ttg 1.0.0
Template Task Graph (TTG): flowgraph-based programming model for high-performance distributed-memory algorithms
Loading...
Searching...
No Matches
tt.h
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-3-Clause
2#ifndef TTG_TT_H
3#define TTG_TT_H
4
5#include "ttg/config.h"
6#include "ttg/fwd.h"
7
8#include "ttg/base/tt.h"
9#include "ttg/edge.h"
10
11#ifdef TTG_HAVE_COROUTINE
12#include "ttg/coroutine.h"
13#endif
14
15#include <cassert>
16#include <memory>
17#include <vector>
18
19namespace ttg {
20
21 // TODO describe TT concept (preferably as a C++20 concept)
22 // N.B. TT::op returns void or ttg::coroutine_handle<>
23 // see TTG_PROCESS_TT_OP_RETURN below
24
26
31 template <typename input_terminalsT, typename output_terminalsT>
32 class TTG : public TTBase {
33 public:
34 static constexpr int numins = std::tuple_size_v<input_terminalsT>; // number of input arguments
35 static constexpr int numouts = std::tuple_size_v<output_terminalsT>; // number of outputs or results
36
37 using input_terminals_type = input_terminalsT;
38 using output_terminals_type = output_terminalsT;
39
40 private:
41 std::vector<std::unique_ptr<TTBase>> tts;
44
45 // not copyable
46 TTG(const TTG &) = delete;
47 TTG &operator=(const TTG &) = delete;
48 // movable
49 TTG(TTG &&other)
50 : TTBase(static_cast<TTBase &&>(other))
51 , tts(std::move(other.tts))
52 , ins(std::move(other.ins))
53 , outs(std::move(other.outs)) {
54 is_ttg_ = true;
55 own_my_tts();
56 }
57 TTG &operator=(TTG &&other) {
58 static_cast<TTBase &>(*this) = static_cast<TTBase &&>(other);
59 is_ttg_ = true;
60 tts = std::move(other.tts);
61 ins = std::move(other.ins);
62 outs = std::move(other.outs);
63 own_my_tts();
64 return *this;
65 };
66
67 public:
69 template <typename ttseqT>
70 TTG(ttseqT &&tts,
71 const input_terminals_type &ins, // tuple of pointers to input terminals
72 const output_terminals_type &outs, // tuple of pointers to output terminals
73 const std::string &name = "ttg")
74 : TTBase(name, numins, numouts), tts(std::forward<ttseqT>(tts)), ins(ins), outs(outs) {
75 if (this->tts.size() == 0) throw name + ":TTG: need to wrap at least one TT"; // see fence
76
79 is_ttg_ = true;
80 own_my_tts();
81
82 // traversal is still broken ... need to add checking for composite
83 }
84
86 template <std::size_t i>
87 auto in() {
88 return std::get<i>(ins);
89 }
90
92 template <std::size_t i>
93 auto out() {
94 return std::get<i>(outs);
95 }
96
97 TTBase *get_op(std::size_t i) { return tts.at(i).get(); }
98
99 ttg::World get_world() const override final { return tts[0]->get_world(); }
100
101 void fence() override { tts[0]->fence(); }
102
103 void make_executable() override {
104 for (auto &op : tts) op->make_executable();
105 }
106
107 virtual void print_incomplete_tasks() const override {
108 for (auto& tt : tts) {
109 tt->print_incomplete_tasks();
110 }
111 }
112
113 private:
114 void own_my_tts() const {
115 for (auto &op : tts) op->owning_ttg = this;
116 }
117 };
118
119 template <typename ttseqT, typename input_terminalsT, typename output_terminalsT>
120 auto make_ttg(ttseqT &&tts, const input_terminalsT &ins, const output_terminalsT &outs,
121 const std::string &name = "ttg") {
122 return std::make_unique<TTG<input_terminalsT, output_terminalsT>>(std::forward<ttseqT>(tts), ins, outs, name);
123 }
124
126 template <typename keyT, typename input_valueT>
127 class SinkTT : public TTBase {
128 static constexpr int numins = 1;
129 static constexpr int numouts = 0;
130
131 using input_terminals_type = std::tuple<ttg::In<keyT, input_valueT>>;
132 using input_edges_type = std::tuple<ttg::Edge<keyT, std::decay_t<input_valueT>>>;
133 using output_terminals_type = std::tuple<>;
134
135 private:
136 input_terminals_type input_terminals;
137 output_terminals_type output_terminals;
138
139 SinkTT(const SinkTT &other) = delete;
140 SinkTT &operator=(const SinkTT &other) = delete;
141 SinkTT(SinkTT &&other) = delete;
142 SinkTT &operator=(SinkTT &&other) = delete;
143
144 template <typename terminalT>
145 void register_input_callback(terminalT &input) {
146 using valueT = std::decay_t<typename terminalT::value_type>;
147 auto move_callback = [](const keyT &key, valueT &&value) {};
148 auto send_callback = [](const keyT &key, const valueT &value) {};
149 auto broadcast_callback = [](const ttg::span<const keyT> &key, const valueT &value) {};
150 auto setsize_callback = [](const keyT &key, std::size_t size) {};
151 auto finalize_callback = [](const keyT &key) {};
152
153 input.set_callback(send_callback, move_callback, broadcast_callback, setsize_callback, finalize_callback);
154 }
155
156 public:
157 SinkTT(const std::string &inname = "junk") : TTBase("sink", numins, numouts) {
158 register_input_terminals(input_terminals, std::vector<std::string>{inname});
159 register_input_callback(std::get<0>(input_terminals));
160 }
161
162 SinkTT(const input_edges_type &inedges, const std::string &inname = "junk") : TTBase("sink", numins, numouts) {
163 register_input_terminals(input_terminals, std::vector<std::string>{inname});
164 register_input_callback(std::get<0>(input_terminals));
165 std::get<0>(inedges).set_out(&std::get<0>(input_terminals));
166 }
167
168 virtual ~SinkTT() {}
169
170 void fence() override final {}
171
172 void make_executable() override final { TTBase::make_executable(); }
173
174 World get_world() const override final { return get_default_world(); }
175
177 template <std::size_t i>
178 std::tuple_element_t<i, input_terminals_type> *in() {
179 static_assert(i == 0);
180 return &std::get<i>(input_terminals);
181 }
182 };
183
184} // namespace ttg
185
186#ifndef TTG_PROCESS_TT_OP_RETURN
187#ifdef TTG_HAVE_COROUTINE
188#define TTG_PROCESS_TT_OP_RETURN(result, id, invoke) \
189 { \
190 using return_type = decltype(invoke); \
191 if constexpr (std::is_same_v<return_type, void>) { \
192 invoke; \
193 id = ttg::TaskCoroutineID::Invalid; \
194 } else { \
195 auto coro_return = invoke; \
196 static_assert(std::is_same_v<return_type, void> || \
197 std::is_base_of_v<ttg::coroutine_handle<ttg::resumable_task_state>, decltype(coro_return)>|| \
198 std::is_base_of_v<ttg::coroutine_handle<ttg::device::detail::device_task_promise_type>, \
199 decltype(coro_return)>); \
200 if constexpr (std::is_base_of_v<ttg::coroutine_handle<ttg::resumable_task_state>, decltype(coro_return)>) \
201 id = ttg::TaskCoroutineID::ResumableTask; \
202 else if constexpr (std::is_base_of_v< \
203 ttg::coroutine_handle<ttg::device::detail::device_task_promise_type>, \
204 decltype(coro_return)>) \
205 id = ttg::TaskCoroutineID::DeviceTask; \
206 else \
207 std::abort(); \
208 result = coro_return.address(); \
209 } \
210 }
211#else
212#define TTG_PROCESS_TT_OP_RETURN(result, id, invoke) invoke
213#endif
214#else
215#error "TTG_PROCESS_TT_OP_RETURN already defined in ttg/tt.h, check your header guards"
216#endif // !defined(TTG_PROCESS_TT_OP_RETURN)
217
218#endif // TTG_TT_H
A data sink for one input.
Definition tt.h:127
SinkTT(const std::string &inname="junk")
Definition tt.h:157
std::tuple_element_t< i, input_terminals_type > * in()
Returns pointer to input terminal i to facilitate connection — terminal cannot be copied,...
Definition tt.h:178
void fence() override final
Definition tt.h:170
virtual ~SinkTT()
Definition tt.h:168
World get_world() const override final
Definition tt.h:174
SinkTT(const input_edges_type &inedges, const std::string &inname="junk")
Definition tt.h:162
void make_executable() override final
Marks this executable.
Definition tt.h:172
A base class for all template tasks.
Definition tt.h:32
void set_terminals(std::index_sequence< IS... >, terminalsT &terms, const setfuncT setfunc)
Definition tt.h:100
virtual void make_executable()=0
Marks this executable.
Definition tt.h:292
void register_input_terminals(terminalsT &terms, const namesT &names)
Definition tt.h:86
friend class TTG
Definition tt.h:42
a template task graph implementation
Definition tt.h:32
auto in()
Return a pointer to i'th input terminal.
Definition tt.h:87
auto out()
Return a pointer to i'th output terminal.
Definition tt.h:93
void fence() override
Definition tt.h:101
TTBase * get_op(std::size_t i)
Definition tt.h:97
input_terminalsT input_terminals_type
Definition tt.h:37
ttg::World get_world() const override final
Definition tt.h:99
virtual void print_incomplete_tasks() const override
Definition tt.h:107
static constexpr int numins
Definition tt.h:34
void make_executable() override
Marks this executable.
Definition tt.h:103
static constexpr int numouts
Definition tt.h:35
output_terminalsT output_terminals_type
Definition tt.h:38
TTG(ttseqT &&tts, const input_terminals_type &ins, const output_terminals_type &outs, const std::string &name="ttg")
Definition tt.h:70
STL namespace.
top-level TTG namespace contains runtime-neutral functionality
Definition keymap.h:9
int size(World world=default_execution_context())
Definition run.h:131
ttg::World & get_default_world()
Definition world.h:81
auto make_ttg(ttseqT &&tts, const input_terminalsT &ins, const output_terminalsT &outs, const std::string &name="ttg")
Definition tt.h:120