1 #ifndef TTG_DEVICE_TASK_H
2 #define TTG_DEVICE_TASK_H
12 #ifdef TTG_HAVE_COROUTINE
17 template <
typename... Ts>
19 std::tuple<std::add_lvalue_reference_t<Ts>...> ties;
30 template <
typename... Args>
32 inline auto select(Args &&...args) {
33 return detail::to_device_t<std::remove_reference_t<Args>...>{std::tie(std::forward<Args>(args)...)};
38 enum ttg_device_coro_state {
39 TTG_DEVICE_CORO_STATE_NONE,
41 TTG_DEVICE_CORO_WAIT_TRANSFER,
42 TTG_DEVICE_CORO_WAIT_KERNEL,
43 TTG_DEVICE_CORO_SENDOUT,
44 TTG_DEVICE_CORO_COMPLETE
47 template <
typename... Ts>
48 struct wait_kernel_t {
49 std::tuple<Ts &...> ties;
52 constexpr
bool await_ready() const noexcept {
return false; }
55 template <
typename Promise>
58 void await_resume() noexcept {
59 if constexpr (
sizeof...(Ts) > 0) {
72 template <
typename... Buffers>
74 inline auto wait(Buffers &&...args) {
78 "Only ttg::Buffer and ttg::devicescratch can be waited on!");
79 return detail::wait_kernel_t<std::remove_reference_t<Buffers>...>{std::tie(std::forward<Buffers>(args)...)};
90 struct send_coro_promise_type;
95 struct send_coro_state :
public send_coro_handle_type {
96 using base_type = send_coro_handle_type;
101 using promise_type = send_coro_promise_type;
105 send_coro_state(base_type base) : base_type(std::move(base)) {}
107 base_type &handle() {
return *
this; }
110 inline bool ready() {
return true; }
113 inline bool completed();
117 struct send_coro_promise_type {
128 send_coro_state get_return_object() {
return send_coro_state{send_coro_handle_type::from_promise(*
this)}; }
133 void unhandled_exception() {
134 std::cerr <<
"Send coroutine caught an unhandled exception!" << std::endl;
138 void return_void() {}
141 template <
typename Key,
typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
142 inline send_coro_state send_coro(
const Key &key, Value &&value,
ttg::Out<Key, std::decay_t<Value>> &t,
146 t.prepare_send(k, std::forward<Value>(value));
148 t.send(k, std::forward<Value>(value));
151 template <
typename Value, ttg::Runtime Runtime = ttg::ttg_runtime>
152 inline send_coro_state sendv_coro(Value &&value,
ttg::Out<
void, std::decay_t<Value>> &t,
155 t.prepare_send(std::forward<Value>(value));
157 t.sendv(std::forward<Value>(value));
160 template <
typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
168 template <ttg::Runtime Runtime = ttg::ttg_runtime>
176 send_coro_state coro;
180 template <
size_t i,
typename keyT,
typename valueT,
typename... out_keysT,
typename... out_valuesT,
184 return detail::send_t{
185 detail::send_coro(key, copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
188 template <
size_t i,
typename valueT,
typename... out_keysT,
typename... out_valuesT,
192 return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), std::get<i>(t), copy_handler)};
195 template <
size_t i,
typename Key,
typename... out_keysT,
typename... out_valuesT,
198 return detail::send_t{detail::sendk_coro(key, std::get<i>(t))};
207 template <
typename keyT,
typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
208 inline detail::send_t
send(
size_t i,
const keyT &key, valueT &&value) {
210 auto *terminal_ptr = ttg::detail::get_out_terminal<keyT, valueT>(i,
"ttg::device::send(i, key, value)");
211 return detail::send_t{detail::send_coro(key, copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
221 template <
size_t i,
typename keyT,
typename valueT>
222 inline auto send(
const keyT &key, valueT &&value) {
227 template <
typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
228 inline detail::send_t
sendv(std::size_t i, valueT &&value) {
229 auto *terminal_ptr = ttg::detail::get_out_terminal<void, valueT>(i,
"ttg::device::send(i, key, value)");
231 return detail::send_t{detail::sendv_coro(copy_handler(std::forward<valueT>(value)), *terminal_ptr, copy_handler)};
234 template <
typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
235 inline detail::send_t
sendk(std::size_t i,
const Key& key) {
236 auto *terminal_ptr = ttg::detail::get_out_terminal<Key, void>(i,
"ttg::device::send(i, key, value)");
237 return detail::send_t{detail::sendk_coro(key, *terminal_ptr)};
240 template <ttg::Runtime Runtime = ttg::ttg_runtime>
241 inline detail::send_t
send(std::size_t i) {
242 auto *terminal_ptr = ttg::detail::get_out_terminal<void, void>(i,
"ttg::device::send(i, key, value)");
243 return detail::send_t{detail::send_coro(*terminal_ptr)};
247 template <std::size_t i,
typename valueT,
typename... out_keysT,
typename... out_valuesT,
249 inline detail::send_t
sendv(valueT &&value) {
250 return sendv(i, std::forward<valueT>(value));
253 template <
size_t i,
typename Key, ttg::Runtime Runtime = ttg::ttg_runtime>
254 inline detail::send_t
sendk(
const Key& key) {
255 return sendk(i, key);
258 template <
size_t i, ttg::Runtime Runtime = ttg::ttg_runtime>
259 inline detail::send_t
sendk() {
265 template<
typename T,
typename Enabler =
void>
266 struct broadcast_keylist_trait {
272 struct broadcast_keylist_trait<T, std::enable_if_t<ttg::meta::is_iterable_v<T>>> {
273 using key_type = decltype(*std::begin(std::get<0>(std::declval<T>())));
276 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT,
277 typename... out_keysT,
typename... out_valuesT>
278 inline void prepare_broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value,
280 std::get<I>(t).prepare_send(std::get<KeyId>(keylists), std::forward<valueT>(value));
281 if constexpr (
sizeof...(Is) > 0) {
282 prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value), t);
286 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT,
287 typename... out_keysT,
typename... out_valuesT>
288 inline void prepare_broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value) {
289 using key_t =
typename broadcast_keylist_trait<
290 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
292 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I,
"ttg::device::broadcast(keylists, value)");
293 terminal_ptr->prepare_send(std::get<KeyId>(keylists), value);
294 if constexpr (
sizeof...(Is) > 0) {
295 prepare_broadcast<KeyId+1, Is...>(keylists, std::forward<valueT>(value));
299 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT,
300 typename... out_keysT,
typename... out_valuesT>
301 inline void broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value,
303 std::get<I>(t).broadcast(std::get<KeyId>(keylists), std::forward<valueT>(value));
304 if constexpr (
sizeof...(Is) > 0) {
309 template <
size_t KeyId,
size_t I,
size_t... Is,
typename... RangesT,
typename valueT,
310 typename... out_keysT,
typename... out_valuesT>
311 inline void broadcast(
const std::tuple<RangesT...> &keylists, valueT &&value) {
312 using key_t =
typename broadcast_keylist_trait<
313 std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
315 auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, valueT>(I,
"ttg::device::broadcast(keylists, value)");
316 terminal_ptr->broadcast(std::get<KeyId>(keylists), value);
317 if constexpr (
sizeof...(Is) > 0) {
323 template <
size_t I,
size_t... Is,
typename RangesT,
typename valueT,
324 typename... out_keysT,
typename... out_valuesT,
326 inline send_coro_state
327 broadcast_coro(RangesT &&keylists, valueT &&value,
331 RangesT kl = std::forward<RangesT>(keylists);
332 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
334 prepare_broadcast<0, I, Is...>(kl, std::forward<std::decay_t<decltype(value)>>(value), t);
337 }
else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
339 prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value), t);
346 template <
size_t I,
size_t... Is,
typename RangesT,
typename valueT,
348 inline send_coro_state
349 broadcast_coro(RangesT &&keylists, valueT &&value,
352 RangesT kl = std::forward<RangesT>(keylists);
353 if constexpr (ttg::meta::is_tuple_v<RangesT>) {
355 static_assert(
sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
356 "Size of keylist tuple must match the number of output terminals");
357 prepare_broadcast<0, I, Is...>(kl, std::forward<std::decay_t<decltype(value)>>(value));
360 }
else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
362 prepare_broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value));
370 template <
size_t I,
size_t... Is,
typename rangeT,
typename valueT,
typename... out_keysT,
typename... out_valuesT,
373 inline detail::send_t
broadcast(rangeT &&keylist,
377 return detail::send_t{
378 detail::broadcast_coro<I, Is...>(std::forward<rangeT>(keylist),
379 copy_handler(std::forward<valueT>(value)),
380 t, std::move(copy_handler))};
384 template <
size_t i,
typename rangeT,
typename valueT,
386 inline detail::send_t
broadcast(rangeT &&keylist, valueT &&value) {
388 return detail::send_t{broadcast_coro<i>(std::tie(keylist), copy_handler(std::forward<valueT>(value)),
389 std::move(copy_handler))};
394 std::vector<device::detail::send_t> forward(Args&&... args) {
396 return std::vector{std::forward<Args>(args)...};
405 struct device_task_promise_type;
418 struct Task :
public detail::device_task_handle_type {
419 using base_type = detail::device_task_handle_type;
424 using promise_type = detail::device_task_promise_type;
428 Task(base_type base) : base_type(std::move(base)) {}
430 base_type& handle() {
return *
this; }
433 inline bool ready() {
438 inline bool completed();
447 struct device_task_promise_type {
453 m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
462 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
467 template<
typename... Views>
469 return yield_value(views);
472 template<
typename... Ts>
476 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
480 template<
typename... Ts>
481 auto await_transform(detail::wait_kernel_t<Ts...>&& a) {
483 if constexpr (
sizeof...(Ts) > 0) {
486 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_KERNEL;
491 m_sends = std::forward<std::vector<device::detail::send_t>>(v);
492 m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
498 m_sends.push_back(std::forward<device::detail::send_t>(v));
499 m_state = ttg::device::detail::TTG_DEVICE_CORO_SENDOUT;
504 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
507 bool complete()
const {
508 return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
511 ttg::device::Task get_return_object() {
return {detail::device_task_handle_type::from_promise(*
this)}; }
513 void unhandled_exception() {
514 std::cerr <<
"Task coroutine caught an unhandled exception!" << std::endl;
522 for (
auto&
send : m_sends) {
533 std::vector<device::detail::send_t> m_sends;
534 ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
540 bool Task::completed() {
return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
542 struct device_wait_kernel
552 struct device_reducer_promise_type;
557 struct device_reducer :
public device_reducer_handle_type {
558 using base_type = device_reducer_handle_type;
563 using promise_type = device_reducer_promise_type;
567 device_reducer(base_type base) : base_type(std::move(base)) {}
569 base_type& handle() {
return *
this; }
572 inline bool ready() {
577 inline bool completed();
585 struct device_reducer_promise_type {
591 m_state = ttg::device::detail::TTG_DEVICE_CORO_INIT;
600 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
604 template<
typename... Ts>
608 m_state = ttg::device::detail::TTG_DEVICE_CORO_WAIT_TRANSFER;
613 m_state = ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
616 bool complete()
const {
617 return m_state == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE;
620 device_reducer get_return_object() {
return device_reducer{device_reducer_handle_type::from_promise(*
this)}; }
622 void unhandled_exception() { }
630 ttg::device::detail::ttg_device_coro_state m_state = ttg::device::detail::TTG_DEVICE_CORO_STATE_NONE;
634 bool device_reducer::completed() {
return base_type::promise().state() == ttg::device::detail::TTG_DEVICE_CORO_COMPLETE; }
std::enable_if_t< meta::is_all_void_v< Key, Value >, void > send()
std::enable_if_t<!meta::is_void_v< Key > &&meta::is_void_v< Value >, void > sendk(const Key &key)
A complete version of void.
void broadcast(const std::tuple< RangesT... > &keylists, valueT &&value, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)
void mark_device_out(std::tuple< Buffer &... > &b)
bool register_device_memory(std::tuple< Views &... > &views)
void post_device_out(std::tuple< Buffer &... > &b)
TTG_CXX_COROUTINE_NAMESPACE::suspend_always suspend_always
void send(const keyT &key, valueT &&value, ttg::Out< keyT, valueT > &t)
Sends a task id and a value to the given output terminal.
constexpr const ttg::Runtime ttg_runtime
TTG_CXX_COROUTINE_NAMESPACE::suspend_never suspend_never
TTG_CXX_COROUTINE_NAMESPACE::coroutine_handle< Promise > coroutine_handle
void sendk(const keyT &key, ttg::Out< keyT, void > &t)
Sends a task id (without an accompanying value) to the given output terminal.
void sendv(valueT &&value, ttg::Out< void, valueT > &t)
Sends a value (without an accompanying task id) to the given output terminal.
void broadcast(const rangeT &keylist, valueT &&value, std::tuple< ttg::Out< out_keysT, out_valuesT >... > &t)