ttg 1.0.0-alpha
Template Task Graph (TTG): flowgraph-based programming model for high-performance distributed-memory algorithms
Loading...
Searching...
No Matches
task.h
Go to the documentation of this file.
1#ifndef TTG_DEVICE_TASK_H
2#define TTG_DEVICE_TASK_H
3
4#include <array>
5#include <type_traits>
6#include <span>
7
8
9#include "ttg/fwd.h"
10#include "ttg/impl_selector.h"
11#include "ttg/ptr.h"
12#include "ttg/devicescope.h"
13
14#ifdef TTG_HAVE_COROUTINE
15
16namespace ttg::device {
17
18 namespace detail {
19
20 struct device_input_data_t {
21 using impl_data_t = decltype(TTG_IMPL_NS::buffer_data(std::declval<ttg::Buffer<int>>()));
22
23 device_input_data_t(impl_data_t data, ttg::scope scope, bool isconst, bool isscratch)
24 : impl_data(data), scope(scope), is_const(isconst), is_scratch(isscratch)
25 { }
26 impl_data_t impl_data;
28 bool is_const;
29 bool is_scratch;
30 };
31
32 template <typename... Ts>
33 struct to_device_t {
34 std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
35 };
36
37 /* extract buffer information from to_device_t */
38 template<typename... Ts, std::size_t... Is>
39 auto extract_buffer_data(detail::to_device_t<Ts...>& a, std::index_sequence<Is...>) {
40 using arg_types = std::tuple<Ts...>;
41 return std::array<device_input_data_t, sizeof...(Ts)>{
42 device_input_data_t{TTG_IMPL_NS::buffer_data(std::get<Is>(a.ties)),
43 std::get<Is>(a.ties).scope(),
46 }
47 } // namespace detail
48
49 struct Input {
50 private:
51 std::vector<detail::device_input_data_t> m_data;
52
53 public:
54 Input() { }
55 template<typename... Args>
56 Input(Args&&... args)
57 : m_data{{TTG_IMPL_NS::buffer_data(args), args.scope(),
58 std::is_const_v<std::remove_reference_t<Args>>,
60 { }
61
62 template<typename T>
63 void add(T&& v) {
64 using type = std::remove_reference_t<T>;
65 m_data.emplace_back(TTG_IMPL_NS::buffer_data(v), v.scope(), std::is_const_v<type>,
67 }
68
69 ttg::span<detail::device_input_data_t> span() {
70 return ttg::span(m_data);
71 }
72 };
73
74 namespace detail {
75 // overload for Input
76 template <>
77 struct to_device_t<Input> {
78 Input& input;
79 };
80 } // namespace detail
81
89 template <typename... Args>
90 [[nodiscard]]
91 inline auto select(Args &&...args) {
92 return detail::to_device_t<std::remove_reference_t<Args>...>{std::tie(std::forward<Args>(args)...)};
93 }
94
95 [[nodiscard]]
96 inline auto select(Input& input) {
97 return detail::to_device_t<Input>{input};
98 }
99
100 namespace detail {
101
102 enum ttg_device_coro_state {
103 TTG_DEVICE_CORO_STATE_NONE,
104 TTG_DEVICE_CORO_INIT,
105 TTG_DEVICE_CORO_WAIT_TRANSFER,
106 TTG_DEVICE_CORO_WAIT_KERNEL,
107 TTG_DEVICE_CORO_SENDOUT,
108 TTG_DEVICE_CORO_COMPLETE
109 };
110
111 template <typename... Ts>
112 struct wait_kernel_t {
113 std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
114
115 /* always suspend */
116 constexpr bool await_ready() const noexcept { return false; }
117
118 /* always suspend */
119 template <typename Promise>
120 constexpr void await_suspend(ttg::coroutine_handle<Promise>) const noexcept {}
121
122 void await_resume() noexcept {
123 if constexpr (sizeof...(Ts) > 0) {
124 /* hook to allow the backend to handle the data after pushout */
125 TTG_IMPL_NS::post_device_out(ties);
126 }
127 }
128 };
129 } // namespace detail
130
136 template <typename... Buffers>
137 [[nodiscard]]
138 inline auto wait(Buffers &&...args) {
139 static_assert(
141 ...),
142 "Only ttg::Buffer and ttg::devicescratch can be waited on!");
143 return detail::wait_kernel_t<std::remove_reference_t<Buffers>...>{std::tie(std::forward<Buffers>(args)...)};
144 }
145
146 /******************************
147 * Send/Broadcast handling
148 * We pass the value returned by the backend's copy handler into a coroutine
149 * and execute the first part (prepare), before suspending it.
150 * The second part (send/broadcast) is executed after the task completed.
151 ******************************/
152
153 namespace detail {
154 struct send_coro_promise_type;
155
156 using send_coro_handle_type = ttg::coroutine_handle<send_coro_promise_type>;
157
159 struct send_coro_state : public send_coro_handle_type {
160 using base_type = send_coro_handle_type;
161
164
165 using promise_type = send_coro_promise_type;
166
168
169 send_coro_state(base_type base) : base_type(std::move(base)) {}
170
171 base_type &handle() { return *this; }
172
174 inline bool ready() { return true; }
175
177 inline bool completed();
178 };
179
181 struct send_coro_promise_type {
182 /* do not suspend the coroutine on first invocation, we want to run
183 * the coroutine immediately and suspend only once.
184 */
185 ttg::suspend_never initial_suspend() { return {}; }
186
187 /* we don't suspend the coroutine at the end.
188 * it can be destroyed once the send/broadcast is done
189 */
190 ttg::suspend_never final_suspend() noexcept { return {}; }
191
192 send_coro_state get_return_object() { return send_coro_state{send_coro_handle_type::from_promise(*this)}; }
193
194 /* the send coros only have an empty co_await */
195 ttg::suspend_always await_transform(ttg::Void) { return {}; }
196
197 void unhandled_exception() {
198 std::cerr << "Send coroutine caught an unhandled exception!" << std::endl;
199 throw; // fwd
200 }
201
202 void return_void() {}
203 };
204
205 template <typename Key, typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
206 inline send_coro_state send_coro(const Key &key, Value &&value, ttg::Out<Key, std::decay_t<Value>> &t,
208 ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
209 Key k = key;
210 t.prepare_send(k, std::forward<Value>(value));
211 co_await ttg::Void{}; // we'll come back once the task is done
212 t.send(k, std::forward<Value>(value));
213 };
214
215 template <typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
216 inline send_coro_state sendv_coro(Value &&value, ttg::Out<void, std::decay_t<Value>> &t,
218 ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
219 t.prepare_send(std::forward<Value>(value));
220 co_await ttg::Void{}; // we'll come back once the task is done
221 t.sendv(std::forward<Value>(value));
222 };
223
224 template <typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
225 inline send_coro_state sendk_coro(const Key &key, ttg::Out<Key, void> &t) {
226 // no need to prepare the send but we have to suspend once
227 Key k = key;
228 co_await ttg::Void{}; // we'll come back once the task is done
229 t.sendk(k);
230 };
231
232 template <ttg::Runtime Runtime = ttg::ttg_runtime>
233 inline send_coro_state send_coro(ttg::Out<void, void> &t) {
234 // no need to prepare the send but we have to suspend once
235 co_await ttg::Void{}; // we'll come back once the task is done
236 t.send();
237 };
238
239 struct send_t {
240 send_coro_state coro;
241 };
242 } // namespace detail
243
244 template <size_t i, typename keyT, typename valueT, typename... out_keysT, typename... out_valuesT,
246 inline detail::send_t send(const keyT &key, valueT &&value, std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
248 return detail::send_t{
249 detail::send_coro(key, copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
250 }
251
252 template <size_t i, typename valueT, typename... out_keysT, typename... out_valuesT,
254 inline detail::send_t sendv(valueT &&value, std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
256 return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
257 }
258
259 template <size_t i, typename Key, typename... out_keysT, typename... out_valuesT,
261 inline detail::send_t sendk(const Key &key, std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
262 return detail::send_t{detail::sendk_coro(key, std::get<i>(t))};
263 }
264
265 // clang-format off
270 // clang-format on
271 template <typename keyT, typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
272 inline detail::send_t send(size_t i, const keyT &key, valueT &&value) {
274 auto *terminal_ptr = ttg::detail::get_out_terminal<keyT, valueT>(i, "ttg::device::send(i, key, value)");
275 return detail::send_t{detail::send_coro(key, copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
276 }
277
278 // clang-format off
284 // clang-format on
285 template <size_t i, typename keyT, typename valueT>
286 inline auto send(const keyT &key, valueT &&value) {
287 return ttg::device::send(i, key, std::forward<valueT>(value));
288 }
289
290
291 template <typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
292 inline detail::send_t sendv(std::size_t i, valueT &&value) {
293 auto *terminal_ptr = ttg::detail::get_out_terminal<void, valueT>(i, "ttg::device::send(i, key, value)");
295 return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
296 }
297
298 template <typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
299 inline detail::send_t sendk(std::size_t i, const Key& key) {
300 auto *terminal_ptr = ttg::detail::get_out_terminal<Key, void>(i, "ttg::device::send(i, key, value)");
301 return detail::send_t{detail::sendk_coro(key, *terminal_ptr)};
302 }
303
304 template <ttg::Runtime Runtime = ttg::ttg_runtime>
305 inline detail::send_t send(std::size_t i) {
306 auto *terminal_ptr = ttg::detail::get_out_terminal<void, void>(i, "ttg::device::send(i, key, value)");
307 return detail::send_t{detail::send_coro(*terminal_ptr)};
308 }
309
310
311 template <std::size_t i, typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
312 inline detail::send_t sendv(valueT &&value) {
313 return sendv(i, std::forward<valueT>(value));
314 }
315
316 template <size_t i, typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
317 inline detail::send_t sendk(const Key& key) {
318 return sendk(i, key);
319 }
320
321 template <size_t i, ttg::Runtime Runtime = ttg::ttg_runtime>
322 inline detail::send_t sendk() {
323 return send(i);
324 }
325
326 namespace detail {
327
328 template<typename T, typename Enabler = void>
329 struct broadcast_keylist_trait {
330 using type = T;
331 };
332
333 /* overload for iterable types that extracts the type of the first element */
334 template<typename T>
335 struct broadcast_keylist_trait<T, std::enable_if_t<ttg::meta::is_iterable_v<T>>> {
336 using key_type = decltype(*std::begin(std::declval<T>()));
337 };
338
339 template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT,
340 typename... out_keysT, typename... out_valuesT>
341 inline void prepare_broadcast(const std::tuple<RangesT...> &keylists, valueT &&value,
342 std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
343 std::get<I>(t).prepare_send(std::get<KeyId>(keylists), std::forward<valueT>(value));
344 if constexpr (sizeof...(Is) > 0) {
345 prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value), t);
346 }
347 }
348
349 template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT>
350 inline void prepare_broadcast(const std::tuple<RangesT...> &keylists, valueT &&value) {
351 using key_t = typename broadcast_keylist_trait<
352 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
353 >::key_type;
354 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I, "ttg::device::broadcast(keylists, value)");
355 terminal_ptr->prepare_send(std::get<KeyId>(keylists), value);
356 if constexpr (sizeof...(Is) > 0) {
357 prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value));
358 }
359 }
360
361 template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT,
362 typename... out_keysT, typename... out_valuesT>
363 inline void broadcast(const std::tuple<RangesT...> &keylists, valueT &&value,
364 std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
365 std::get<I>(t).broadcast(std::get<KeyId>(keylists), std::forward<valueT>(value));
366 if constexpr (sizeof...(Is) > 0) {
367 detail::broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value), t);
368 }
369 }
370
371 template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT>
372 inline void broadcast(const std::tuple<RangesT...> &keylists, valueT &&value) {
373 using key_t = typename broadcast_keylist_trait<
374 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
375 >::key_type;
376 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I, "ttg::device::broadcast(keylists, value)");
377 terminal_ptr->broadcast(std::get<KeyId>(keylists), value);
378 if constexpr (sizeof...(Is) > 0) {
379 ttg::device::detail::broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value));
380 }
381 }
382
383 /* overload with explicit terminals */
384 template <size_t I, size_t... Is, typename RangesT, typename valueT,
385 typename... out_keysT, typename... out_valuesT,
387 inline send_coro_state
388 broadcast_coro(RangesT &&keylists, valueT &&value,
389 std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t,
391 ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
392 RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
393 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
394 // treat as tuple
395 prepare_broadcast<0, I, Is...>(kl, std::forward<valueT>(value), t);
396 co_await ttg::Void{}; // we'll come back once the task is done
397 ttg::device::detail::broadcast<0, I, Is...>(kl, std::forward<valueT>(value), t);
398 } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
399 // create a tie to the captured keylist
400 prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<valueT>(value), t);
401 co_await ttg::Void{}; // we'll come back once the task is done
402 ttg::device::detail::broadcast<0, I, Is...>(std::tie(kl), std::forward<valueT>(value), t);
403 }
404 }
405
406 /* overload with implicit terminals */
407 template <size_t I, size_t... Is, typename RangesT, typename valueT,
409 inline send_coro_state
410 broadcast_coro(RangesT &&keylists, valueT &&value,
412 ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
413 RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
414 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
415 // treat as tuple
416 static_assert(sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
417 "Size of keylist tuple must match the number of output terminals");
418 prepare_broadcast<0, I, Is...>(kl, std::forward<valueT>(value));
419 co_await ttg::Void{}; // we'll come back once the task is done
420 ttg::device::detail::broadcast<0, I, Is...>(kl, std::forward<valueT>(value));
421 } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
422 // create a tie to the captured keylist
423 prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<valueT>(value));
424 co_await ttg::Void{}; // we'll come back once the task is done
425 ttg::device::detail::broadcast<0, I, Is...>(std::tie(kl), std::forward<valueT>(value));
426 }
427 }
428
433 template <size_t KeyId, size_t I, size_t... Is, typename... RangesT,
434 typename... out_keysT, typename... out_valuesT>
435 inline void broadcastk(const std::tuple<RangesT...> &keylists,
436 std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
437 std::get<I>(t).broadcast(std::get<KeyId>(keylists));
438 if constexpr (sizeof...(Is) > 0) {
439 detail::broadcastk<KeyId+1, Is...>(keylists, t);
440 }
441 }
442
443 template <size_t KeyId, size_t I, size_t... Is, typename... RangesT>
444 inline void broadcastk(const std::tuple<RangesT...> &keylists) {
445 using key_t = typename broadcast_keylist_trait<
446 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
447 >::key_type;
448 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, void>(I, "ttg::device::broadcastk(keylists)");
449 terminal_ptr->broadcast(std::get<KeyId>(keylists));
450 if constexpr (sizeof...(Is) > 0) {
451 ttg::device::detail::broadcastk<KeyId+1, Is...>(keylists);
452 }
453 }
454
455 /* overload with explicit terminals */
456 template <size_t I, size_t... Is, typename RangesT,
457 typename... out_keysT, typename... out_valuesT,
459 inline send_coro_state
460 broadcastk_coro(RangesT &&keylists,
461 std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
462 RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
463 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
464 // treat as tuple
465 co_await ttg::Void{}; // we'll come back once the task is done
466 ttg::device::detail::broadcastk<0, I, Is...>(kl, t);
467 } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
468 // create a tie to the captured keylist
469 co_await ttg::Void{}; // we'll come back once the task is done
470 ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl), t);
471 }
472 }
473
474 /* overload with implicit terminals */
475 template <size_t I, size_t... Is, typename RangesT,
477 inline send_coro_state
478 broadcastk_coro(RangesT &&keylists) {
479 RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
480 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
481 // treat as tuple
482 static_assert(sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
483 "Size of keylist tuple must match the number of output terminals");
484 co_await ttg::Void{}; // we'll come back once the task is done
485 ttg::device::detail::broadcastk<0, I, Is...>(kl);
486 } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
487 // create a tie to the captured keylist
488 co_await ttg::Void{}; // we'll come back once the task is done
489 ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl));
490 }
491 }
492 } // namespace detail
493
494 /* overload with explicit terminals and keylist passed by const reference */
495 template <size_t I, size_t... Is, typename rangeT, typename valueT, typename... out_keysT, typename... out_valuesT,
497 [[nodiscard]]
498 inline detail::send_t broadcast(rangeT &&keylist,
499 valueT &&value,
500 std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
502 return detail::send_t{
503 detail::broadcast_coro<I, Is...>(std::forward<rangeT>(keylist),
504 copy_handler(std::forward<valueT>(value)),
505 t, std::move(copy_handler))};
506 }
507
508 /* overload with implicit terminals and keylist passed by const reference */
509 template <size_t i, typename rangeT, typename valueT,
511 inline detail::send_t broadcast(rangeT &&keylist, valueT &&value) {
513 return detail::send_t{detail::broadcast_coro<i>(std::tie(keylist),
514 copy_handler(std::forward<valueT>(value)),
515 std::move(copy_handler))};
516 }
517
518 /* overload with explicit terminals and keylist passed by const reference */
519 template <size_t I, size_t... Is, typename rangeT, typename... out_keysT, typename... out_valuesT,
521 [[nodiscard]]
522 inline detail::send_t broadcastk(rangeT &&keylist,
523 std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
525 return detail::send_t{
526 detail::broadcastk_coro<I, Is...>(std::forward<rangeT>(keylist), t)};
527 }
528
529 /* overload with implicit terminals and keylist passed by const reference */
530 template <size_t i, typename rangeT,
532 inline detail::send_t broadcastk(rangeT &&keylist) {
533 if constexpr (std::is_rvalue_reference_v<decltype(keylist)>) {
534 return detail::send_t{detail::broadcastk_coro<i>(std::forward<rangeT>(keylist))};
535 } else {
536 return detail::send_t{detail::broadcastk_coro<i>(std::tie(keylist))};
537 }
538 }
539
540 template<typename... Args, ttg::Runtime Runtime = ttg::ttg_runtime>
541 [[nodiscard]]
542 std::vector<device::detail::send_t> forward(Args&&... args) {
543 // TODO: check the cost of this!
544 return std::vector<device::detail::send_t>{std::forward<Args>(args)...};
545 }
546
547 /*******************************************
548 * Device task promise and coroutine handle
549 *******************************************/
550
551 namespace detail {
552 // fwd-decl
553 struct device_task_promise_type;
554 // base type for ttg::device::Task
555 using device_task_handle_type = ttg::coroutine_handle<device_task_promise_type>;
556 } // namespace detail
557
559
566 struct Task : public detail::device_task_handle_type {
567 using base_type = detail::device_task_handle_type;
568
571
572 using promise_type = detail::device_task_promise_type;
573
575
576 Task(base_type base) : base_type(std::move(base)) {}
577
578 base_type& handle() { return *this; }
579
581 inline bool ready() {
582 return true;
583 }
584
586 inline bool completed();
587 };
588
589 namespace detail {
590
591 /* The promise type that stores the views provided by the
592 * application task coroutine on the first co_yield. It subsequently
593 * tracks the state of the task when it moves from waiting for transfers
594 * to waiting for the submitted kernel to complete. */
595 struct device_task_promise_type {
596
597 /* do not suspend the coroutine on first invocation, we want to run
598 * the coroutine immediately and suspend when we get the device transfers.
599 */
600 ttg::suspend_never initial_suspend() {
601 m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
602 return {};
603 }
604
605 /* suspend the coroutine at the end of the execution
606 * so we can access the promise.
607 * TODO: necessary? maybe we can save one suspend here
608 */
609 ttg::suspend_always final_suspend() noexcept {
610 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
611 return {};
612 }
613
614 /* Allow co_await on a tuple */
615 template<typename... Views>
616 ttg::suspend_always await_transform(std::tuple<Views&...> &views) {
617 return yield_value(views);
618 }
619
620 template<typename... Ts>
621 ttg::suspend_always await_transform(detail::to_device_t<Ts...>&& a) {
622 auto arr = detail::extract_buffer_data(a, std::make_index_sequence<sizeof...(Ts)>{});
623 bool need_transfer = !(TTG_IMPL_NS::register_device_memory(ttg::span(arr)));
624 /* TODO: are we allowed to not suspend here and launch the kernel directly? */
625 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
626 return {};
627 }
628
629 ttg::suspend_always await_transform(detail::to_device_t<Input>&& a) {
630 bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.input.span()));
631 /* TODO: are we allowed to not suspend here and launch the kernel directly? */
632 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
633 return {};
634 }
635
636 template<typename... Ts>
637 auto await_transform(detail::wait_kernel_t<Ts...>&& a) {
638 //std::cout << "yield_value: wait_kernel_t" << std::endl;
639 if constexpr (sizeof...(Ts) > 0) {
640 TTG_IMPL_NS::mark_device_out(a.ties);
641 }
642 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_KERNEL;
643 return a;
644 }
645
646 ttg::suspend_always await_transform(std::vector<device::detail::send_t>&& v) {
647 m_sends = std::forward<std::vector<device::detail::send_t>>(v);
648 m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
649 return {};
650 }
651
652 ttg::suspend_always await_transform(device::detail::send_t&& v) {
653 m_sends.clear();
654 m_sends.push_back(std::forward<device::detail::send_t>(v));
655 m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
656 return {};
657 }
658
659 void return_void() {
660 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
661 }
662
663 bool complete() const {
664 return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
665 }
666
667 ttg::device::Task get_return_object() { return {detail::device_task_handle_type::from_promise(*this)}; }
668
669 void unhandled_exception() {
670 std::cerr << "Task coroutine caught an unhandled exception!" << std::endl;
671 throw; // fwd
672 }
673
674 //using iterator = std::vector<device_obj_view>::iterator;
675
676 /* execute all pending send and broadcast operations */
677 void do_sends() {
678 for (auto& send : m_sends) {
679 send.coro();
680 }
681 m_sends.clear();
682 }
683
684 auto state() {
685 return m_state;
686 }
687
688 private:
689 std::vector<device::detail::send_t> m_sends;
690 ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
691
692 };
693
694 } // namespace detail
695
696 bool Task::completed() { return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
697
698 struct device_wait_kernel
699 { };
700
701
702 /* NOTE: below is preliminary for reductions on the device, which is not available yet */
703#if 0
704 /**************************
705 * Device reduction coros *
706 **************************/
707
708 struct device_reducer_promise_type;
709
710 using device_reducer_handle_type = ttg::coroutine_handle<device_reducer_promise_type>;
711
713 struct device_reducer : public device_reducer_handle_type {
714 using base_type = device_reducer_handle_type;
715
718
719 using promise_type = device_reducer_promise_type;
720
722
723 device_reducer(base_type base) : base_type(std::move(base)) {}
724
725 base_type& handle() { return *this; }
726
728 inline bool ready() {
729 return true;
730 }
731
733 inline bool completed();
734 };
735
736
737 /* The promise type that stores the views provided by the
738 * application task coroutine on the first co_yield. It subsequently
739 * tracks the state of the task when it moves from waiting for transfers
740 * to waiting for the submitted kernel to complete. */
741 struct device_reducer_promise_type {
742
743 /* do not suspend the coroutine on first invocation, we want to run
744 * the coroutine immediately and suspend when we get the device transfers.
745 */
746 ttg::suspend_never initial_suspend() {
747 m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
748 return {};
749 }
750
751 /* suspend the coroutine at the end of the execution
752 * so we can access the promise.
753 * TODO: necessary? maybe we can save one suspend here
754 */
755 ttg::suspend_always final_suspend() noexcept {
756 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
757 return {};
758 }
759
760 template<typename... Ts>
761 ttg::suspend_always await_transform(detail::to_device_t<Ts...>&& a) {
762 bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.ties));
763 /* TODO: are we allowed to not suspend here and launch the kernel directly? */
764 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
765 return {};
766 }
767
768 void return_void() {
769 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
770 }
771
772 bool complete() const {
773 return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
774 }
775
776 device_reducer get_return_object() { return device_reducer{device_reducer_handle_type::from_promise(*this)}; }
777
778 void unhandled_exception() { }
779
780 auto state() {
781 return m_state;
782 }
783
784
785 private:
786 ttg::device::detail::ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
787
788 };
789
790 bool device_reducer::completed() { return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
791#endif // 0
792
793} // namespace ttg::device
794
795#endif // TTG_HAVE_COROUTINE
796
797#endif // TTG_DEVICE_TASK_H
std::enable_if_t<!meta::is_void_v< Key > &&meta::is_void_v< Value >, void > sendk(const Key &key)
Definition terminal.h:475
std::enable_if_t< meta::is_all_void_v< Key, Value >, void > send()
Definition terminal.h:514
A complete version of void.
Definition void.h:11
constexpr auto data(C &c) -> decltype(c.data())
Definition span.h:190
std::integral_constant< bool,(Flags &const_) !=0 > is_const
STL namespace.
void broadcast(const std::tuple< RangesT... > &keylists, valueT &&value, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)
Definition func.h:347
constexpr auto get(typelist< T, RestOfTs... >)
Definition typelist.h:101
TTG_CXX_COROUTINE_NAMESPACE::suspend_always suspend_always
Definition coroutine.h:21
void send(const keyT &key, valueT &&value, ttg::Out< keyT, valueT > &t)
Sends a task id and a value to the given output terminal.
Definition func.h:158
constexpr const ttg::Runtime ttg_runtime
Definition import.h:20
Runtime
Definition runtimes.h:15
TTG_CXX_COROUTINE_NAMESPACE::suspend_never suspend_never
Definition coroutine.h:22
TTG_CXX_COROUTINE_NAMESPACE::coroutine_handle< Promise > coroutine_handle
Definition coroutine.h:24
void sendk(const keyT &key, ttg::Out< keyT, void > &t)
Sends a task id (without an accompanying value) to the given output terminal.
Definition func.h:169
void sendv(valueT &&value, ttg::Out< void, valueT > &t)
Sends a value (without an accompanying task id) to the given output terminal.
Definition func.h:179
void broadcastk(const rangeT &keylist, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)
Definition func.h:452