task.h
Go to the documentation of this file.
1 #ifndef TTG_DEVICE_TASK_H
2 #define TTG_DEVICE_TASK_H
3 
4 #include <array>
5 #include <type_traits>
6 #include <span>
7 
8 #include "ttg/fwd.h"
9 #include "ttg/impl_selector.h"
10 #include "ttg/ptr.h"
11 
12 #ifdef TTG_HAVE_COROUTINE
13 
14 namespace ttg::device {
15 
16  namespace detail {
17  template <typename... Ts>
18  struct to_device_t {
19  std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
20  };
21  } // namespace detail
22 
30  template <typename... Args>
31  [[nodiscard]]
32  inline auto select(Args &&...args) {
33  return detail::to_device_t<std::remove_reference_t<Args>...>{std::tie(std::forward<Args>(args)...)};
34  }
35 
36  namespace detail {
37 
38  enum ttg_device_coro_state {
39  TTG_DEVICE_CORO_STATE_NONE,
40  TTG_DEVICE_CORO_INIT,
41  TTG_DEVICE_CORO_WAIT_TRANSFER,
42  TTG_DEVICE_CORO_WAIT_KERNEL,
43  TTG_DEVICE_CORO_SENDOUT,
44  TTG_DEVICE_CORO_COMPLETE
45  };
46 
47  template <typename... Ts>
48  struct wait_kernel_t {
49  std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
50 
51  /* always suspend */
52  constexpr bool await_ready() const noexcept { return false; }
53 
54  /* always suspend */
55  template <typename Promise>
56  constexpr void await_suspend(ttg::coroutine_handle<Promise>) const noexcept {}
57 
58  void await_resume() noexcept {
59  if constexpr (sizeof...(Ts) > 0) {
60  /* hook to allow the backend to handle the data after pushout */
62  }
63  }
64  };
65  } // namespace detail
66 
72  template <typename... Buffers>
73  [[nodiscard]]
74  inline auto wait(Buffers &&...args) {
75  static_assert(
76  ((ttg::meta::is_buffer_v<std::decay_t<Buffers>> || ttg::meta::is_devicescratch_v<std::decay_t<Buffers>>) &&
77  ...),
78  "Only ttg::Buffer and ttg::devicescratch can be waited on!");
79  return detail::wait_kernel_t<std::remove_reference_t<Buffers>...>{std::tie(std::forward<Buffers>(args)...)};
80  }
81 
82  /******************************
83  * Send/Broadcast handling
84  * We pass the value returned by the backend's copy handler into a coroutine
85  * and execute the first part (prepare), before suspending it.
86  * The second part (send/broadcast) is executed after the task completed.
87  ******************************/
88 
89  namespace detail {
90  struct send_coro_promise_type;
91 
92  using send_coro_handle_type = ttg::coroutine_handle<send_coro_promise_type>;
93 
95  struct send_coro_state : public send_coro_handle_type {
96  using base_type = send_coro_handle_type;
97 
100 
101  using promise_type = send_coro_promise_type;
102 
104 
105  send_coro_state(base_type base) : base_type(std::move(base)) {}
106 
107  base_type &handle() { return *this; }
108 
110  inline bool ready() { return true; }
111 
113  inline bool completed();
114  };
115 
117  struct send_coro_promise_type {
118  /* do not suspend the coroutine on first invocation, we want to run
119  * the coroutine immediately and suspend only once.
120  */
121  ttg::suspend_never initial_suspend() { return {}; }
122 
123  /* we don't suspend the coroutine at the end.
124  * it can be destroyed once the send/broadcast is done
125  */
126  ttg::suspend_never final_suspend() noexcept { return {}; }
127 
128  send_coro_state get_return_object() { return send_coro_state{send_coro_handle_type::from_promise(*this)}; }
129 
130  /* the send coros only have an empty co_await */
131  ttg::suspend_always await_transform(ttg::Void) { return {}; }
132 
133  void unhandled_exception() {
134  std::cerr << "Send coroutine caught an unhandled exception!" << std::endl;
135  throw; // fwd
136  }
137 
138  void return_void() {}
139  };
140 
141  template <typename Key, typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
142  inline send_coro_state send_coro(const Key &key, Value &&value, ttg::Out<Key, std::decay_t<Value>> &t,
144  ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
145  Key k = key;
146  t.prepare_send(k, std::forward<Value>(value));
147  co_await ttg::Void{}; // we'll come back once the task is done
148  t.send(k, std::forward<Value>(value));
149  };
150 
151  template <typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
152  inline send_coro_state sendv_coro(Value &&value, ttg::Out<void, std::decay_t<Value>> &t,
154  ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
155  t.prepare_send(std::forward<Value>(value));
156  co_await ttg::Void{}; // we'll come back once the task is done
157  t.sendv(std::forward<Value>(value));
158  };
159 
160  template <typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
161  inline send_coro_state sendk_coro(const Key &key, ttg::Out<Key, void> &t) {
162  // no need to prepare the send but we have to suspend once
163  Key k = key;
164  co_await ttg::Void{}; // we'll come back once the task is done
165  t.sendk(k);
166  };
167 
168  template <ttg::Runtime Runtime = ttg::ttg_runtime>
169  inline send_coro_state send_coro(ttg::Out<void, void> &t) {
170  // no need to prepare the send but we have to suspend once
171  co_await ttg::Void{}; // we'll come back once the task is done
172  t.send();
173  };
174 
175  struct send_t {
176  send_coro_state coro;
177  };
178  } // namespace detail
179 
180  template <size_t i, typename keyT, typename valueT, typename... out_keysT, typename... out_valuesT,
182  inline detail::send_t send(const keyT &key, valueT &&value, std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
184  return detail::send_t{
185  detail::send_coro(key, copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
186  }
187 
188  template <size_t i, typename valueT, typename... out_keysT, typename... out_valuesT,
190  inline detail::send_t sendv(valueT &&value, std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
192  return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
193  }
194 
195  template <size_t i, typename Key, typename... out_keysT, typename... out_valuesT,
197  inline detail::send_t sendk(const Key &key, std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
198  return detail::send_t{detail::sendk_coro(key, std::get<i>(t))};
199  }
200 
201  // clang-format off
206  // clang-format on
207  template <typename keyT, typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
208  inline detail::send_t send(size_t i, const keyT &key, valueT &&value) {
210  auto *terminal_ptr = ttg::detail::get_out_terminal<keyT, valueT>(i, "ttg::device::send(i, key, value)");
211  return detail::send_t{detail::send_coro(key, copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
212  }
213 
214  // clang-format off
220  // clang-format on
221  template <size_t i, typename keyT, typename valueT>
222  inline auto send(const keyT &key, valueT &&value) {
223  return ttg::device::send(i, key, std::forward<valueT>(value));
224  }
225 
226 
227  template <typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
228  inline detail::send_t sendv(std::size_t i, valueT &&value) {
229  auto *terminal_ptr = ttg::detail::get_out_terminal<void, valueT>(i, "ttg::device::send(i, key, value)");
231  return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
232  }
233 
234  template <typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
235  inline detail::send_t sendk(std::size_t i, const Key& key) {
236  auto *terminal_ptr = ttg::detail::get_out_terminal<Key, void>(i, "ttg::device::send(i, key, value)");
237  return detail::send_t{detail::sendk_coro(key, *terminal_ptr)};
238  }
239 
240  template <ttg::Runtime Runtime = ttg::ttg_runtime>
241  inline detail::send_t send(std::size_t i) {
242  auto *terminal_ptr = ttg::detail::get_out_terminal<void, void>(i, "ttg::device::send(i, key, value)");
243  return detail::send_t{detail::send_coro(*terminal_ptr)};
244  }
245 
246 
247  template <std::size_t i, typename valueT, typename... out_keysT, typename... out_valuesT,
249  inline detail::send_t sendv(valueT &&value) {
250  return sendv(i, std::forward<valueT>(value));
251  }
252 
253  template <size_t i, typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
254  inline detail::send_t sendk(const Key& key) {
255  return sendk(i, key);
256  }
257 
258  template <size_t i, ttg::Runtime Runtime = ttg::ttg_runtime>
259  inline detail::send_t sendk() {
260  return send(i);
261  }
262 
263  namespace detail {
264 
265  template<typename T, typename Enabler = void>
266  struct broadcast_keylist_trait {
267  using type = T;
268  };
269 
270  /* overload for iterable types that extracts the type of the first element */
271  template<typename T>
272  struct broadcast_keylist_trait<T, std::enable_if_t<ttg::meta::is_iterable_v<T>>> {
273  using key_type = decltype(*std::begin(std::declval<T>()));
274  };
275 
276  template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT,
277  typename... out_keysT, typename... out_valuesT>
278  inline void prepare_broadcast(const std::tuple<RangesT...> &keylists, valueT &&value,
279  std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
280  std::get<I>(t).prepare_send(std::get<KeyId>(keylists), std::forward<valueT>(value));
281  if constexpr (sizeof...(Is) > 0) {
282  prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value), t);
283  }
284  }
285 
286  template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT,
287  typename... out_keysT, typename... out_valuesT>
288  inline void prepare_broadcast(const std::tuple<RangesT...> &keylists, valueT &&value) {
289  using key_t = typename broadcast_keylist_trait<
290  std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
291  >::key_type;
292  auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I, "ttg::device::broadcast(keylists, value)");
293  terminal_ptr->prepare_send(std::get<KeyId>(keylists), value);
294  if constexpr (sizeof...(Is) > 0) {
295  prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value));
296  }
297  }
298 
299  template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT,
300  typename... out_keysT, typename... out_valuesT>
301  inline void broadcast(const std::tuple<RangesT...> &keylists, valueT &&value,
302  std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
303  std::get<I>(t).broadcast(std::get<KeyId>(keylists), std::forward<valueT>(value));
304  if constexpr (sizeof...(Is) > 0) {
305  detail::broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value), t);
306  }
307  }
308 
309  template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT>
310  inline void broadcast(const std::tuple<RangesT...> &keylists, valueT &&value) {
311  using key_t = typename broadcast_keylist_trait<
312  std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
313  >::key_type;
314  auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I, "ttg::device::broadcast(keylists, value)");
315  terminal_ptr->broadcast(std::get<KeyId>(keylists), value);
316  if constexpr (sizeof...(Is) > 0) {
317  ttg::device::detail::broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value));
318  }
319  }
320 
321  /* overload with explicit terminals */
322  template <size_t I, size_t... Is, typename RangesT, typename valueT,
323  typename... out_keysT, typename... out_valuesT,
325  inline send_coro_state
326  broadcast_coro(RangesT &&keylists, valueT &&value,
327  std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t,
329  ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
330  RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
331  if constexpr (ttg::meta::is_tuple_v<RangesT>) {
332  // treat as tuple
333  prepare_broadcast<0, I, Is...>(kl, std::forward<std::decay_t<decltype(value)>>(value), t);
334  co_await ttg::Void{}; // we'll come back once the task is done
335  ttg::device::detail::broadcast<0, I, Is...>(kl, std::forward<std::decay_t<decltype(value)>>(value), t);
336  } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
337  // create a tie to the captured keylist
338  prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value), t);
339  co_await ttg::Void{}; // we'll come back once the task is done
340  ttg::device::detail::broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value), t);
341  }
342  }
343 
344  /* overload with implicit terminals */
345  template <size_t I, size_t... Is, typename RangesT, typename valueT,
347  inline send_coro_state
348  broadcast_coro(RangesT &&keylists, valueT &&value,
350  ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
351  RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
352  if constexpr (ttg::meta::is_tuple_v<RangesT>) {
353  // treat as tuple
354  static_assert(sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
355  "Size of keylist tuple must match the number of output terminals");
356  prepare_broadcast<0, I, Is...>(kl, std::forward<std::decay_t<decltype(value)>>(value));
357  co_await ttg::Void{}; // we'll come back once the task is done
358  ttg::device::detail::broadcast<0, I, Is...>(kl, std::forward<std::decay_t<decltype(value)>>(value));
359  } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
360  // create a tie to the captured keylist
361  prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value));
362  co_await ttg::Void{}; // we'll come back once the task is done
363  ttg::device::detail::broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value));
364  }
365  }
366 
371  template <size_t KeyId, size_t I, size_t... Is, typename... RangesT,
372  typename... out_keysT, typename... out_valuesT>
373  inline void broadcastk(const std::tuple<RangesT...> &keylists,
374  std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
375  std::get<I>(t).broadcast(std::get<KeyId>(keylists));
376  if constexpr (sizeof...(Is) > 0) {
377  detail::broadcastk<KeyId+1, Is...>(keylists, t);
378  }
379  }
380 
381  template <size_t KeyId, size_t I, size_t... Is, typename... RangesT>
382  inline void broadcastk(const std::tuple<RangesT...> &keylists) {
383  using key_t = typename broadcast_keylist_trait<
384  std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
385  >::key_type;
386  auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, void>(I, "ttg::device::broadcastk(keylists)");
387  terminal_ptr->broadcast(std::get<KeyId>(keylists));
388  if constexpr (sizeof...(Is) > 0) {
389  ttg::device::detail::broadcastk<KeyId+1, Is...>(keylists);
390  }
391  }
392 
393  /* overload with explicit terminals */
394  template <size_t I, size_t... Is, typename RangesT,
395  typename... out_keysT, typename... out_valuesT,
397  inline send_coro_state
398  broadcastk_coro(RangesT &&keylists,
399  std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
400  RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
401  if constexpr (ttg::meta::is_tuple_v<RangesT>) {
402  // treat as tuple
403  co_await ttg::Void{}; // we'll come back once the task is done
404  ttg::device::detail::broadcastk<0, I, Is...>(kl, t);
405  } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
406  // create a tie to the captured keylist
407  co_await ttg::Void{}; // we'll come back once the task is done
408  ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl), t);
409  }
410  }
411 
412  /* overload with implicit terminals */
413  template <size_t I, size_t... Is, typename RangesT,
415  inline send_coro_state
416  broadcastk_coro(RangesT &&keylists) {
417  RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
418  if constexpr (ttg::meta::is_tuple_v<RangesT>) {
419  // treat as tuple
420  static_assert(sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
421  "Size of keylist tuple must match the number of output terminals");
422  co_await ttg::Void{}; // we'll come back once the task is done
423  ttg::device::detail::broadcastk<0, I, Is...>(kl);
424  } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
425  // create a tie to the captured keylist
426  co_await ttg::Void{}; // we'll come back once the task is done
427  ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl));
428  }
429  }
430  } // namespace detail
431 
432  /* overload with explicit terminals and keylist passed by const reference */
433  template <size_t I, size_t... Is, typename rangeT, typename valueT, typename... out_keysT, typename... out_valuesT,
435  [[nodiscard]]
436  inline detail::send_t broadcast(rangeT &&keylist,
437  valueT &&value,
438  std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
440  return detail::send_t{
441  detail::broadcast_coro<I, Is...>(std::forward<rangeT>(keylist),
442  copy_handler(std::forward<valueT>(value)),
443  t, std::move(copy_handler))};
444  }
445 
446  /* overload with implicit terminals and keylist passed by const reference */
447  template <size_t i, typename rangeT, typename valueT,
449  inline detail::send_t broadcast(rangeT &&keylist, valueT &&value) {
451  return detail::send_t{broadcast_coro<i>(std::tie(keylist), copy_handler(std::forward<valueT>(value)),
452  std::move(copy_handler))};
453  }
454 
455  /* overload with explicit terminals and keylist passed by const reference */
456  template <size_t I, size_t... Is, typename rangeT, typename... out_keysT, typename... out_valuesT,
458  [[nodiscard]]
459  inline detail::send_t broadcastk(rangeT &&keylist,
460  std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
462  return detail::send_t{
463  detail::broadcastk_coro<I, Is...>(std::forward<rangeT>(keylist), t)};
464  }
465 
466  /* overload with implicit terminals and keylist passed by const reference */
467  template <size_t i, typename rangeT,
469  inline detail::send_t broadcastk(rangeT &&keylist) {
470  if constexpr (std::is_rvalue_reference_v<decltype(keylist)>) {
471  return detail::send_t{detail::broadcastk_coro<i>(std::forward<rangeT>(keylist))};
472  } else {
473  return detail::send_t{detail::broadcastk_coro<i>(std::tie(keylist))};
474  }
475  }
476 
477  template<typename... Args, ttg::Runtime Runtime = ttg::ttg_runtime>
478  [[nodiscard]]
479  std::vector<device::detail::send_t> forward(Args&&... args) {
480  // TODO: check the cost of this!
481  return std::vector<device::detail::send_t>{std::forward<Args>(args)...};
482  }
483 
484  /*******************************************
485  * Device task promise and coroutine handle
486  *******************************************/
487 
488  namespace detail {
489  // fwd-decl
490  struct device_task_promise_type;
491  // base type for ttg::device::Task
492  using device_task_handle_type = ttg::coroutine_handle<device_task_promise_type>;
493  } // namespace detail
494 
496 
503  struct Task : public detail::device_task_handle_type {
504  using base_type = detail::device_task_handle_type;
505 
508 
509  using promise_type = detail::device_task_promise_type;
510 
512 
513  Task(base_type base) : base_type(std::move(base)) {}
514 
515  base_type& handle() { return *this; }
516 
518  inline bool ready() {
519  return true;
520  }
521 
523  inline bool completed();
524  };
525 
526  namespace detail {
527 
528  /* The promise type that stores the views provided by the
529  * application task coroutine on the first co_yield. It subsequently
530  * tracks the state of the task when it moves from waiting for transfers
531  * to waiting for the submitted kernel to complete. */
532  struct device_task_promise_type {
533 
534  /* do not suspend the coroutine on first invocation, we want to run
535  * the coroutine immediately and suspend when we get the device transfers.
536  */
537  ttg::suspend_never initial_suspend() {
538  m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
539  return {};
540  }
541 
542  /* suspend the coroutine at the end of the execution
543  * so we can access the promise.
544  * TODO: necessary? maybe we can save one suspend here
545  */
546  ttg::suspend_always final_suspend() noexcept {
547  m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
548  return {};
549  }
550 
551  /* Allow co_await on a tuple */
552  template<typename... Views>
553  ttg::suspend_always await_transform(std::tuple<Views&...> &views) {
554  return yield_value(views);
555  }
556 
557  template<typename... Ts>
558  ttg::suspend_always await_transform(detail::to_device_t<Ts...>&& a) {
559  bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.ties));
560  /* TODO: are we allowed to not suspend here and launch the kernel directly? */
561  m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
562  return {};
563  }
564 
565  template<typename... Ts>
566  auto await_transform(detail::wait_kernel_t<Ts...>&& a) {
567  //std::cout << "yield_value: wait_kernel_t" << std::endl;
568  if constexpr (sizeof...(Ts) > 0) {
570  }
571  m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_KERNEL;
572  return a;
573  }
574 
575  ttg::suspend_always await_transform(std::vector<device::detail::send_t>&& v) {
576  m_sends = std::forward<std::vector<device::detail::send_t>>(v);
577  m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
578  return {};
579  }
580 
581  ttg::suspend_always await_transform(device::detail::send_t&& v) {
582  m_sends.clear();
583  m_sends.push_back(std::forward<device::detail::send_t>(v));
584  m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
585  return {};
586  }
587 
588  void return_void() {
589  m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
590  }
591 
592  bool complete() const {
593  return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
594  }
595 
596  ttg::device::Task get_return_object() { return {detail::device_task_handle_type::from_promise(*this)}; }
597 
598  void unhandled_exception() {
599  std::cerr << "Task coroutine caught an unhandled exception!" << std::endl;
600  throw; // fwd
601  }
602 
603  //using iterator = std::vector<device_obj_view>::iterator;
604 
605  /* execute all pending send and broadcast operations */
606  void do_sends() {
607  for (auto& send : m_sends) {
608  send.coro();
609  }
610  m_sends.clear();
611  }
612 
613  auto state() {
614  return m_state;
615  }
616 
617  private:
618  std::vector<device::detail::send_t> m_sends;
619  ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
620 
621  };
622 
623  } // namespace detail
624 
625  bool Task::completed() { return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
626 
627  struct device_wait_kernel
628  { };
629 
630 
631  /* NOTE: below is preliminary for reductions on the device, which is not available yet */
632 #if 0
633  /**************************
634  * Device reduction coros *
635  **************************/
636 
637  struct device_reducer_promise_type;
638 
639  using device_reducer_handle_type = ttg::coroutine_handle<device_reducer_promise_type>;
640 
642  struct device_reducer : public device_reducer_handle_type {
643  using base_type = device_reducer_handle_type;
644 
647 
648  using promise_type = device_reducer_promise_type;
649 
651 
652  device_reducer(base_type base) : base_type(std::move(base)) {}
653 
654  base_type& handle() { return *this; }
655 
657  inline bool ready() {
658  return true;
659  }
660 
662  inline bool completed();
663  };
664 
665 
666  /* The promise type that stores the views provided by the
667  * application task coroutine on the first co_yield. It subsequently
668  * tracks the state of the task when it moves from waiting for transfers
669  * to waiting for the submitted kernel to complete. */
670  struct device_reducer_promise_type {
671 
672  /* do not suspend the coroutine on first invocation, we want to run
673  * the coroutine immediately and suspend when we get the device transfers.
674  */
675  ttg::suspend_never initial_suspend() {
676  m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
677  return {};
678  }
679 
680  /* suspend the coroutine at the end of the execution
681  * so we can access the promise.
682  * TODO: necessary? maybe we can save one suspend here
683  */
684  ttg::suspend_always final_suspend() noexcept {
685  m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
686  return {};
687  }
688 
689  template<typename... Ts>
690  ttg::suspend_always await_transform(detail::to_device_t<Ts...>&& a) {
691  bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.ties));
692  /* TODO: are we allowed to not suspend here and launch the kernel directly? */
693  m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
694  return {};
695  }
696 
697  void return_void() {
698  m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
699  }
700 
701  bool complete() const {
702  return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
703  }
704 
705  device_reducer get_return_object() { return device_reducer{device_reducer_handle_type::from_promise(*this)}; }
706 
707  void unhandled_exception() { }
708 
709  auto state() {
710  return m_state;
711  }
712 
713 
714  private:
715  ttg::device::detail::ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
716 
717  };
718 
719  bool device_reducer::completed() { return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
720 #endif // 0
721 
722 } // namespace ttg::device
723 
724 #endif // TTG_HAVE_COROUTINE
725 
726 #endif // TTG_DEVICE_TASK_H
std::enable_if_t< meta::is_all_void_v< Key, Value >, void > send()
Definition: terminal.h:514
std::enable_if_t<!meta::is_void_v< Key > &&meta::is_void_v< Value >, void > sendk(const Key &key)
Definition: terminal.h:475
A complete version of void.
Definition: void.h:11
void broadcast(const std::tuple< RangesT... > &keylists, valueT &&value, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)
Definition: func.h:347
constexpr bool is_devicescratch_v
Definition: meta.h:340
constexpr bool is_buffer_v
Definition: meta.h:325
void mark_device_out(std::tuple< Buffer &... > &b)
bool register_device_memory(std::tuple< Views &... > &views)
void post_device_out(std::tuple< Buffer &... > &b)
TTG_CXX_COROUTINE_NAMESPACE::suspend_always suspend_always
Definition: coroutine.h:21
void send(const keyT &key, valueT &&value, ttg::Out< keyT, valueT > &t)
Sends a task id and a value to the given output terminal.
Definition: func.h:158
constexpr const ttg::Runtime ttg_runtime
Definition: import.h:20
Runtime
Definition: runtimes.h:15
TTG_CXX_COROUTINE_NAMESPACE::suspend_never suspend_never
Definition: coroutine.h:22
TTG_CXX_COROUTINE_NAMESPACE::coroutine_handle< Promise > coroutine_handle
Definition: coroutine.h:24
void sendk(const keyT &key, ttg::Out< keyT, void > &t)
Sends a task id (without an accompanying value) to the given output terminal.
Definition: func.h:169
void sendv(valueT &&value, ttg::Out< void, valueT > &t)
Sends a value (without an accompanying task id) to the given output terminal.
Definition: func.h:179
void broadcast(const rangeT &keylist, valueT &&value, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)
Definition: func.h:414
void broadcastk(const rangeT &keylist, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)
Definition: func.h:452