ttg 1.0.0
Template Task Graph (TTG): flowgraph-based programming model for high-performance distributed-memory algorithms
Loading...
Searching...
No Matches
tt.h
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-3-Clause
2#ifndef TTG_TT_H
3#define TTG_TT_H
4
5#include "ttg/config.h"
6#include "ttg/fwd.h"
7
8#include "ttg/base/tt.h"
9#include "ttg/edge.h"
10
11#ifdef TTG_HAVE_COROUTINE
12#include "ttg/coroutine.h"
13#endif
14
15#include <cassert>
16#include <memory>
17#include <vector>
18
19namespace ttg {
20
21 // TODO describe TT concept (preferably as a C++20 concept)
22 // N.B. TT::op returns void or ttg::coroutine_handle<>
23 // see TTG_PROCESS_TT_OP_RETURN below
24
26
31 template <typename input_terminalsT, typename output_terminalsT>
32 class TTG : public TTBase {
33 public:
34 static constexpr int numins = std::tuple_size_v<input_terminalsT>; // number of input arguments
35 static constexpr int numouts = std::tuple_size_v<output_terminalsT>; // number of outputs or results
36
37 using input_terminals_type = input_terminalsT;
38 using output_terminals_type = output_terminalsT;
39
40 private:
41 std::vector<std::unique_ptr<TTBase>> tts;
44
45 // not copyable
46 TTG(const TTG &) = delete;
47 TTG &operator=(const TTG &) = delete;
48 // movable
49 TTG(TTG &&other)
50 : TTBase(static_cast<TTBase &&>(other))
51 , tts(std::move(other.tts))
52 , ins(std::move(other.ins))
53 , outs(std::move(other.outs)) {
54 is_ttg_ = true;
55 own_my_tts();
56 }
57 TTG &operator=(TTG &&other) {
58 static_cast<TTBase &>(*this) = static_cast<TTBase &&>(other);
59 is_ttg_ = true;
60 tts = std::move(other.tts);
61 ins = std::move(other.ins);
62 outs = std::move(other.outs);
63 own_my_tts();
64 return *this;
65 };
66
67 public:
69 template <typename ttseqT>
70 TTG(ttseqT &&tts,
71 const input_terminals_type &ins, // tuple of pointers to input terminals
72 const output_terminals_type &outs, // tuple of pointers to output terminals
73 const std::string &name = "ttg")
74 : TTBase(name, numins, numouts), tts(std::forward<ttseqT>(tts)), ins(ins), outs(outs) {
75 if (this->tts.size() == 0) throw name + ":TTG: need to wrap at least one TT"; // see fence
76
79 is_ttg_ = true;
80 own_my_tts();
81
82 // traversal is still broken ... need to add checking for composite
83 }
84
86 template <std::size_t i>
87 auto in() {
88 return std::get<i>(ins);
89 }
90
92 template <std::size_t i>
93 auto out() {
94 return std::get<i>(outs);
95 }
96
97 TTBase *get_op(std::size_t i) { return tts.at(i).get(); }
98
99 ttg::World get_world() const override final { return tts[0]->get_world(); }
100
101 void fence() override { tts[0]->fence(); }
102
103 void make_executable() override {
104 for (auto &op : tts) op->make_executable();
105 }
106
107 private:
108 void own_my_tts() const {
109 for (auto &op : tts) op->owning_ttg = this;
110 }
111 };
112
113 template <typename ttseqT, typename input_terminalsT, typename output_terminalsT>
114 auto make_ttg(ttseqT &&tts, const input_terminalsT &ins, const output_terminalsT &outs,
115 const std::string &name = "ttg") {
116 return std::make_unique<TTG<input_terminalsT, output_terminalsT>>(std::forward<ttseqT>(tts), ins, outs, name);
117 }
118
120 template <typename keyT, typename input_valueT>
121 class SinkTT : public TTBase {
122 static constexpr int numins = 1;
123 static constexpr int numouts = 0;
124
125 using input_terminals_type = std::tuple<ttg::In<keyT, input_valueT>>;
126 using input_edges_type = std::tuple<ttg::Edge<keyT, std::decay_t<input_valueT>>>;
127 using output_terminals_type = std::tuple<>;
128
129 private:
130 input_terminals_type input_terminals;
131 output_terminals_type output_terminals;
132
133 SinkTT(const SinkTT &other) = delete;
134 SinkTT &operator=(const SinkTT &other) = delete;
135 SinkTT(SinkTT &&other) = delete;
136 SinkTT &operator=(SinkTT &&other) = delete;
137
138 template <typename terminalT>
139 void register_input_callback(terminalT &input) {
140 using valueT = std::decay_t<typename terminalT::value_type>;
141 auto move_callback = [](const keyT &key, valueT &&value) {};
142 auto send_callback = [](const keyT &key, const valueT &value) {};
143 auto broadcast_callback = [](const ttg::span<const keyT> &key, const valueT &value) {};
144 auto setsize_callback = [](const keyT &key, std::size_t size) {};
145 auto finalize_callback = [](const keyT &key) {};
146
147 input.set_callback(send_callback, move_callback, broadcast_callback, setsize_callback, finalize_callback);
148 }
149
150 public:
151 SinkTT(const std::string &inname = "junk") : TTBase("sink", numins, numouts) {
152 register_input_terminals(input_terminals, std::vector<std::string>{inname});
153 register_input_callback(std::get<0>(input_terminals));
154 }
155
156 SinkTT(const input_edges_type &inedges, const std::string &inname = "junk") : TTBase("sink", numins, numouts) {
157 register_input_terminals(input_terminals, std::vector<std::string>{inname});
158 register_input_callback(std::get<0>(input_terminals));
159 std::get<0>(inedges).set_out(&std::get<0>(input_terminals));
160 }
161
162 virtual ~SinkTT() {}
163
164 void fence() override final {}
165
166 void make_executable() override final { TTBase::make_executable(); }
167
168 World get_world() const override final { return get_default_world(); }
169
171 template <std::size_t i>
172 std::tuple_element_t<i, input_terminals_type> *in() {
173 static_assert(i == 0);
174 return &std::get<i>(input_terminals);
175 }
176 };
177
178} // namespace ttg
179
180#ifndef TTG_PROCESS_TT_OP_RETURN
181#ifdef TTG_HAVE_COROUTINE
182#define TTG_PROCESS_TT_OP_RETURN(result, id, invoke) \
183 { \
184 using return_type = decltype(invoke); \
185 if constexpr (std::is_same_v<return_type, void>) { \
186 invoke; \
187 id = ttg::TaskCoroutineID::Invalid; \
188 } else { \
189 auto coro_return = invoke; \
190 static_assert(std::is_same_v<return_type, void> || \
191 std::is_base_of_v<ttg::coroutine_handle<ttg::resumable_task_state>, decltype(coro_return)>|| \
192 std::is_base_of_v<ttg::coroutine_handle<ttg::device::detail::device_task_promise_type>, \
193 decltype(coro_return)>); \
194 if constexpr (std::is_base_of_v<ttg::coroutine_handle<ttg::resumable_task_state>, decltype(coro_return)>) \
195 id = ttg::TaskCoroutineID::ResumableTask; \
196 else if constexpr (std::is_base_of_v< \
197 ttg::coroutine_handle<ttg::device::detail::device_task_promise_type>, \
198 decltype(coro_return)>) \
199 id = ttg::TaskCoroutineID::DeviceTask; \
200 else \
201 std::abort(); \
202 result = coro_return.address(); \
203 } \
204 }
205#else
206#define TTG_PROCESS_TT_OP_RETURN(result, id, invoke) invoke
207#endif
208#else
209#error "TTG_PROCESS_TT_OP_RETURN already defined in ttg/tt.h, check your header guards"
210#endif // !defined(TTG_PROCESS_TT_OP_RETURN)
211
212#endif // TTG_TT_H
A data sink for one input.
Definition tt.h:121
SinkTT(const std::string &inname="junk")
Definition tt.h:151
std::tuple_element_t< i, input_terminals_type > * in()
Returns pointer to input terminal i to facilitate connection — terminal cannot be copied,...
Definition tt.h:172
void fence() override final
Definition tt.h:164
virtual ~SinkTT()
Definition tt.h:162
World get_world() const override final
Definition tt.h:168
SinkTT(const input_edges_type &inedges, const std::string &inname="junk")
Definition tt.h:156
void make_executable() override final
Marks this executable.
Definition tt.h:166
A base class for all template tasks.
Definition tt.h:32
void set_terminals(std::index_sequence< IS... >, terminalsT &terms, const setfuncT setfunc)
Definition tt.h:100
virtual void make_executable()=0
Marks this executable.
Definition tt.h:288
void register_input_terminals(terminalsT &terms, const namesT &names)
Definition tt.h:86
friend class TTG
Definition tt.h:42
a template task graph implementation
Definition tt.h:32
auto in()
Return a pointer to i'th input terminal.
Definition tt.h:87
auto out()
Return a pointer to i'th output terminal.
Definition tt.h:93
void fence() override
Definition tt.h:101
TTBase * get_op(std::size_t i)
Definition tt.h:97
input_terminalsT input_terminals_type
Definition tt.h:37
ttg::World get_world() const override final
Definition tt.h:99
static constexpr int numins
Definition tt.h:34
void make_executable() override
Marks this executable.
Definition tt.h:103
static constexpr int numouts
Definition tt.h:35
output_terminalsT output_terminals_type
Definition tt.h:38
TTG(ttseqT &&tts, const input_terminals_type &ins, const output_terminals_type &outs, const std::string &name="ttg")
Definition tt.h:70
STL namespace.
top-level TTG namespace contains runtime-neutral functionality
Definition keymap.h:9
int size(World world=default_execution_context())
Definition run.h:131
ttg::World & get_default_world()
Definition world.h:81
auto make_ttg(ttseqT &&tts, const input_terminalsT &ins, const output_terminalsT &outs, const std::string &name="ttg")
Definition tt.h:114