ttg 1.0.0
Template Task Graph (TTG): flowgraph-based programming model for high-performance distributed-memory algorithms
Loading...
Searching...
No Matches
task.h
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-3-Clause
2#ifndef TTG_DEVICE_TASK_H
3#define TTG_DEVICE_TASK_H
4
5#include <array>
6#include <type_traits>
7#include <span>
8
9
10#include "ttg/fwd.h"
11#include "ttg/impl_selector.h"
12#include "ttg/ptr.h"
13#include "ttg/devicescope.h"
14
15#ifdef TTG_HAVE_COROUTINE
16
17namespace ttg::device {
18
19 namespace detail {
20
21 struct device_input_data_t {
22 using impl_data_t = decltype(TTG_IMPL_NS::buffer_data(std::declval<ttg::Buffer<int>>()));
23
24 device_input_data_t(impl_data_t data, ttg::scope scope, bool isconst, bool isscratch)
25 : impl_data(data), scope(scope), is_const(isconst), is_scratch(isscratch)
26 { }
27 impl_data_t impl_data;
29 bool is_const;
30 bool is_scratch;
31 };
32
33 template <typename... Ts>
34 struct to_device_t {
35 std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
36 };
37
38 /* extract buffer information from to_device_t */
39 template<typename... Ts, std::size_t... Is>
40 auto extract_buffer_data(detail::to_device_t<Ts...>& a, std::index_sequence<Is...>) {
41 using arg_types = std::tuple<Ts...>;
42 return std::array<device_input_data_t, sizeof...(Ts)>{
43 device_input_data_t{TTG_IMPL_NS::buffer_data(std::get<Is>(a.ties)),
44 std::get<Is>(a.ties).scope(),
47 }
48 } // namespace detail
49
50 struct Input {
51 private:
52 std::vector<detail::device_input_data_t> m_data;
53
54 public:
55 Input() { }
56 template<typename... Args>
57 Input(Args&&... args)
58 : m_data{{TTG_IMPL_NS::buffer_data(args), args.scope(),
59 std::is_const_v<std::remove_reference_t<Args>>,
61 { }
62
63 template<typename T>
64 void add(T&& v) {
65 using type = std::remove_reference_t<T>;
66 m_data.emplace_back(TTG_IMPL_NS::buffer_data(v), v.scope(), std::is_const_v<type>,
68 }
69
70 ttg::span<detail::device_input_data_t> span() {
71 return ttg::span(m_data);
72 }
73 };
74
75 namespace detail {
76 // overload for Input
77 template <>
78 struct to_device_t<Input> {
79 Input& input;
80 };
81 } // namespace detail
82
90 template <typename... Args>
91 [[nodiscard]]
92 inline auto select(Args &&...args) {
93 return detail::to_device_t<std::remove_reference_t<Args>...>{std::tie(std::forward<Args>(args)...)};
94 }
95
96 [[nodiscard]]
97 inline auto select(Input& input) {
98 return detail::to_device_t<Input>{input};
99 }
100
101 namespace detail {
102
103 enum ttg_device_coro_state {
104 TTG_DEVICE_CORO_STATE_NONE,
105 TTG_DEVICE_CORO_INIT,
106 TTG_DEVICE_CORO_WAIT_TRANSFER,
107 TTG_DEVICE_CORO_WAIT_KERNEL,
108 TTG_DEVICE_CORO_SENDOUT,
109 TTG_DEVICE_CORO_COMPLETE
110 };
111
112 template <typename... Ts>
113 struct wait_kernel_t {
114 std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
115
116 /* always suspend */
117 constexpr bool await_ready() const noexcept { return false; }
118
119 /* always suspend */
120 template <typename Promise>
121 constexpr void await_suspend(ttg::coroutine_handle<Promise>) const noexcept {}
122
123 void await_resume() noexcept {
124 if constexpr (sizeof...(Ts) > 0) {
125 /* hook to allow the backend to handle the data after pushout */
126 TTG_IMPL_NS::post_device_out(ties);
127 }
128 }
129 };
130 } // namespace detail
131
137 template <typename... Buffers>
138 [[nodiscard]]
139 inline auto wait(Buffers &&...args) {
140 static_assert(
142 ...),
143 "Only ttg::Buffer and ttg::devicescratch can be waited on!");
144 return detail::wait_kernel_t<std::remove_reference_t<Buffers>...>{std::tie(std::forward<Buffers>(args)...)};
145 }
146
147 /******************************
148 * Send/Broadcast handling
149 * We pass the value returned by the backend's copy handler into a coroutine
150 * and execute the first part (prepare), before suspending it.
151 * The second part (send/broadcast) is executed after the task completed.
152 ******************************/
153
154 namespace detail {
155 struct send_coro_promise_type;
156
157 using send_coro_handle_type = ttg::coroutine_handle<send_coro_promise_type>;
158
160 struct send_coro_state : public send_coro_handle_type {
161 using base_type = send_coro_handle_type;
162
165
166 using promise_type = send_coro_promise_type;
167
169
170 send_coro_state(base_type base) : base_type(std::move(base)) {}
171
172 base_type &handle() { return *this; }
173
175 inline bool ready() { return true; }
176
178 inline bool completed();
179 };
180
182 struct send_coro_promise_type {
183 /* do not suspend the coroutine on first invocation, we want to run
184 * the coroutine immediately and suspend only once.
185 */
186 ttg::suspend_never initial_suspend() { return {}; }
187
188 /* we don't suspend the coroutine at the end.
189 * it can be destroyed once the send/broadcast is done
190 */
191 ttg::suspend_never final_suspend() noexcept { return {}; }
192
193 send_coro_state get_return_object() { return send_coro_state{send_coro_handle_type::from_promise(*this)}; }
194
195 /* the send coros only have an empty co_await */
196 ttg::suspend_always await_transform(ttg::Void) { return {}; }
197
198 void unhandled_exception() {
199 std::cerr << "Send coroutine caught an unhandled exception!" << std::endl;
200 throw; // fwd
201 }
202
203 void return_void() {}
204 };
205
206 template <typename Key, typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
207 inline send_coro_state send_coro(const Key &key, Value &&value, ttg::Out<Key, std::decay_t<Value>> &t,
209 ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
210 Key k = key;
211 t.prepare_send(k, std::forward<Value>(value));
212 co_await ttg::Void{}; // we'll come back once the task is done
213 t.send(k, std::forward<Value>(value));
214 };
215
216 template <typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
217 inline send_coro_state sendv_coro(Value &&value, ttg::Out<void, std::decay_t<Value>> &t,
219 ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
220 t.prepare_send(std::forward<Value>(value));
221 co_await ttg::Void{}; // we'll come back once the task is done
222 t.sendv(std::forward<Value>(value));
223 };
224
225 template <typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
226 inline send_coro_state sendk_coro(const Key &key, ttg::Out<Key, void> &t) {
227 // no need to prepare the send but we have to suspend once
228 Key k = key;
229 co_await ttg::Void{}; // we'll come back once the task is done
230 t.sendk(k);
231 };
232
233 template <ttg::Runtime Runtime = ttg::ttg_runtime>
234 inline send_coro_state send_coro(ttg::Out<void, void> &t) {
235 // no need to prepare the send but we have to suspend once
236 co_await ttg::Void{}; // we'll come back once the task is done
237 t.send();
238 };
239
240 struct send_t {
241 send_coro_state coro;
242 };
243 } // namespace detail
244
245 template <size_t i, typename keyT, typename valueT, typename... out_keysT, typename... out_valuesT,
247 inline detail::send_t send(const keyT &key, valueT &&value, std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
249 return detail::send_t{
250 detail::send_coro(key, copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
251 }
252
253 template <size_t i, typename valueT, typename... out_keysT, typename... out_valuesT,
255 inline detail::send_t sendv(valueT &&value, std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
257 return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
258 }
259
260 template <size_t i, typename Key, typename... out_keysT, typename... out_valuesT,
262 inline detail::send_t sendk(const Key &key, std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
263 return detail::send_t{detail::sendk_coro(key, std::get<i>(t))};
264 }
265
266 // clang-format off
271 // clang-format on
272 template <typename keyT, typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
273 inline detail::send_t send(size_t i, const keyT &key, valueT &&value) {
275 auto *terminal_ptr = ttg::detail::get_out_terminal<keyT, valueT>(i, "ttg::device::send(i, key, value)");
276 return detail::send_t{detail::send_coro(key, copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
277 }
278
279 // clang-format off
285 // clang-format on
286 template <size_t i, typename keyT, typename valueT>
287 inline auto send(const keyT &key, valueT &&value) {
288 return ttg::device::send(i, key, std::forward<valueT>(value));
289 }
290
291
292 template <typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
293 inline detail::send_t sendv(std::size_t i, valueT &&value) {
294 auto *terminal_ptr = ttg::detail::get_out_terminal<void, valueT>(i, "ttg::device::send(i, key, value)");
296 return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
297 }
298
299 template <typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
300 inline detail::send_t sendk(std::size_t i, const Key& key) {
301 auto *terminal_ptr = ttg::detail::get_out_terminal<Key, void>(i, "ttg::device::send(i, key, value)");
302 return detail::send_t{detail::sendk_coro(key, *terminal_ptr)};
303 }
304
305 template <ttg::Runtime Runtime = ttg::ttg_runtime>
306 inline detail::send_t send(std::size_t i) {
307 auto *terminal_ptr = ttg::detail::get_out_terminal<void, void>(i, "ttg::device::send(i, key, value)");
308 return detail::send_t{detail::send_coro(*terminal_ptr)};
309 }
310
311
312 template <std::size_t i, typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
313 inline detail::send_t sendv(valueT &&value) {
314 return sendv(i, std::forward<valueT>(value));
315 }
316
317 template <size_t i, typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
318 inline detail::send_t sendk(const Key& key) {
319 return sendk(i, key);
320 }
321
322 template <size_t i, ttg::Runtime Runtime = ttg::ttg_runtime>
323 inline detail::send_t sendk() {
324 return send(i);
325 }
326
327 namespace detail {
328
329 template<typename T, typename Enabler = void>
330 struct broadcast_keylist_trait {
331 using type = T;
332 };
333
334 /* overload for iterable types that extracts the type of the first element */
335 template<typename T>
336 struct broadcast_keylist_trait<T, std::enable_if_t<ttg::meta::is_iterable_v<T>>> {
337 using key_type = decltype(*std::begin(std::declval<T>()));
338 };
339
340 template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT,
341 typename... out_keysT, typename... out_valuesT>
342 inline void prepare_broadcast(const std::tuple<RangesT...> &keylists, valueT &&value,
343 std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
344 std::get<I>(t).prepare_send(std::get<KeyId>(keylists), std::forward<valueT>(value));
345 if constexpr (sizeof...(Is) > 0) {
346 prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value), t);
347 }
348 }
349
350 template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT>
351 inline void prepare_broadcast(const std::tuple<RangesT...> &keylists, valueT &&value) {
352 using key_t = typename broadcast_keylist_trait<
353 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
354 >::key_type;
355 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I, "ttg::device::broadcast(keylists, value)");
356 terminal_ptr->prepare_send(std::get<KeyId>(keylists), value);
357 if constexpr (sizeof...(Is) > 0) {
358 prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value));
359 }
360 }
361
362 template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT,
363 typename... out_keysT, typename... out_valuesT>
364 inline void broadcast(const std::tuple<RangesT...> &keylists, valueT &&value,
365 std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
366 std::get<I>(t).broadcast(std::get<KeyId>(keylists), std::forward<valueT>(value));
367 if constexpr (sizeof...(Is) > 0) {
368 detail::broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value), t);
369 }
370 }
371
372 template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT>
373 inline void broadcast(const std::tuple<RangesT...> &keylists, valueT &&value) {
374 using key_t = typename broadcast_keylist_trait<
375 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
376 >::key_type;
377 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I, "ttg::device::broadcast(keylists, value)");
378 terminal_ptr->broadcast(std::get<KeyId>(keylists), value);
379 if constexpr (sizeof...(Is) > 0) {
380 ttg::device::detail::broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value));
381 }
382 }
383
384 /* overload with explicit terminals */
385 template <size_t I, size_t... Is, typename RangesT, typename valueT,
386 typename... out_keysT, typename... out_valuesT,
388 inline send_coro_state
389 broadcast_coro(RangesT &&keylists, valueT &&value,
390 std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t,
392 ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
393 RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
394 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
395 // treat as tuple
396 prepare_broadcast<0, I, Is...>(kl, std::forward<valueT>(value), t);
397 co_await ttg::Void{}; // we'll come back once the task is done
398 ttg::device::detail::broadcast<0, I, Is...>(kl, std::forward<valueT>(value), t);
399 } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
400 // create a tie to the captured keylist
401 prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<valueT>(value), t);
402 co_await ttg::Void{}; // we'll come back once the task is done
403 ttg::device::detail::broadcast<0, I, Is...>(std::tie(kl), std::forward<valueT>(value), t);
404 }
405 }
406
407 /* overload with implicit terminals */
408 template <size_t I, size_t... Is, typename RangesT, typename valueT,
410 inline send_coro_state
411 broadcast_coro(RangesT &&keylists, valueT &&value,
413 ttg::detail::value_copy_handler<Runtime> copy_handler = std::move(ch); // destroyed at the end of the coro
414 RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
415 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
416 // treat as tuple
417 static_assert(sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
418 "Size of keylist tuple must match the number of output terminals");
419 prepare_broadcast<0, I, Is...>(kl, std::forward<valueT>(value));
420 co_await ttg::Void{}; // we'll come back once the task is done
421 ttg::device::detail::broadcast<0, I, Is...>(kl, std::forward<valueT>(value));
422 } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
423 // create a tie to the captured keylist
424 prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<valueT>(value));
425 co_await ttg::Void{}; // we'll come back once the task is done
426 ttg::device::detail::broadcast<0, I, Is...>(std::tie(kl), std::forward<valueT>(value));
427 }
428 }
429
434 template <size_t KeyId, size_t I, size_t... Is, typename... RangesT,
435 typename... out_keysT, typename... out_valuesT>
436 inline void broadcastk(const std::tuple<RangesT...> &keylists,
437 std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
438 std::get<I>(t).broadcast(std::get<KeyId>(keylists));
439 if constexpr (sizeof...(Is) > 0) {
440 detail::broadcastk<KeyId+1, Is...>(keylists, t);
441 }
442 }
443
444 template <size_t KeyId, size_t I, size_t... Is, typename... RangesT>
445 inline void broadcastk(const std::tuple<RangesT...> &keylists) {
446 using key_t = typename broadcast_keylist_trait<
447 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
448 >::key_type;
449 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, void>(I, "ttg::device::broadcastk(keylists)");
450 terminal_ptr->broadcast(std::get<KeyId>(keylists));
451 if constexpr (sizeof...(Is) > 0) {
452 ttg::device::detail::broadcastk<KeyId+1, Is...>(keylists);
453 }
454 }
455
456 /* overload with explicit terminals */
457 template <size_t I, size_t... Is, typename RangesT,
458 typename... out_keysT, typename... out_valuesT,
460 inline send_coro_state
461 broadcastk_coro(RangesT &&keylists,
462 std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
463 RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
464 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
465 // treat as tuple
466 co_await ttg::Void{}; // we'll come back once the task is done
467 ttg::device::detail::broadcastk<0, I, Is...>(kl, t);
468 } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
469 // create a tie to the captured keylist
470 co_await ttg::Void{}; // we'll come back once the task is done
471 ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl), t);
472 }
473 }
474
475 /* overload with implicit terminals */
476 template <size_t I, size_t... Is, typename RangesT,
478 inline send_coro_state
479 broadcastk_coro(RangesT &&keylists) {
480 RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
481 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
482 // treat as tuple
483 static_assert(sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
484 "Size of keylist tuple must match the number of output terminals");
485 co_await ttg::Void{}; // we'll come back once the task is done
486 ttg::device::detail::broadcastk<0, I, Is...>(kl);
487 } else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
488 // create a tie to the captured keylist
489 co_await ttg::Void{}; // we'll come back once the task is done
490 ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl));
491 }
492 }
493 } // namespace detail
494
495 /* overload with explicit terminals and keylist passed by const reference */
496 template <size_t I, size_t... Is, typename rangeT, typename valueT, typename... out_keysT, typename... out_valuesT,
498 [[nodiscard]]
499 inline detail::send_t broadcast(rangeT &&keylist,
500 valueT &&value,
501 std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
503 return detail::send_t{
504 detail::broadcast_coro<I, Is...>(std::forward<rangeT>(keylist),
505 copy_handler(std::forward<valueT>(value)),
506 t, std::move(copy_handler))};
507 }
508
509 /* overload with implicit terminals and keylist passed by const reference */
510 template <size_t i, typename rangeT, typename valueT,
512 inline detail::send_t broadcast(rangeT &&keylist, valueT &&value) {
514 return detail::send_t{detail::broadcast_coro<i>(std::tie(keylist),
515 copy_handler(std::forward<valueT>(value)),
516 std::move(copy_handler))};
517 }
518
519 /* overload with explicit terminals and keylist passed by const reference */
520 template <size_t I, size_t... Is, typename rangeT, typename... out_keysT, typename... out_valuesT,
522 [[nodiscard]]
523 inline detail::send_t broadcastk(rangeT &&keylist,
524 std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
526 return detail::send_t{
527 detail::broadcastk_coro<I, Is...>(std::forward<rangeT>(keylist), t)};
528 }
529
530 /* overload with implicit terminals and keylist passed by const reference */
531 template <size_t i, typename rangeT,
533 inline detail::send_t broadcastk(rangeT &&keylist) {
534 if constexpr (std::is_rvalue_reference_v<decltype(keylist)>) {
535 return detail::send_t{detail::broadcastk_coro<i>(std::forward<rangeT>(keylist))};
536 } else {
537 return detail::send_t{detail::broadcastk_coro<i>(std::tie(keylist))};
538 }
539 }
540
541 template<typename... Args, ttg::Runtime Runtime = ttg::ttg_runtime>
542 [[nodiscard]]
543 std::vector<device::detail::send_t> forward(Args&&... args) {
544 // TODO: check the cost of this!
545 return std::vector<device::detail::send_t>{std::forward<Args>(args)...};
546 }
547
548 /*******************************************
549 * Device task promise and coroutine handle
550 *******************************************/
551
552 namespace detail {
553 // fwd-decl
554 struct device_task_promise_type;
555 // base type for ttg::device::Task
556 using device_task_handle_type = ttg::coroutine_handle<device_task_promise_type>;
557 } // namespace detail
558
560
567 struct Task : public detail::device_task_handle_type {
568 using base_type = detail::device_task_handle_type;
569
572
573 using promise_type = detail::device_task_promise_type;
574
576
577 Task(base_type base) : base_type(std::move(base)) {}
578
579 base_type& handle() { return *this; }
580
582 inline bool ready() {
583 return true;
584 }
585
587 inline bool completed();
588 };
589
590 namespace detail {
591
592 /* The promise type that stores the views provided by the
593 * application task coroutine on the first co_yield. It subsequently
594 * tracks the state of the task when it moves from waiting for transfers
595 * to waiting for the submitted kernel to complete. */
596 struct device_task_promise_type {
597
598 /* do not suspend the coroutine on first invocation, we want to run
599 * the coroutine immediately and suspend when we get the device transfers.
600 */
601 ttg::suspend_never initial_suspend() {
602 m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
603 return {};
604 }
605
606 /* suspend the coroutine at the end of the execution
607 * so we can access the promise.
608 * TODO: necessary? maybe we can save one suspend here
609 */
610 ttg::suspend_always final_suspend() noexcept {
611 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
612 return {};
613 }
614
615 /* Allow co_await on a tuple */
616 template<typename... Views>
617 ttg::suspend_always await_transform(std::tuple<Views&...> &views) {
618 return yield_value(views);
619 }
620
621 template<typename... Ts>
622 ttg::suspend_always await_transform(detail::to_device_t<Ts...>&& a) {
623 auto arr = detail::extract_buffer_data(a, std::make_index_sequence<sizeof...(Ts)>{});
624 bool need_transfer = !(TTG_IMPL_NS::register_device_memory(ttg::span(arr)));
625 /* TODO: are we allowed to not suspend here and launch the kernel directly? */
626 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
627 return {};
628 }
629
630 ttg::suspend_always await_transform(detail::to_device_t<Input>&& a) {
631 bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.input.span()));
632 /* TODO: are we allowed to not suspend here and launch the kernel directly? */
633 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
634 return {};
635 }
636
637 template<typename... Ts>
638 auto await_transform(detail::wait_kernel_t<Ts...>&& a) {
639 //std::cout << "yield_value: wait_kernel_t" << std::endl;
640 if constexpr (sizeof...(Ts) > 0) {
641 TTG_IMPL_NS::mark_device_out(a.ties);
642 }
643 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_KERNEL;
644 return a;
645 }
646
647 ttg::suspend_always await_transform(std::vector<device::detail::send_t>&& v) {
648 m_sends = std::forward<std::vector<device::detail::send_t>>(v);
649 m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
650 return {};
651 }
652
653 ttg::suspend_always await_transform(device::detail::send_t&& v) {
654 m_sends.clear();
655 m_sends.push_back(std::forward<device::detail::send_t>(v));
656 m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
657 return {};
658 }
659
660 void return_void() {
661 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
662 }
663
664 bool complete() const {
665 return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
666 }
667
668 ttg::device::Task get_return_object() { return {detail::device_task_handle_type::from_promise(*this)}; }
669
670 void unhandled_exception() {
671 std::cerr << "Task coroutine caught an unhandled exception!" << std::endl;
672 throw; // fwd
673 }
674
675 //using iterator = std::vector<device_obj_view>::iterator;
676
677 /* execute all pending send and broadcast operations */
678 void do_sends() {
679 for (auto& send : m_sends) {
680 send.coro();
681 }
682 m_sends.clear();
683 }
684
685 auto state() {
686 return m_state;
687 }
688
689 private:
690 std::vector<device::detail::send_t> m_sends;
691 ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
692
693 };
694
695 } // namespace detail
696
697 bool Task::completed() { return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
698
699 struct device_wait_kernel
700 { };
701
702
703 /* NOTE: below is preliminary for reductions on the device, which is not available yet */
704#if 0
705 /**************************
706 * Device reduction coros *
707 **************************/
708
709 struct device_reducer_promise_type;
710
711 using device_reducer_handle_type = ttg::coroutine_handle<device_reducer_promise_type>;
712
714 struct device_reducer : public device_reducer_handle_type {
715 using base_type = device_reducer_handle_type;
716
719
720 using promise_type = device_reducer_promise_type;
721
723
724 device_reducer(base_type base) : base_type(std::move(base)) {}
725
726 base_type& handle() { return *this; }
727
729 inline bool ready() {
730 return true;
731 }
732
734 inline bool completed();
735 };
736
737
738 /* The promise type that stores the views provided by the
739 * application task coroutine on the first co_yield. It subsequently
740 * tracks the state of the task when it moves from waiting for transfers
741 * to waiting for the submitted kernel to complete. */
742 struct device_reducer_promise_type {
743
744 /* do not suspend the coroutine on first invocation, we want to run
745 * the coroutine immediately and suspend when we get the device transfers.
746 */
747 ttg::suspend_never initial_suspend() {
748 m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
749 return {};
750 }
751
752 /* suspend the coroutine at the end of the execution
753 * so we can access the promise.
754 * TODO: necessary? maybe we can save one suspend here
755 */
756 ttg::suspend_always final_suspend() noexcept {
757 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
758 return {};
759 }
760
761 template<typename... Ts>
762 ttg::suspend_always await_transform(detail::to_device_t<Ts...>&& a) {
763 bool need_transfer = !(TTG_IMPL_NS::register_device_memory(a.ties));
764 /* TODO: are we allowed to not suspend here and launch the kernel directly? */
765 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
766 return {};
767 }
768
769 void return_void() {
770 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
771 }
772
773 bool complete() const {
774 return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
775 }
776
777 device_reducer get_return_object() { return device_reducer{device_reducer_handle_type::from_promise(*this)}; }
778
779 void unhandled_exception() { }
780
781 auto state() {
782 return m_state;
783 }
784
785
786 private:
787 ttg::device::detail::ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
788
789 };
790
791 bool device_reducer::completed() { return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
792#endif // 0
793
794} // namespace ttg::device
795
796#endif // TTG_HAVE_COROUTINE
797
798#endif // TTG_DEVICE_TASK_H
std::enable_if_t<!meta::is_void_v< Key > &&meta::is_void_v< Value >, void > sendk(const Key &key)
Definition terminal.h:476
std::enable_if_t< meta::is_all_void_v< Key, Value >, void > send()
Definition terminal.h:515
A complete version of void.
Definition void.h:12
constexpr auto data(C &c) -> decltype(c.data())
Definition span.h:191
std::integral_constant< bool,(Flags &const_) !=0 > is_const
STL namespace.
void broadcast(const std::tuple< RangesT... > &keylists, valueT &&value, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)
Definition func.h:348
constexpr auto get(typelist< T, RestOfTs... >)
Definition typelist.h:102
TTG_CXX_COROUTINE_NAMESPACE::suspend_always suspend_always
Definition coroutine.h:22
void send(const keyT &key, valueT &&value, ttg::Out< keyT, valueT > &t)
Sends a task id and a value to the given output terminal.
Definition func.h:159
constexpr const ttg::Runtime ttg_runtime
Definition import.h:21
Runtime
Definition runtimes.h:16
TTG_CXX_COROUTINE_NAMESPACE::suspend_never suspend_never
Definition coroutine.h:23
TTG_CXX_COROUTINE_NAMESPACE::coroutine_handle< Promise > coroutine_handle
Definition coroutine.h:25
void sendk(const keyT &key, ttg::Out< keyT, void > &t)
Sends a task id (without an accompanying value) to the given output terminal.
Definition func.h:170
void sendv(valueT &&value, ttg::Out< void, valueT > &t)
Sends a value (without an accompanying task id) to the given output terminal.
Definition func.h:180
void broadcastk(const rangeT &keylist, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)
Definition func.h:453