2#ifndef TTG_DEVICE_TASK_H
3#define TTG_DEVICE_TASK_H
15#ifdef TTG_HAVE_COROUTINE
21 struct device_input_data_t {
22 using impl_data_t =
decltype(TTG_IMPL_NS::buffer_data(std::declval<
ttg::Buffer<int>>()));
24 device_input_data_t(impl_data_t data,
ttg::scope scope,
bool isconst,
bool isscratch)
27 impl_data_t impl_data;
33 template <
typename... Ts>
35 std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
39 template<
typename... Ts, std::size_t... Is>
40 auto extract_buffer_data(detail::to_device_t<Ts...>& a, std::index_sequence<Is...>) {
41 using arg_types = std::tuple<Ts...>;
42 return std::array<device_input_data_t,
sizeof...(Ts)>{
43 device_input_data_t{TTG_IMPL_NS::buffer_data(std::get<Is>(a.ties)),
44 std::get<Is>(a.ties).scope(),
52 std::vector<detail::device_input_data_t> m_data;
56 template<
typename... Args>
58 : m_data{{TTG_IMPL_NS::buffer_data(args), args.scope(),
59 std::is_const_v<std::remove_reference_t<Args>>,
65 using type = std::remove_reference_t<T>;
66 m_data.emplace_back(TTG_IMPL_NS::buffer_data(v), v.scope(), std::is_const_v<type>,
70 ttg::span<detail::device_input_data_t> span() {
71 return ttg::span(m_data);
78 struct to_device_t<Input> {
90 template <
typename... Args>
92 inline auto select(Args &&...args) {
93 return detail::to_device_t<std::remove_reference_t<Args>...>{std::tie(std::forward<Args>(args)...)};
97 inline auto select(Input& input) {
98 return detail::to_device_t<Input>{input};
103 enum ttg_device_coro_state {
104 TTG_DEVICE_CORO_STATE_NONE,
105 TTG_DEVICE_CORO_INIT,
106 TTG_DEVICE_CORO_WAIT_TRANSFER,
107 TTG_DEVICE_CORO_WAIT_KERNEL,
108 TTG_DEVICE_CORO_SENDOUT,
109 TTG_DEVICE_CORO_COMPLETE
112 template <
typename... Ts>
113 struct wait_kernel_t {
114 std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
117 constexpr bool await_ready() const noexcept {
return false; }
120 template <
typename Promise>
123 void await_resume() noexcept {
124 if constexpr (
sizeof...(Ts) > 0) {
126 TTG_IMPL_NS::post_device_out(ties);
137 template <
typename... Buffers>
139 inline auto wait(Buffers &&...args) {
143 "Only ttg::Buffer and ttg::devicescratch can be waited on!");
144 return detail::wait_kernel_t<std::remove_reference_t<Buffers>...>{std::tie(std::forward<Buffers>(args)...)};
155 struct send_coro_promise_type;
160 struct send_coro_state :
public send_coro_handle_type {
161 using base_type = send_coro_handle_type;
166 using promise_type = send_coro_promise_type;
170 send_coro_state(base_type base) : base_type(
std::move(base)) {}
172 base_type &handle() {
return *
this; }
175 inline bool ready() {
return true; }
178 inline bool completed();
182 struct send_coro_promise_type {
193 send_coro_state get_return_object() {
return send_coro_state{send_coro_handle_type::from_promise(*
this)}; }
198 void unhandled_exception() {
199 std::cerr <<
"Send coroutine caught an unhandled exception!" << std::endl;
203 void return_void() {}
206 template <
typename Key,
typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
207 inline send_coro_state send_coro(
const Key &key, Value &&value,
ttg::Out<Key, std::decay_t<Value>> &t,
211 t.prepare_send(k, std::forward<Value>(value));
213 t.send(k, std::forward<Value>(value));
216 template <
typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
217 inline send_coro_state sendv_coro(Value &&value,
ttg::Out<
void, std::decay_t<Value>> &t,
220 t.prepare_send(std::forward<Value>(value));
222 t.sendv(std::forward<Value>(value));
225 template <
typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
233 template <ttg::Runtime Runtime = ttg::ttg_runtime>
241 send_coro_state coro;
245 template <
size_t i,
typename keyT,
typename valueT,
typename... out_keysT,
typename... out_valuesT,
249 return detail::send_t{
250 detail::send_coro(key, copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
253 template <
size_t i,
typename valueT,
typename... out_keysT,
typename... out_valuesT,
257 return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
260 template <
size_t i,
typename Key,
typename... out_keysT,
typename... out_valuesT,
263 return detail::send_t{detail::sendk_coro(key, std::get<i>(t))};
272 template <
typename keyT,
typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
273 inline detail::send_t
send(
size_t i,
const keyT &key, valueT &&value) {
275 auto *terminal_ptr = ttg::detail::get_out_terminal<keyT, valueT>(i,
"ttg::device::send(i, key, value)");
276 return detail::send_t{detail::send_coro(key, copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
286 template <
size_t i,
typename keyT,
typename valueT>
287 inline auto send(
const keyT &key, valueT &&value) {
288 return ttg::device::send(i, key, std::forward<valueT>(value));
292 template <
typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
293 inline detail::send_t
sendv(std::size_t i, valueT &&value) {
294 auto *terminal_ptr = ttg::detail::get_out_terminal<void, valueT>(i,
"ttg::device::send(i, key, value)");
296 return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
299 template <
typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
300 inline detail::send_t
sendk(std::size_t i,
const Key& key) {
301 auto *terminal_ptr = ttg::detail::get_out_terminal<Key, void>(i,
"ttg::device::send(i, key, value)");
302 return detail::send_t{detail::sendk_coro(key, *terminal_ptr)};
305 template <ttg::Runtime Runtime = ttg::ttg_runtime>
306 inline detail::send_t
send(std::size_t i) {
307 auto *terminal_ptr = ttg::detail::get_out_terminal<void, void>(i,
"ttg::device::send(i, key, value)");
308 return detail::send_t{detail::send_coro(*terminal_ptr)};
312 template <std::
size_t i,
typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
313 inline detail::send_t
sendv(valueT &&value) {
314 return sendv(i, std::forward<valueT>(value));
317 template <
size_t i,
typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
318 inline detail::send_t
sendk(
const Key& key) {
319 return sendk(i, key);
322 template <
size_t i, ttg::Runtime Runtime = ttg::ttg_runtime>
323 inline detail::send_t
sendk() {
329 template<
typename T,
typename Enabler =
void>
330 struct broadcast_keylist_trait {
336 struct broadcast_keylist_trait<T,
std::enable_if_t<ttg::meta::is_iterable_v<T>>> {
337 using key_type =
decltype(*std::begin(std::declval<T>()));
340 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT,
341 typename... out_keysT,
typename... out_valuesT>
342 inline void prepare_broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value,
344 std::get<I>(t).prepare_send(std::get<KeyId>(keylists), std::forward<valueT>(value));
345 if constexpr (
sizeof...(Is) > 0) {
346 prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value), t);
350 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT>
351 inline void prepare_broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value) {
352 using key_t =
typename broadcast_keylist_trait<
353 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
355 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I,
"ttg::device::broadcast(keylists, value)");
356 terminal_ptr->prepare_send(std::get<KeyId>(keylists), value);
357 if constexpr (
sizeof...(Is) > 0) {
358 prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value));
362 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT,
363 typename... out_keysT,
typename... out_valuesT>
364 inline void broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value,
366 std::get<I>(t).broadcast(std::get<KeyId>(keylists), std::forward<valueT>(value));
367 if constexpr (
sizeof...(Is) > 0) {
368 detail::broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value), t);
372 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT>
373 inline void broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value) {
374 using key_t =
typename broadcast_keylist_trait<
375 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
377 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I,
"ttg::device::broadcast(keylists, value)");
378 terminal_ptr->broadcast(std::get<KeyId>(keylists), value);
379 if constexpr (
sizeof...(Is) > 0) {
380 ttg::device::detail::broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value));
385 template <
size_t I,
size_t... Is,
typename RangesT,
typename valueT,
386 typename... out_keysT,
typename... out_valuesT,
388 inline send_coro_state
389 broadcast_coro(RangesT &&keylists, valueT &&value,
393 RangesT kl = std::forward<RangesT>(keylists);
396 prepare_broadcast<0, I, Is...>(kl, std::forward<valueT>(value), t);
398 ttg::device::detail::broadcast<0, I, Is...>(kl, std::forward<valueT>(value), t);
401 prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<valueT>(value), t);
403 ttg::device::detail::broadcast<0, I, Is...>(std::tie(kl), std::forward<valueT>(value), t);
408 template <
size_t I,
size_t... Is,
typename RangesT,
typename valueT,
410 inline send_coro_state
411 broadcast_coro(RangesT &&keylists, valueT &&value,
414 RangesT kl = std::forward<RangesT>(keylists);
417 static_assert(
sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
418 "Size of keylist tuple must match the number of output terminals");
419 prepare_broadcast<0, I, Is...>(kl, std::forward<valueT>(value));
421 ttg::device::detail::broadcast<0, I, Is...>(kl, std::forward<valueT>(value));
424 prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<valueT>(value));
426 ttg::device::detail::broadcast<0, I, Is...>(std::tie(kl), std::forward<valueT>(value));
434 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
435 typename... out_keysT,
typename... out_valuesT>
436 inline void broadcastk(
const std::tuple<RangesT...> &keylists,
438 std::get<I>(t).broadcast(std::get<KeyId>(keylists));
439 if constexpr (
sizeof...(Is) > 0) {
440 detail::broadcastk<KeyId+1, Is...>(keylists, t);
444 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT>
445 inline void broadcastk(
const std::tuple<RangesT...> &keylists) {
446 using key_t =
typename broadcast_keylist_trait<
447 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
449 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, void>(I,
"ttg::device::broadcastk(keylists)");
450 terminal_ptr->broadcast(std::get<KeyId>(keylists));
451 if constexpr (
sizeof...(Is) > 0) {
452 ttg::device::detail::broadcastk<KeyId+1, Is...>(keylists);
457 template <
size_t I,
size_t... Is,
typename RangesT,
458 typename... out_keysT,
typename... out_valuesT,
460 inline send_coro_state
461 broadcastk_coro(RangesT &&keylists,
463 RangesT kl = std::forward<RangesT>(keylists);
467 ttg::device::detail::broadcastk<0, I, Is...>(kl, t);
471 ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl), t);
476 template <
size_t I,
size_t... Is,
typename RangesT,
478 inline send_coro_state
479 broadcastk_coro(RangesT &&keylists) {
480 RangesT kl = std::forward<RangesT>(keylists);
483 static_assert(
sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
484 "Size of keylist tuple must match the number of output terminals");
486 ttg::device::detail::broadcastk<0, I, Is...>(kl);
490 ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl));
496 template <
size_t I,
size_t... Is,
typename rangeT,
typename valueT,
typename... out_keysT,
typename... out_valuesT,
499 inline detail::send_t
broadcast(rangeT &&keylist,
503 return detail::send_t{
504 detail::broadcast_coro<I, Is...>(std::forward<rangeT>(keylist),
505 copy_handler(std::forward<valueT>(value)),
506 t, std::move(copy_handler))};
510 template <
size_t i,
typename rangeT,
typename valueT,
512 inline detail::send_t
broadcast(rangeT &&keylist, valueT &&value) {
514 return detail::send_t{detail::broadcast_coro<i>(std::tie(keylist),
515 copy_handler(std::forward<valueT>(value)),
516 std::move(copy_handler))};
520 template <
size_t I,
size_t... Is,
typename rangeT,
typename... out_keysT,
typename... out_valuesT,
523 inline detail::send_t
broadcastk(rangeT &&keylist,
526 return detail::send_t{
527 detail::broadcastk_coro<I, Is...>(std::forward<rangeT>(keylist), t)};
531 template <
size_t i,
typename rangeT,
533 inline detail::send_t
broadcastk(rangeT &&keylist) {
534 if constexpr (std::is_rvalue_reference_v<
decltype(keylist)>) {
535 return detail::send_t{detail::broadcastk_coro<i>(std::forward<rangeT>(keylist))};
537 return detail::send_t{detail::broadcastk_coro<i>(std::tie(keylist))};
543 std::vector<device::detail::send_t> forward(Args&&... args) {
545 return std::vector<device::detail::send_t>{std::forward<Args>(args)...};
554 struct device_task_promise_type;
567 struct Task :
public detail::device_task_handle_type {
568 using base_type = detail::device_task_handle_type;
573 using promise_type = detail::device_task_promise_type;
577 Task(base_type base) : base_type(
std::move(base)) {}
579 base_type& handle() {
return *
this; }
582 inline bool ready() {
587 inline bool completed();
596 struct device_task_promise_type {
602 m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
611 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
616 template<
typename... Views>
618 return yield_value(views);
621 template<
typename... Ts>
623 auto arr = detail::extract_buffer_data(a, std::make_index_sequence<
sizeof...(Ts)>{});
624 bool need_transfer = !(TTG_IMPL_NS::register_device_memory(ttg::span(arr)));
626 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
631 bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.input.span()));
633 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
637 template<
typename... Ts>
638 auto await_transform(detail::wait_kernel_t<Ts...>&& a) {
640 if constexpr (
sizeof...(Ts) > 0) {
641 TTG_IMPL_NS::mark_device_out(a.ties);
643 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_KERNEL;
648 m_sends = std::forward<std::vector<device::detail::send_t>>(v);
649 m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
655 m_sends.push_back(std::forward<device::detail::send_t>(v));
656 m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
661 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
664 bool complete()
const {
665 return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
668 ttg::device::Task get_return_object() {
return {detail::device_task_handle_type::from_promise(*
this)}; }
670 void unhandled_exception() {
671 std::cerr <<
"Task coroutine caught an unhandled exception!" << std::endl;
679 for (
auto& send : m_sends) {
690 std::vector<device::detail::send_t> m_sends;
691 ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
697 bool Task::completed() {
return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
699 struct device_wait_kernel
709 struct device_reducer_promise_type;
714 struct device_reducer :
public device_reducer_handle_type {
715 using base_type = device_reducer_handle_type;
720 using promise_type = device_reducer_promise_type;
724 device_reducer(base_type base) : base_type(
std::move(base)) {}
726 base_type& handle() {
return *
this; }
729 inline bool ready() {
734 inline bool completed();
742 struct device_reducer_promise_type {
748 m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
757 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
761 template<
typename... Ts>
763 bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.ties));
765 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
770 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
773 bool complete()
const {
774 return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
777 device_reducer get_return_object() {
return device_reducer{device_reducer_handle_type::from_promise(*
this)}; }
779 void unhandled_exception() { }
787 ttg::device::detail::ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
791 bool device_reducer::completed() {
return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
std::enable_if_t<!meta::is_void_v< Key > &&meta::is_void_v< Value >, void > sendk(const Key &key)
std::enable_if_t< meta::is_all_void_v< Key, Value >, void > send()
A complete version of void.
constexpr auto data(C &c) -> decltype(c.data())
std::integral_constant< bool,(Flags &const_) !=0 > is_const
void broadcast(const std::tuple< RangesT... > &keylists, valueT &&value, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)
TTG_CXX_COROUTINE_NAMESPACE::suspend_always suspend_always
void send(const keyT &key, valueT &&value, ttg::Out< keyT, valueT > &t)
Sends a task id and a value to the given output terminal.
constexpr const ttg::Runtime ttg_runtime
TTG_CXX_COROUTINE_NAMESPACE::suspend_never suspend_never
TTG_CXX_COROUTINE_NAMESPACE::coroutine_handle< Promise > coroutine_handle
void sendk(const keyT &key, ttg::Out< keyT, void > &t)
Sends a task id (without an accompanying value) to the given output terminal.
void sendv(valueT &&value, ttg::Out< void, valueT > &t)
Sends a value (without an accompanying task id) to the given output terminal.
void broadcastk(const rangeT &keylist, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)