task.h
Go to the documentation of this file.
1 #ifndef TTG_DEVICE_TASK_H
2 #define TTG_DEVICE_TASK_H
3 
4 #include <array>
5 #include <type_traits>
6 #include <span>
7 
8 
9 #include "ttg/fwd.h"
10 #include "ttg/impl_selector.h"
11 #include "ttg/ptr.h"
12 #include "ttg/devicescope.h"
13 
14 #ifdef TTG_HAVE_COROUTINE
15 
16 namespace ttg::device {
17 
18  namespace detail {
19 
20  struct device_input_data_t {
21  using impl_data_t = decltype(TTG_IMPL_NS::buffer_data(std::declval<ttg::Buffer<int>>()));
22 
23  device_input_data_t(impl_data_t data, ttg::scope scope, bool isconst, bool isscratch)
24  : impl_data(data), scope(scope), is_const(isconst), is_scratch(isscratch)
25  { }
26  impl_data_t impl_data;
28  bool is_const;
29  bool is_scratch;
30  };
31 
32  template <typename... Ts>
33  struct to_device_t {
34  std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
35  };
36 
37  /* extract buffer information from to_device_t */
38  template<typename... Ts, std::size_t... Is>
39  auto extract_buffer_data(detail::to_device_t<Ts...>& a, std::index_sequence<Is...>) {
40  using arg_types = std::tuple<Ts...>;
41  return std::array<device_input_data_t, sizeof...(Ts)>{
42  device_input_data_t{TTG_IMPL_NS::buffer_data(std::get<Is>(a.ties)),
43  std::get<Is>(a.ties).scope(),
44  ttg::meta::is_const_v<std::tuple_element_t<Is, arg_types>>,
45  ttg::meta::is_devicescratch_v<std::tuple_element_t<Is, arg_types>>}...};
46  }
47  } // namespace detail
48 
49  struct Input {
50  private:
51  std::vector<detail::device_input_data_t> m_data;
52 
53  public:
54  Input() { }
55  template<typename... Args>
56  Input(Args&&... args)
57  : m_data{{TTG_IMPL_NS::buffer_data(args), args.scope(),
58  std::is_const_v<std::remove_reference_t<Args>>,
59  ttg::meta::is_devicescratch_v<std::decay_t<Args>>}...}
60  { }
61 
62  template<typename T>
63  void add(T&& v) {
64  using type = std::remove_reference_t<T>;
65  m_data.emplace_back(TTG_IMPL_NS::buffer_data(v), v.scope(), std::is_const_v<type>,
66  ttg::meta::is_devicescratch_v<type>);
67  }
68 
69  ttg::span<detail::device_input_data_t> span() {
70  return ttg::span(m_data);
71  }
72  };
73 
74  namespace detail {
75  // overload for Input
76  template <>
77  struct to_device_t<Input> {
78  Input& input;
79  };
80  } // namespace detail
81 
89  template <typename... Args>
90  [[nodiscard]]
91  inline auto select(Args &&...args) {
92  return detail::to_device_t<std::remove_reference_t<Args>...>{std::tie(std::forward<Args>(args)...)};
93  }
94 
95  [[nodiscard]]
96  inline auto select(Input& input) {
97  return detail::to_device_t<Input>{input};
98  }
99 
100  namespace detail {
101 
102  enum ttg_device_coro_state {
103  TTG_DEVICE_CORO_STATE_NONE,
104  TTG_DEVICE_CORO_INIT,
105  TTG_DEVICE_CORO_WAIT_TRANSFER,
106  TTG_DEVICE_CORO_WAIT_KERNEL,
107  TTG_DEVICE_CORO_SENDOUT,
108  TTG_DEVICE_CORO_COMPLETE
109  };
110 
111  template <typename... Ts>
112  struct wait_kernel_t {
113  std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
114 
115  /* always suspend */
116  constexpr bool await_ready() const noexcept { return false; }
117 
118  /* always suspend */
119  template <typename Promise>
120  constexpr void await_suspend(ttg::coroutine_handle<Promise>) const noexcept {}
121 
122  void await_resume() noexcept {
123  if constexpr (sizeof...(Ts) > 0) {
124  /* hook to allow the backend to handle the data after pushout */
126  }
127  }
128  };
129  } // namespace detail
130 
136  template <typename... Buffers>
137  [[nodiscard]]
138  inline auto wait(Buffers &&...args) {
139  static_assert(
140  ((ttg::meta::is_buffer_v<std::decay_t<Buffers>> || ttg::meta::is_devicescratch_v<std::decay_t<Buffers>>) &&
141  ...),
142  "Only ttg::Buffer and ttg::devicescratch can be waited on!");
143  return detail::wait_kernel_t<std::remove_reference_t<Buffers>...>{std::tie(std::forward<Buffers>(args)...)};
144  }
145 
146  /******************************
147  * Send/Broadcast handling
148  * We pass the value returned by the backend's copy handler into a coroutine
149  * and execute the first part (prepare), before suspending it.
150  * The second part (send/broadcast) is executed after the task completed.
151  ******************************/
152 
153  namespace detail {
154  struct send_coro_promise_type;
155 
156  using send_coro_handle_type = ttg::coroutine_handle<send_coro_promise_type>;
157 
159  struct send_coro_state : public send_coro_handle_type {
160  using base_type = send_coro_handle_type;
161 
164 
165  using promise_type = send_coro_promise_type;
166 
168 
169  send_coro_state(base_type base) : base_type(std::move(base)) {}
170 
171  base_type &handle() { return *this; }
172 
174  inline bool ready() { return true; }
175 
177  inline bool completed();
178  };
179 
181  struct send_coro_promise_type {
182  /* do not suspend the coroutine on first invocation, we want to run
183  * the coroutine immediately and suspend only once.
184  */
185  ttg::suspend_never initial_suspend() { return {}; }
186 
187  /* we don't suspend the coroutine at the end.
188  * it can be destroyed once the send/broadcast is done
189  */
190  ttg::suspend_never final_suspend() noexcept { return {}; }
191 
192  send_coro_state get_return_object() { return send_coro_state{send_coro_handle_type::from_promise(*this)}; }
193 
194  /* the send coros only have an empty co_await */
195  ttg::suspend_always await_transform(ttg::Void) { return {}; }
196 
197  void unhandled_exception() {
198  std::cerr << "Send coroutine caught an unhandled exception!" << std::endl;
199  throw; // fwd
200  }
201 
202  void return_void() {}
203  };
204 
205  template <typename Key, typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
206  inline send_coro_state send_coro(const Key &key, Value &&value, ttg::Out<Key, std::decay_t<Value>> &t,
208  ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
209  Key k = key;
210  t.prepare_send(k, std::forward<Value>(value));
211  co_await ttg::Void{}; // we'll come back once the task is done
212  t.send(k, std::forward<Value>(value));
213  };
214 
215  template <typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
216  inline send_coro_state sendv_coro(Value &&value, ttg::Out<void, std::decay_t<Value>> &t,
218  ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
219  t.prepare_send(std::forward<Value>(value));
220  co_await ttg::Void{}; // we'll come back once the task is done
221  t.sendv(std::forward<Value>(value));
222  };
223 
224  template <typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
225  inline send_coro_state sendk_coro(const Key &key, ttg::Out<Key, void> &t) {
226  // no need to prepare the send but we have to suspend once
227  Key k = key;
228  co_await ttg::Void{}; // we'll come back once the task is done
229  t.sendk(k);
230  };
231 
232  template <ttg::Runtime Runtime = ttg::ttg_runtime>
233  inline send_coro_state send_coro(ttg::Out<void, void> &t) {
234  // no need to prepare the send but we have to suspend once
235  co_await ttg::Void{}; // we'll come back once the task is done
236  t.send();
237  };
238 
239  struct send_t {
240  send_coro_state coro;
241  };
242  } // namespace detail
243 
244  template <size_t i, typename keyT, typename valueT, typename... out_keysT, typename... out_valuesT,
246  inline detail::send_t send(const keyT &key, valueT &&value, std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
248  return detail::send_t{
249  detail::send_coro(key, copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
250  }
251 
252  template <size_t i, typename valueT, typename... out_keysT, typename... out_valuesT,
254  inline detail::send_t sendv(valueT &&value, std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
256  return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
257  }
258 
259  template <size_t i, typename Key, typename... out_keysT, typename... out_valuesT,
261  inline detail::send_t sendk(const Key &key, std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
262  return detail::send_t{detail::sendk_coro(key, std::get<i>(t))};
263  }
264 
265  // clang-format off
270  // clang-format on
271  template <typename keyT, typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
272  inline detail::send_t send(size_t i, const keyT &key, valueT &&value) {
274  auto *terminal_ptr = ttg::detail::get_out_terminal<keyT, valueT>(i, "ttg::device::send(i, key, value)");
275  return detail::send_t{detail::send_coro(key, copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
276  }
277 
278  // clang-format off
284  // clang-format on
285  template <size_t i, typename keyT, typename valueT>
286  inline auto send(const keyT &key, valueT &&value) {
287  return ttg::device::send(i, key, std::forward<valueT>(value));
288  }
289 
290 
291  template <typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
292  inline detail::send_t sendv(std::size_t i, valueT &&value) {
293  auto *terminal_ptr = ttg::detail::get_out_terminal<void, valueT>(i, "ttg::device::send(i, key, value)");
295  return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
296  }
297 
298  template <typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
299  inline detail::send_t sendk(std::size_t i, const Key& key) {
300  auto *terminal_ptr = ttg::detail::get_out_terminal<Key, void>(i, "ttg::device::send(i, key, value)");
301  return detail::send_t{detail::sendk_coro(key, *terminal_ptr)};
302  }
303 
304  template <ttg::Runtime Runtime = ttg::ttg_runtime>
305  inline detail::send_t send(std::size_t i) {
306  auto *terminal_ptr = ttg::detail::get_out_terminal<void, void>(i, "ttg::device::send(i, key, value)");
307  return detail::send_t{detail::send_coro(*terminal_ptr)};
308  }
309 
310 
311  template <std::size_t i, typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
312  inline detail::send_t sendv(valueT &&value) {
313  return sendv(i, std::forward<valueT>(value));
314  }
315 
316  template <size_t i, typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
317  inline detail::send_t sendk(const Key& key) {
318  return sendk(i, key);
319  }
320 
321  template <size_t i, ttg::Runtime Runtime = ttg::ttg_runtime>
322  inline detail::send_t sendk() {
323  return send(i);
324  }
325 
326  namespace detail {
327 
328  template<typename T, typename Enabler = void>
329  struct broadcast_keylist_trait {
330  using type = T;
331  };
332 
333  /* overload for iterable types that extracts the type of the first element */
334  template<typename T>
335  struct broadcast_keylist_trait<T, std::enable_if_t<ttg::meta::is_iterable_v<T>>> {
336  using key_type = decltype(*std::begin(std::declval<T>()));
337  };
338 
339  template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT,
340  typename... out_keysT, typename... out_valuesT>
341  inline void prepare_broadcast(const std::tuple<RangesT...> &keylists, valueT &&value,
342  std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
343  std::get<I>(t).prepare_send(std::get<KeyId>(keylists), std::forward<valueT>(value));
344  if constexpr (sizeof...(Is) > 0) {
345  prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value), t);
346  }
347  }
348 
349  template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT>
350  inline void prepare_broadcast(const std::tuple<RangesT...> &keylists, valueT &&value) {
351  using key_t = typename broadcast_keylist_trait<
352  std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
353  >::key_type;
354  auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I, "ttg::device::broadcast(keylists, value)");
355  terminal_ptr->prepare_send(std::get<KeyId>(keylists), value);
356  if constexpr (sizeof...(Is) > 0) {
357  prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value));
358  }
359  }
360 
361  template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT,
362  typename... out_keysT, typename... out_valuesT>
363  inline void broadcast(const std::tuple<RangesT...> &keylists, valueT &&value,
364  std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
365  std::get<I>(t).broadcast(std::get<KeyId>(keylists), std::forward<valueT>(value));
366  if constexpr (sizeof...(Is) > 0) {
367  detail::broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value), t);
368  }
369  }
370 
371  template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT>
372  inline void broadcast(const std::tuple<RangesT...> &keylists, valueT &&value) {
373  using key_t = typename broadcast_keylist_trait<
374  std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
375  >::key_type;
376  auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I, "ttg::device::broadcast(keylists, value)");
377  terminal_ptr->broadcast(std::get<KeyId>(keylists), value);
378  if constexpr (sizeof...(Is) > 0) {
379  ttg::device::detail::broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value));
380  }
381  }
382 
383  /* overload with explicit terminals */
384  template <size_t I, size_t... Is, typename RangesT, typename valueT,
385  typename... out_keysT, typename... out_valuesT,
387  inline send_coro_state
388  broadcast_coro(RangesT &&keylists, valueT &&value,
389  std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t,
391  ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
392  RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
393  if constexpr (ttg::meta::is_tuple_v<RangesT>) {
394  // treat as tuple
395  prepare_broadcast<0, I, Is...>(kl, std::forward<std::decay_t<decltype(value)>>(value), t);
396  co_await ttg::Void{}; // we'll come back once the task is done
397  ttg::device::detail::broadcast<0, I, Is...>(kl, std::forward<std::decay_t<decltype(value)>>(value), t);
398  } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
399  // create a tie to the captured keylist
400  prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value), t);
401  co_await ttg::Void{}; // we'll come back once the task is done
402  ttg::device::detail::broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value), t);
403  }
404  }
405 
406  /* overload with implicit terminals */
407  template <size_t I, size_t... Is, typename RangesT, typename valueT,
409  inline send_coro_state
410  broadcast_coro(RangesT &&keylists, valueT &&value,
412  ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
413  RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
414  if constexpr (ttg::meta::is_tuple_v<RangesT>) {
415  // treat as tuple
416  static_assert(sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
417  "Size of keylist tuple must match the number of output terminals");
418  prepare_broadcast<0, I, Is...>(kl, std::forward<std::decay_t<decltype(value)>>(value));
419  co_await ttg::Void{}; // we'll come back once the task is done
420  ttg::device::detail::broadcast<0, I, Is...>(kl, std::forward<std::decay_t<decltype(value)>>(value));
421  } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
422  // create a tie to the captured keylist
423  prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value));
424  co_await ttg::Void{}; // we'll come back once the task is done
425  ttg::device::detail::broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value));
426  }
427  }
428 
433  template <size_t KeyId, size_t I, size_t... Is, typename... RangesT,
434  typename... out_keysT, typename... out_valuesT>
435  inline void broadcastk(const std::tuple<RangesT...> &keylists,
436  std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
437  std::get<I>(t).broadcast(std::get<KeyId>(keylists));
438  if constexpr (sizeof...(Is) > 0) {
439  detail::broadcastk<KeyId+1, Is...>(keylists, t);
440  }
441  }
442 
443  template <size_t KeyId, size_t I, size_t... Is, typename... RangesT>
444  inline void broadcastk(const std::tuple<RangesT...> &keylists) {
445  using key_t = typename broadcast_keylist_trait<
446  std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
447  >::key_type;
448  auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, void>(I, "ttg::device::broadcastk(keylists)");
449  terminal_ptr->broadcast(std::get<KeyId>(keylists));
450  if constexpr (sizeof...(Is) > 0) {
451  ttg::device::detail::broadcastk<KeyId+1, Is...>(keylists);
452  }
453  }
454 
455  /* overload with explicit terminals */
456  template <size_t I, size_t... Is, typename RangesT,
457  typename... out_keysT, typename... out_valuesT,
459  inline send_coro_state
460  broadcastk_coro(RangesT &&keylists,
461  std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
462  RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
463  if constexpr (ttg::meta::is_tuple_v<RangesT>) {
464  // treat as tuple
465  co_await ttg::Void{}; // we'll come back once the task is done
466  ttg::device::detail::broadcastk<0, I, Is...>(kl, t);
467  } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
468  // create a tie to the captured keylist
469  co_await ttg::Void{}; // we'll come back once the task is done
470  ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl), t);
471  }
472  }
473 
474  /* overload with implicit terminals */
475  template <size_t I, size_t... Is, typename RangesT,
477  inline send_coro_state
478  broadcastk_coro(RangesT &&keylists) {
479  RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
480  if constexpr (ttg::meta::is_tuple_v<RangesT>) {
481  // treat as tuple
482  static_assert(sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
483  "Size of keylist tuple must match the number of output terminals");
484  co_await ttg::Void{}; // we'll come back once the task is done
485  ttg::device::detail::broadcastk<0, I, Is...>(kl);
486  } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
487  // create a tie to the captured keylist
488  co_await ttg::Void{}; // we'll come back once the task is done
489  ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl));
490  }
491  }
492  } // namespace detail
493 
494  /* overload with explicit terminals and keylist passed by const reference */
495  template <size_t I, size_t... Is, typename rangeT, typename valueT, typename... out_keysT, typename... out_valuesT,
497  [[nodiscard]]
498  inline detail::send_t broadcast(rangeT &&keylist,
499  valueT &&value,
500  std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
502  return detail::send_t{
503  detail::broadcast_coro<I, Is...>(std::forward<rangeT>(keylist),
504  copy_handler(std::forward<valueT>(value)),
505  t, std::move(copy_handler))};
506  }
507 
508  /* overload with implicit terminals and keylist passed by const reference */
509  template <size_t i, typename rangeT, typename valueT,
511  inline detail::send_t broadcast(rangeT &&keylist, valueT &&value) {
513  return detail::send_t{detail::broadcast_coro<i>(std::tie(keylist),
514  copy_handler(std::forward<valueT>(value)),
515  std::move(copy_handler))};
516  }
517 
518  /* overload with explicit terminals and keylist passed by const reference */
519  template <size_t I, size_t... Is, typename rangeT, typename... out_keysT, typename... out_valuesT,
521  [[nodiscard]]
522  inline detail::send_t broadcastk(rangeT &&keylist,
523  std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
525  return detail::send_t{
526  detail::broadcastk_coro<I, Is...>(std::forward<rangeT>(keylist), t)};
527  }
528 
529  /* overload with implicit terminals and keylist passed by const reference */
530  template <size_t i, typename rangeT,
532  inline detail::send_t broadcastk(rangeT &&keylist) {
533  if constexpr (std::is_rvalue_reference_v<decltype(keylist)>) {
534  return detail::send_t{detail::broadcastk_coro<i>(std::forward<rangeT>(keylist))};
535  } else {
536  return detail::send_t{detail::broadcastk_coro<i>(std::tie(keylist))};
537  }
538  }
539 
540  template<typename... Args, ttg::Runtime Runtime = ttg::ttg_runtime>
541  [[nodiscard]]
542  std::vector<device::detail::send_t> forward(Args&&... args) {
543  // TODO: check the cost of this!
544  return std::vector<device::detail::send_t>{std::forward<Args>(args)...};
545  }
546 
547  /*******************************************
548  * Device task promise and coroutine handle
549  *******************************************/
550 
551  namespace detail {
552  // fwd-decl
553  struct device_task_promise_type;
554  // base type for ttg::device::Task
555  using device_task_handle_type = ttg::coroutine_handle<device_task_promise_type>;
556  } // namespace detail
557 
559 
566  struct Task : public detail::device_task_handle_type {
567  using base_type = detail::device_task_handle_type;
568 
571 
572  using promise_type = detail::device_task_promise_type;
573 
575 
576  Task(base_type base) : base_type(std::move(base)) {}
577 
578  base_type& handle() { return *this; }
579 
581  inline bool ready() {
582  return true;
583  }
584 
586  inline bool completed();
587  };
588 
589  namespace detail {
590 
591  /* The promise type that stores the views provided by the
592  * application task coroutine on the first co_yield. It subsequently
593  * tracks the state of the task when it moves from waiting for transfers
594  * to waiting for the submitted kernel to complete. */
595  struct device_task_promise_type {
596 
597  /* do not suspend the coroutine on first invocation, we want to run
598  * the coroutine immediately and suspend when we get the device transfers.
599  */
600  ttg::suspend_never initial_suspend() {
601  m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
602  return {};
603  }
604 
605  /* suspend the coroutine at the end of the execution
606  * so we can access the promise.
607  * TODO: necessary? maybe we can save one suspend here
608  */
609  ttg::suspend_always final_suspend() noexcept {
610  m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
611  return {};
612  }
613 
614  /* Allow co_await on a tuple */
615  template<typename... Views>
616  ttg::suspend_always await_transform(std::tuple<Views&...> &views) {
617  return yield_value(views);
618  }
619 
620  template<typename... Ts>
621  ttg::suspend_always await_transform(detail::to_device_t<Ts...>&& a) {
622  auto arr = detail::extract_buffer_data(a, std::make_index_sequence<sizeof...(Ts)>{});
623  bool need_transfer = !(TTG_IMPL_NS::register_device_memory(ttg::span(arr)));
624  /* TODO: are we allowed to not suspend here and launch the kernel directly? */
625  m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
626  return {};
627  }
628 
629  ttg::suspend_always await_transform(detail::to_device_t<Input>&& a) {
630  bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.input.span()));
631  /* TODO: are we allowed to not suspend here and launch the kernel directly? */
632  m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
633  return {};
634  }
635 
636  template<typename... Ts>
637  auto await_transform(detail::wait_kernel_t<Ts...>&& a) {
638  //std::cout << "yield_value: wait_kernel_t" << std::endl;
639  if constexpr (sizeof...(Ts) > 0) {
641  }
642  m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_KERNEL;
643  return a;
644  }
645 
646  ttg::suspend_always await_transform(std::vector<device::detail::send_t>&& v) {
647  m_sends = std::forward<std::vector<device::detail::send_t>>(v);
648  m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
649  return {};
650  }
651 
652  ttg::suspend_always await_transform(device::detail::send_t&& v) {
653  m_sends.clear();
654  m_sends.push_back(std::forward<device::detail::send_t>(v));
655  m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
656  return {};
657  }
658 
659  void return_void() {
660  m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
661  }
662 
663  bool complete() const {
664  return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
665  }
666 
667  ttg::device::Task get_return_object() { return {detail::device_task_handle_type::from_promise(*this)}; }
668 
669  void unhandled_exception() {
670  std::cerr << "Task coroutine caught an unhandled exception!" << std::endl;
671  throw; // fwd
672  }
673 
674  //using iterator = std::vector<device_obj_view>::iterator;
675 
676  /* execute all pending send and broadcast operations */
677  void do_sends() {
678  for (auto& send : m_sends) {
679  send.coro();
680  }
681  m_sends.clear();
682  }
683 
684  auto state() {
685  return m_state;
686  }
687 
688  private:
689  std::vector<device::detail::send_t> m_sends;
690  ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
691 
692  };
693 
694  } // namespace detail
695 
696  bool Task::completed() { return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
697 
698  struct device_wait_kernel
699  { };
700 
701 
702  /* NOTE: below is preliminary for reductions on the device, which is not available yet */
703 #if 0
704  /**************************
705  * Device reduction coros *
706  **************************/
707 
708  struct device_reducer_promise_type;
709 
710  using device_reducer_handle_type = ttg::coroutine_handle<device_reducer_promise_type>;
711 
713  struct device_reducer : public device_reducer_handle_type {
714  using base_type = device_reducer_handle_type;
715 
718 
719  using promise_type = device_reducer_promise_type;
720 
722 
723  device_reducer(base_type base) : base_type(std::move(base)) {}
724 
725  base_type& handle() { return *this; }
726 
728  inline bool ready() {
729  return true;
730  }
731 
733  inline bool completed();
734  };
735 
736 
737  /* The promise type that stores the views provided by the
738  * application task coroutine on the first co_yield. It subsequently
739  * tracks the state of the task when it moves from waiting for transfers
740  * to waiting for the submitted kernel to complete. */
741  struct device_reducer_promise_type {
742 
743  /* do not suspend the coroutine on first invocation, we want to run
744  * the coroutine immediately and suspend when we get the device transfers.
745  */
746  ttg::suspend_never initial_suspend() {
747  m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
748  return {};
749  }
750 
751  /* suspend the coroutine at the end of the execution
752  * so we can access the promise.
753  * TODO: necessary? maybe we can save one suspend here
754  */
755  ttg::suspend_always final_suspend() noexcept {
756  m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
757  return {};
758  }
759 
760  template<typename... Ts>
761  ttg::suspend_always await_transform(detail::to_device_t<Ts...>&& a) {
762  bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.ties));
763  /* TODO: are we allowed to not suspend here and launch the kernel directly? */
764  m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
765  return {};
766  }
767 
768  void return_void() {
769  m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
770  }
771 
772  bool complete() const {
773  return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
774  }
775 
776  device_reducer get_return_object() { return device_reducer{device_reducer_handle_type::from_promise(*this)}; }
777 
778  void unhandled_exception() { }
779 
780  auto state() {
781  return m_state;
782  }
783 
784 
785  private:
786  ttg::device::detail::ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
787 
788  };
789 
790  bool device_reducer::completed() { return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
791 #endif // 0
792 
793 } // namespace ttg::device
794 
795 #endif // TTG_HAVE_COROUTINE
796 
797 #endif // TTG_DEVICE_TASK_H
std::enable_if_t< meta::is_all_void_v< Key, Value >, void > send()
Definition: terminal.h:514
std::enable_if_t<!meta::is_void_v< Key > &&meta::is_void_v< Value >, void > sendk(const Key &key)
Definition: terminal.h:475
A complete version of void.
Definition: void.h:11
constexpr auto data(C &c) -> decltype(c.data())
Definition: span.h:189
typename make_index_sequence_t< I... >::type make_index_sequence
std::integral_constant< bool,(Flags &const_) !=0 > is_const
void broadcast(const std::tuple< RangesT... > &keylists, valueT &&value, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)
Definition: func.h:347
constexpr bool is_devicescratch_v
Definition: meta.h:318
constexpr bool is_buffer_v
Definition: meta.h:311
auto buffer_data(const Buffer< T, A > &buffer)
Definition: devicefunc.h:9
void mark_device_out(std::tuple< Buffer &... > &b)
bool register_device_memory(std::tuple< Views &... > &views)
Definition: devicefunc.h:15
void post_device_out(std::tuple< Buffer &... > &b)
TTG_CXX_COROUTINE_NAMESPACE::suspend_always suspend_always
Definition: coroutine.h:21
void send(const keyT &key, valueT &&value, ttg::Out< keyT, valueT > &t)
Sends a task id and a value to the given output terminal.
Definition: func.h:158
constexpr const ttg::Runtime ttg_runtime
Definition: import.h:20
scope
Definition: devicescope.h:5
Runtime
Definition: runtimes.h:15
TTG_CXX_COROUTINE_NAMESPACE::suspend_never suspend_never
Definition: coroutine.h:22
TTG_CXX_COROUTINE_NAMESPACE::coroutine_handle< Promise > coroutine_handle
Definition: coroutine.h:24
void sendk(const keyT &key, ttg::Out< keyT, void > &t)
Sends a task id (without an accompanying value) to the given output terminal.
Definition: func.h:169
void sendv(valueT &&value, ttg::Out< void, valueT > &t)
Sends a value (without an accompanying task id) to the given output terminal.
Definition: func.h:179
void broadcastk(const rangeT &keylist, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)
Definition: func.h:452