1 #ifndef TTG_DEVICE_TASK_H
2 #define TTG_DEVICE_TASK_H
14 #ifdef TTG_HAVE_COROUTINE
20 struct device_input_data_t {
26 impl_data_t impl_data;
32 template <
typename... Ts>
34 std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
38 template<
typename... Ts, std::size_t... Is>
39 auto extract_buffer_data(detail::to_device_t<Ts...>& a, std::index_sequence<Is...>) {
40 using arg_types = std::tuple<Ts...>;
41 return std::array<device_input_data_t,
sizeof...(Ts)>{
43 std::get<Is>(a.ties).scope(),
44 ttg::meta::is_const_v<std::tuple_element_t<Is, arg_types>>,
45 ttg::meta::is_devicescratch_v<std::tuple_element_t<Is, arg_types>>}...};
51 std::vector<detail::device_input_data_t> m_data;
55 template<
typename... Args>
58 std::is_const_v<std::remove_reference_t<Args>>,
59 ttg::meta::is_devicescratch_v<std::decay_t<Args>>}...}
64 using type = std::remove_reference_t<T>;
66 ttg::meta::is_devicescratch_v<type>);
69 ttg::span<detail::device_input_data_t> span() {
70 return ttg::span(m_data);
77 struct to_device_t<Input> {
89 template <
typename... Args>
91 inline auto select(Args &&...args) {
92 return detail::to_device_t<std::remove_reference_t<Args>...>{std::tie(std::forward<Args>(args)...)};
96 inline auto select(Input& input) {
97 return detail::to_device_t<Input>{input};
102 enum ttg_device_coro_state {
103 TTG_DEVICE_CORO_STATE_NONE,
104 TTG_DEVICE_CORO_INIT,
105 TTG_DEVICE_CORO_WAIT_TRANSFER,
106 TTG_DEVICE_CORO_WAIT_KERNEL,
107 TTG_DEVICE_CORO_SENDOUT,
108 TTG_DEVICE_CORO_COMPLETE
111 template <
typename... Ts>
112 struct wait_kernel_t {
113 std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
116 constexpr
bool await_ready() const noexcept {
return false; }
119 template <
typename Promise>
122 void await_resume() noexcept {
123 if constexpr (
sizeof...(Ts) > 0) {
136 template <
typename... Buffers>
138 inline auto wait(Buffers &&...args) {
142 "Only ttg::Buffer and ttg::devicescratch can be waited on!");
143 return detail::wait_kernel_t<std::remove_reference_t<Buffers>...>{std::tie(std::forward<Buffers>(args)...)};
154 struct send_coro_promise_type;
159 struct send_coro_state :
public send_coro_handle_type {
160 using base_type = send_coro_handle_type;
165 using promise_type = send_coro_promise_type;
169 send_coro_state(base_type base) : base_type(std::move(base)) {}
171 base_type &handle() {
return *
this; }
174 inline bool ready() {
return true; }
177 inline bool completed();
181 struct send_coro_promise_type {
192 send_coro_state get_return_object() {
return send_coro_state{send_coro_handle_type::from_promise(*
this)}; }
197 void unhandled_exception() {
198 std::cerr <<
"Send coroutine caught an unhandled exception!" << std::endl;
202 void return_void() {}
205 template <
typename Key,
typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
206 inline send_coro_state send_coro(
const Key &key, Value &&value,
ttg::Out<Key, std::decay_t<Value>> &t,
210 t.prepare_send(k, std::forward<Value>(value));
212 t.send(k, std::forward<Value>(value));
215 template <
typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
216 inline send_coro_state sendv_coro(Value &&value,
ttg::Out<
void, std::decay_t<Value>> &t,
219 t.prepare_send(std::forward<Value>(value));
221 t.sendv(std::forward<Value>(value));
224 template <
typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
232 template <ttg::Runtime Runtime = ttg::ttg_runtime>
240 send_coro_state coro;
244 template <
size_t i,
typename keyT,
typename valueT,
typename... out_keysT,
typename... out_valuesT,
248 return detail::send_t{
249 detail::send_coro(key, copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
252 template <
size_t i,
typename valueT,
typename... out_keysT,
typename... out_valuesT,
256 return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
259 template <
size_t i,
typename Key,
typename... out_keysT,
typename... out_valuesT,
262 return detail::send_t{detail::sendk_coro(key, std::get<i>(t))};
271 template <
typename keyT,
typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
272 inline detail::send_t
send(
size_t i,
const keyT &key, valueT &&value) {
274 auto *terminal_ptr = ttg::detail::get_out_terminal<keyT, valueT>(i,
"ttg::device::send(i, key, value)");
275 return detail::send_t{detail::send_coro(key, copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
285 template <
size_t i,
typename keyT,
typename valueT>
286 inline auto send(
const keyT &key, valueT &&value) {
291 template <
typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
292 inline detail::send_t
sendv(std::size_t i, valueT &&value) {
293 auto *terminal_ptr = ttg::detail::get_out_terminal<void, valueT>(i,
"ttg::device::send(i, key, value)");
295 return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
298 template <
typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
299 inline detail::send_t
sendk(std::size_t i,
const Key& key) {
300 auto *terminal_ptr = ttg::detail::get_out_terminal<Key, void>(i,
"ttg::device::send(i, key, value)");
301 return detail::send_t{detail::sendk_coro(key, *terminal_ptr)};
304 template <ttg::Runtime Runtime = ttg::ttg_runtime>
305 inline detail::send_t
send(std::size_t i) {
306 auto *terminal_ptr = ttg::detail::get_out_terminal<void, void>(i,
"ttg::device::send(i, key, value)");
307 return detail::send_t{detail::send_coro(*terminal_ptr)};
311 template <std::
size_t i,
typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
312 inline detail::send_t
sendv(valueT &&value) {
313 return sendv(i, std::forward<valueT>(value));
316 template <
size_t i,
typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
317 inline detail::send_t
sendk(
const Key& key) {
318 return sendk(i, key);
321 template <
size_t i, ttg::Runtime Runtime = ttg::ttg_runtime>
322 inline detail::send_t
sendk() {
328 template<
typename T,
typename Enabler =
void>
329 struct broadcast_keylist_trait {
335 struct broadcast_keylist_trait<T, std::enable_if_t<ttg::meta::is_iterable_v<T>>> {
336 using key_type = decltype(*std::begin(std::declval<T>()));
339 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT,
340 typename... out_keysT,
typename... out_valuesT>
341 inline void prepare_broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value,
343 std::get<I>(t).prepare_send(std::get<KeyId>(keylists), std::forward<valueT>(value));
344 if constexpr (
sizeof...(Is) > 0) {
345 prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value), t);
349 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT>
350 inline void prepare_broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value) {
351 using key_t =
typename broadcast_keylist_trait<
352 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
354 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I,
"ttg::device::broadcast(keylists, value)");
355 terminal_ptr->prepare_send(std::get<KeyId>(keylists), value);
356 if constexpr (
sizeof...(Is) > 0) {
357 prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value));
361 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT,
362 typename... out_keysT,
typename... out_valuesT>
363 inline void broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value,
365 std::get<I>(t).broadcast(std::get<KeyId>(keylists), std::forward<valueT>(value));
366 if constexpr (
sizeof...(Is) > 0) {
371 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT>
372 inline void broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value) {
373 using key_t =
typename broadcast_keylist_trait<
374 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
376 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I,
"ttg::device::broadcast(keylists, value)");
377 terminal_ptr->broadcast(std::get<KeyId>(keylists), value);
378 if constexpr (
sizeof...(Is) > 0) {
384 template <
size_t I,
size_t... Is,
typename RangesT,
typename valueT,
385 typename... out_keysT,
typename... out_valuesT,
387 inline send_coro_state
388 broadcast_coro(RangesT &&keylists, valueT &&value,
392 RangesT kl = std::forward<RangesT>(keylists);
393 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
395 prepare_broadcast<0, I, Is...>(kl, std::forward<std::decay_t<decltype(value)>>(value), t);
398 }
else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
400 prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value), t);
407 template <
size_t I,
size_t... Is,
typename RangesT,
typename valueT,
409 inline send_coro_state
410 broadcast_coro(RangesT &&keylists, valueT &&value,
413 RangesT kl = std::forward<RangesT>(keylists);
414 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
416 static_assert(
sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
417 "Size of keylist tuple must match the number of output terminals");
418 prepare_broadcast<0, I, Is...>(kl, std::forward<std::decay_t<decltype(value)>>(value));
421 }
else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
423 prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value));
433 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
434 typename... out_keysT,
typename... out_valuesT>
435 inline void broadcastk(
const std::tuple<RangesT...> &keylists,
437 std::get<I>(t).broadcast(std::get<KeyId>(keylists));
438 if constexpr (
sizeof...(Is) > 0) {
443 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT>
444 inline void broadcastk(
const std::tuple<RangesT...> &keylists) {
445 using key_t =
typename broadcast_keylist_trait<
446 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
448 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, void>(I,
"ttg::device::broadcastk(keylists)");
449 terminal_ptr->broadcast(std::get<KeyId>(keylists));
450 if constexpr (
sizeof...(Is) > 0) {
456 template <
size_t I,
size_t... Is,
typename RangesT,
457 typename... out_keysT,
typename... out_valuesT,
459 inline send_coro_state
460 broadcastk_coro(RangesT &&keylists,
462 RangesT kl = std::forward<RangesT>(keylists);
463 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
467 }
else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
475 template <
size_t I,
size_t... Is,
typename RangesT,
477 inline send_coro_state
478 broadcastk_coro(RangesT &&keylists) {
479 RangesT kl = std::forward<RangesT>(keylists);
480 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
482 static_assert(
sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
483 "Size of keylist tuple must match the number of output terminals");
486 }
else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
495 template <
size_t I,
size_t... Is,
typename rangeT,
typename valueT,
typename... out_keysT,
typename... out_valuesT,
498 inline detail::send_t
broadcast(rangeT &&keylist,
502 return detail::send_t{
503 detail::broadcast_coro<I, Is...>(std::forward<rangeT>(keylist),
504 copy_handler(std::forward<valueT>(value)),
505 t, std::move(copy_handler))};
509 template <
size_t i,
typename rangeT,
typename valueT,
511 inline detail::send_t
broadcast(rangeT &&keylist, valueT &&value) {
513 return detail::send_t{detail::broadcast_coro<i>(std::tie(keylist),
514 copy_handler(std::forward<valueT>(value)),
515 std::move(copy_handler))};
519 template <
size_t I,
size_t... Is,
typename rangeT,
typename... out_keysT,
typename... out_valuesT,
522 inline detail::send_t
broadcastk(rangeT &&keylist,
525 return detail::send_t{
526 detail::broadcastk_coro<I, Is...>(std::forward<rangeT>(keylist), t)};
530 template <
size_t i,
typename rangeT,
532 inline detail::send_t
broadcastk(rangeT &&keylist) {
533 if constexpr (std::is_rvalue_reference_v<decltype(keylist)>) {
534 return detail::send_t{detail::broadcastk_coro<i>(std::forward<rangeT>(keylist))};
536 return detail::send_t{detail::broadcastk_coro<i>(std::tie(keylist))};
542 std::vector<device::detail::send_t> forward(Args&&... args) {
544 return std::vector<device::detail::send_t>{std::forward<Args>(args)...};
553 struct device_task_promise_type;
566 struct Task :
public detail::device_task_handle_type {
567 using base_type = detail::device_task_handle_type;
572 using promise_type = detail::device_task_promise_type;
576 Task(base_type base) : base_type(std::move(base)) {}
578 base_type& handle() {
return *
this; }
581 inline bool ready() {
586 inline bool completed();
595 struct device_task_promise_type {
601 m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
610 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
615 template<
typename... Views>
617 return yield_value(views);
620 template<
typename... Ts>
625 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
632 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
636 template<
typename... Ts>
637 auto await_transform(detail::wait_kernel_t<Ts...>&& a) {
639 if constexpr (
sizeof...(Ts) > 0) {
642 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_KERNEL;
647 m_sends = std::forward<std::vector<device::detail::send_t>>(v);
648 m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
654 m_sends.push_back(std::forward<device::detail::send_t>(v));
655 m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
660 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
663 bool complete()
const {
664 return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
667 ttg::device::Task get_return_object() {
return {detail::device_task_handle_type::from_promise(*
this)}; }
669 void unhandled_exception() {
670 std::cerr <<
"Task coroutine caught an unhandled exception!" << std::endl;
678 for (
auto&
send : m_sends) {
689 std::vector<device::detail::send_t> m_sends;
690 ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
696 bool Task::completed() {
return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
698 struct device_wait_kernel
708 struct device_reducer_promise_type;
713 struct device_reducer :
public device_reducer_handle_type {
714 using base_type = device_reducer_handle_type;
719 using promise_type = device_reducer_promise_type;
723 device_reducer(base_type base) : base_type(std::move(base)) {}
725 base_type& handle() {
return *
this; }
728 inline bool ready() {
733 inline bool completed();
741 struct device_reducer_promise_type {
747 m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
756 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
760 template<
typename... Ts>
764 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
769 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
772 bool complete()
const {
773 return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
776 device_reducer get_return_object() {
return device_reducer{device_reducer_handle_type::from_promise(*
this)}; }
778 void unhandled_exception() { }
786 ttg::device::detail::ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
790 bool device_reducer::completed() {
return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
std::enable_if_t< meta::is_all_void_v< Key, Value >, void > send()
std::enable_if_t<!meta::is_void_v< Key > &&meta::is_void_v< Value >, void > sendk(const Key &key)
A complete version of void.
constexpr auto data(C &c) -> decltype(c.data())
typename make_index_sequence_t< I... >::type make_index_sequence
std::integral_constant< bool,(Flags &const_) !=0 > is_const
void broadcast(const std::tuple< RangesT... > &keylists, valueT &&value, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)
auto buffer_data(const Buffer< T, A > &buffer)
void mark_device_out(std::tuple< Buffer &... > &b)
bool register_device_memory(std::tuple< Views &... > &views)
void post_device_out(std::tuple< Buffer &... > &b)
TTG_CXX_COROUTINE_NAMESPACE::suspend_always suspend_always
void send(const keyT &key, valueT &&value, ttg::Out< keyT, valueT > &t)
Sends a task id and a value to the given output terminal.
constexpr const ttg::Runtime ttg_runtime
TTG_CXX_COROUTINE_NAMESPACE::suspend_never suspend_never
TTG_CXX_COROUTINE_NAMESPACE::coroutine_handle< Promise > coroutine_handle
void sendk(const keyT &key, ttg::Out< keyT, void > &t)
Sends a task id (without an accompanying value) to the given output terminal.
void sendv(valueT &&value, ttg::Out< void, valueT > &t)
Sends a value (without an accompanying task id) to the given output terminal.
void broadcastk(const rangeT &keylist, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)