1 #ifndef TTG_DEVICE_TASK_H
2 #define TTG_DEVICE_TASK_H
12 #ifdef TTG_HAVE_COROUTINE
17 template <
typename... Ts>
19 std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
30 template <
typename... Args>
32 inline auto select(Args &&...args) {
33 return detail::to_device_t<std::remove_reference_t<Args>...>{std::tie(std::forward<Args>(args)...)};
38 enum ttg_device_coro_state {
39 TTG_DEVICE_CORO_STATE_NONE,
41 TTG_DEVICE_CORO_WAIT_TRANSFER,
42 TTG_DEVICE_CORO_WAIT_KERNEL,
43 TTG_DEVICE_CORO_SENDOUT,
44 TTG_DEVICE_CORO_COMPLETE
47 template <
typename... Ts>
48 struct wait_kernel_t {
49 std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
52 constexpr
bool await_ready() const noexcept {
return false; }
55 template <
typename Promise>
58 void await_resume() noexcept {
59 if constexpr (
sizeof...(Ts) > 0) {
72 template <
typename... Buffers>
74 inline auto wait(Buffers &&...args) {
78 "Only ttg::Buffer and ttg::devicescratch can be waited on!");
79 return detail::wait_kernel_t<std::remove_reference_t<Buffers>...>{std::tie(std::forward<Buffers>(args)...)};
90 struct send_coro_promise_type;
95 struct send_coro_state :
public send_coro_handle_type {
96 using base_type = send_coro_handle_type;
101 using promise_type = send_coro_promise_type;
105 send_coro_state(base_type base) : base_type(std::move(base)) {}
107 base_type &handle() {
return *
this; }
110 inline bool ready() {
return true; }
113 inline bool completed();
117 struct send_coro_promise_type {
128 send_coro_state get_return_object() {
return send_coro_state{send_coro_handle_type::from_promise(*
this)}; }
133 void unhandled_exception() {
134 std::cerr <<
"Send coroutine caught an unhandled exception!" << std::endl;
138 void return_void() {}
141 template <
typename Key,
typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
142 inline send_coro_state send_coro(
const Key &key, Value &&value,
ttg::Out<Key, std::decay_t<Value>> &t,
146 t.prepare_send(k, std::forward<Value>(value));
148 t.send(k, std::forward<Value>(value));
151 template <
typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
152 inline send_coro_state sendv_coro(Value &&value,
ttg::Out<
void, std::decay_t<Value>> &t,
155 t.prepare_send(std::forward<Value>(value));
157 t.sendv(std::forward<Value>(value));
160 template <
typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
168 template <ttg::Runtime Runtime = ttg::ttg_runtime>
176 send_coro_state coro;
180 template <
size_t i,
typename keyT,
typename valueT,
typename... out_keysT,
typename... out_valuesT,
184 return detail::send_t{
185 detail::send_coro(key, copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
188 template <
size_t i,
typename valueT,
typename... out_keysT,
typename... out_valuesT,
192 return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
195 template <
size_t i,
typename Key,
typename... out_keysT,
typename... out_valuesT,
198 return detail::send_t{detail::sendk_coro(key, std::get<i>(t))};
207 template <
typename keyT,
typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
208 inline detail::send_t
send(
size_t i,
const keyT &key, valueT &&value) {
210 auto *terminal_ptr = ttg::detail::get_out_terminal<keyT, valueT>(i,
"ttg::device::send(i, key, value)");
211 return detail::send_t{detail::send_coro(key, copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
221 template <
size_t i,
typename keyT,
typename valueT>
222 inline auto send(
const keyT &key, valueT &&value) {
227 template <
typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
228 inline detail::send_t
sendv(std::size_t i, valueT &&value) {
229 auto *terminal_ptr = ttg::detail::get_out_terminal<void, valueT>(i,
"ttg::device::send(i, key, value)");
231 return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
234 template <
typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
235 inline detail::send_t
sendk(std::size_t i,
const Key& key) {
236 auto *terminal_ptr = ttg::detail::get_out_terminal<Key, void>(i,
"ttg::device::send(i, key, value)");
237 return detail::send_t{detail::sendk_coro(key, *terminal_ptr)};
240 template <ttg::Runtime Runtime = ttg::ttg_runtime>
241 inline detail::send_t
send(std::size_t i) {
242 auto *terminal_ptr = ttg::detail::get_out_terminal<void, void>(i,
"ttg::device::send(i, key, value)");
243 return detail::send_t{detail::send_coro(*terminal_ptr)};
247 template <std::size_t i,
typename valueT,
typename... out_keysT,
typename... out_valuesT,
249 inline detail::send_t
sendv(valueT &&value) {
250 return sendv(i, std::forward<valueT>(value));
253 template <
size_t i,
typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
254 inline detail::send_t
sendk(
const Key& key) {
255 return sendk(i, key);
258 template <
size_t i, ttg::Runtime Runtime = ttg::ttg_runtime>
259 inline detail::send_t
sendk() {
265 template<
typename T,
typename Enabler =
void>
266 struct broadcast_keylist_trait {
272 struct broadcast_keylist_trait<T, std::enable_if_t<ttg::meta::is_iterable_v<T>>> {
273 using key_type = decltype(*std::begin(std::declval<T>()));
276 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT,
277 typename... out_keysT,
typename... out_valuesT>
278 inline void prepare_broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value,
280 std::get<I>(t).prepare_send(std::get<KeyId>(keylists), std::forward<valueT>(value));
281 if constexpr (
sizeof...(Is) > 0) {
282 prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value), t);
286 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT,
287 typename... out_keysT,
typename... out_valuesT>
288 inline void prepare_broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value) {
289 using key_t =
typename broadcast_keylist_trait<
290 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
292 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I,
"ttg::device::broadcast(keylists, value)");
293 terminal_ptr->prepare_send(std::get<KeyId>(keylists), value);
294 if constexpr (
sizeof...(Is) > 0) {
295 prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value));
299 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT,
300 typename... out_keysT,
typename... out_valuesT>
301 inline void broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value,
303 std::get<I>(t).broadcast(std::get<KeyId>(keylists), std::forward<valueT>(value));
304 if constexpr (
sizeof...(Is) > 0) {
309 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT>
310 inline void broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value) {
311 using key_t =
typename broadcast_keylist_trait<
312 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
314 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I,
"ttg::device::broadcast(keylists, value)");
315 terminal_ptr->broadcast(std::get<KeyId>(keylists), value);
316 if constexpr (
sizeof...(Is) > 0) {
322 template <
size_t I,
size_t... Is,
typename RangesT,
typename valueT,
323 typename... out_keysT,
typename... out_valuesT,
325 inline send_coro_state
326 broadcast_coro(RangesT &&keylists, valueT &&value,
330 RangesT kl = std::forward<RangesT>(keylists);
331 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
333 prepare_broadcast<0, I, Is...>(kl, std::forward<std::decay_t<decltype(value)>>(value), t);
336 }
else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
338 prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value), t);
345 template <
size_t I,
size_t... Is,
typename RangesT,
typename valueT,
347 inline send_coro_state
348 broadcast_coro(RangesT &&keylists, valueT &&value,
351 RangesT kl = std::forward<RangesT>(keylists);
352 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
354 static_assert(
sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
355 "Size of keylist tuple must match the number of output terminals");
356 prepare_broadcast<0, I, Is...>(kl, std::forward<std::decay_t<decltype(value)>>(value));
359 }
else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
361 prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value));
371 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
372 typename... out_keysT,
typename... out_valuesT>
373 inline void broadcastk(
const std::tuple<RangesT...> &keylists,
375 std::get<I>(t).broadcast(std::get<KeyId>(keylists));
376 if constexpr (
sizeof...(Is) > 0) {
381 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT>
382 inline void broadcastk(
const std::tuple<RangesT...> &keylists) {
383 using key_t =
typename broadcast_keylist_trait<
384 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
386 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, void>(I,
"ttg::device::broadcastk(keylists)");
387 terminal_ptr->broadcast(std::get<KeyId>(keylists));
388 if constexpr (
sizeof...(Is) > 0) {
394 template <
size_t I,
size_t... Is,
typename RangesT,
395 typename... out_keysT,
typename... out_valuesT,
397 inline send_coro_state
398 broadcastk_coro(RangesT &&keylists,
400 RangesT kl = std::forward<RangesT>(keylists);
401 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
405 }
else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
413 template <
size_t I,
size_t... Is,
typename RangesT,
415 inline send_coro_state
416 broadcastk_coro(RangesT &&keylists) {
417 RangesT kl = std::forward<RangesT>(keylists);
418 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
420 static_assert(
sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
421 "Size of keylist tuple must match the number of output terminals");
424 }
else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
433 template <
size_t I,
size_t... Is,
typename rangeT,
typename valueT,
typename... out_keysT,
typename... out_valuesT,
436 inline detail::send_t
broadcast(rangeT &&keylist,
440 return detail::send_t{
441 detail::broadcast_coro<I, Is...>(std::forward<rangeT>(keylist),
442 copy_handler(std::forward<valueT>(value)),
443 t, std::move(copy_handler))};
447 template <
size_t i,
typename rangeT,
typename valueT,
449 inline detail::send_t
broadcast(rangeT &&keylist, valueT &&value) {
451 return detail::send_t{broadcast_coro<i>(std::tie(keylist), copy_handler(std::forward<valueT>(value)),
452 std::move(copy_handler))};
456 template <
size_t I,
size_t... Is,
typename rangeT,
typename... out_keysT,
typename... out_valuesT,
459 inline detail::send_t
broadcastk(rangeT &&keylist,
462 return detail::send_t{
463 detail::broadcastk_coro<I, Is...>(std::forward<rangeT>(keylist), t)};
467 template <
size_t i,
typename rangeT,
469 inline detail::send_t
broadcastk(rangeT &&keylist) {
470 if constexpr (std::is_rvalue_reference_v<decltype(keylist)>) {
471 return detail::send_t{detail::broadcastk_coro<i>(std::forward<rangeT>(keylist))};
473 return detail::send_t{detail::broadcastk_coro<i>(std::tie(keylist))};
479 std::vector<device::detail::send_t> forward(Args&&... args) {
481 return std::vector<device::detail::send_t>{std::forward<Args>(args)...};
490 struct device_task_promise_type;
503 struct Task :
public detail::device_task_handle_type {
504 using base_type = detail::device_task_handle_type;
509 using promise_type = detail::device_task_promise_type;
513 Task(base_type base) : base_type(std::move(base)) {}
515 base_type& handle() {
return *
this; }
518 inline bool ready() {
523 inline bool completed();
532 struct device_task_promise_type {
538 m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
547 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
552 template<
typename... Views>
554 return yield_value(views);
557 template<
typename... Ts>
561 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
565 template<
typename... Ts>
566 auto await_transform(detail::wait_kernel_t<Ts...>&& a) {
568 if constexpr (
sizeof...(Ts) > 0) {
571 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_KERNEL;
576 m_sends = std::forward<std::vector<device::detail::send_t>>(v);
577 m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
583 m_sends.push_back(std::forward<device::detail::send_t>(v));
584 m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
589 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
592 bool complete()
const {
593 return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
596 ttg::device::Task get_return_object() {
return {detail::device_task_handle_type::from_promise(*
this)}; }
598 void unhandled_exception() {
599 std::cerr <<
"Task coroutine caught an unhandled exception!" << std::endl;
607 for (
auto&
send : m_sends) {
618 std::vector<device::detail::send_t> m_sends;
619 ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
625 bool Task::completed() {
return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
627 struct device_wait_kernel
637 struct device_reducer_promise_type;
642 struct device_reducer :
public device_reducer_handle_type {
643 using base_type = device_reducer_handle_type;
648 using promise_type = device_reducer_promise_type;
652 device_reducer(base_type base) : base_type(std::move(base)) {}
654 base_type& handle() {
return *
this; }
657 inline bool ready() {
662 inline bool completed();
670 struct device_reducer_promise_type {
676 m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
685 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
689 template<
typename... Ts>
693 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
698 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
701 bool complete()
const {
702 return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
705 device_reducer get_return_object() {
return device_reducer{device_reducer_handle_type::from_promise(*
this)}; }
707 void unhandled_exception() { }
715 ttg::device::detail::ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
719 bool device_reducer::completed() {
return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
std::enable_if_t< meta::is_all_void_v< Key, Value >, void > send()
std::enable_if_t<!meta::is_void_v< Key > &&meta::is_void_v< Value >, void > sendk(const Key &key)
A complete version of void.
void broadcast(const std::tuple< RangesT... > &keylists, valueT &&value, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)
void mark_device_out(std::tuple< Buffer &... > &b)
bool register_device_memory(std::tuple< Views &... > &views)
void post_device_out(std::tuple< Buffer &... > &b)
TTG_CXX_COROUTINE_NAMESPACE::suspend_always suspend_always
void send(const keyT &key, valueT &&value, ttg::Out< keyT, valueT > &t)
Sends a task id and a value to the given output terminal.
constexpr const ttg::Runtime ttg_runtime
TTG_CXX_COROUTINE_NAMESPACE::suspend_never suspend_never
TTG_CXX_COROUTINE_NAMESPACE::coroutine_handle< Promise > coroutine_handle
void sendk(const keyT &key, ttg::Out< keyT, void > &t)
Sends a task id (without an accompanying value) to the given output terminal.
void sendv(valueT &&value, ttg::Out< void, valueT > &t)
Sends a value (without an accompanying task id) to the given output terminal.
void broadcast(const rangeT &keylist, valueT &&value, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)
void broadcastk(const rangeT &keylist, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)