tt.h
Go to the documentation of this file.
1 #ifndef TTG_TT_H
2 #define TTG_TT_H
3 
4 #include "ttg/config.h"
5 #include "ttg/fwd.h"
6 
7 #include "ttg/base/tt.h"
8 #include "ttg/edge.h"
9 
10 #ifdef TTG_HAVE_COROUTINE
11 #include "ttg/coroutine.h"
12 #endif
13 
14 #include <cassert>
15 #include <memory>
16 #include <vector>
17 
18 namespace ttg {
19 
20  // TODO describe TT concept (preferably as a C++20 concept)
21  // N.B. TT::op returns void or ttg::coroutine_handle<>
22  // see TTG_PROCESS_TT_OP_RETURN below
23 
25 
30  template <typename input_terminalsT, typename output_terminalsT>
31  class TTG : public TTBase {
32  public:
33  static constexpr int numins = std::tuple_size_v<input_terminalsT>; // number of input arguments
34  static constexpr int numouts = std::tuple_size_v<output_terminalsT>; // number of outputs or results
35 
36  using input_terminals_type = input_terminalsT;
37  using output_terminals_type = output_terminalsT;
38 
39  private:
40  std::vector<std::unique_ptr<TTBase>> tts;
43 
44  // not copyable
45  TTG(const TTG &) = delete;
46  TTG &operator=(const TTG &) = delete;
47  // movable
48  TTG(TTG &&other)
49  : TTBase(static_cast<TTBase &&>(other))
50  , tts(other.tts)
51  , ins(std::move(other.ins))
52  , outs(std::move(other.outs)) {
53  is_ttg_ = true;
54  own_my_tts();
55  }
56  TTG &operator=(TTG &&other) {
57  static_cast<TTBase &>(*this) = static_cast<TTBase &&>(other);
58  is_ttg_ = true;
59  tts = std::move(other.tts);
60  ins = std::move(other.ins);
61  outs = std::move(other.outs);
62  own_my_tts();
63  return *this;
64  };
65 
66  public:
68  template <typename ttseqT>
69  TTG(ttseqT &&tts,
70  const input_terminals_type &ins, // tuple of pointers to input terminals
71  const output_terminals_type &outs, // tuple of pointers to output terminals
72  const std::string &name = "ttg")
73  : TTBase(name, numins, numouts), tts(std::forward<ttseqT>(tts)), ins(ins), outs(outs) {
74  if (this->tts.size() == 0) throw name + ":TTG: need to wrap at least one TT"; // see fence
75 
78  is_ttg_ = true;
79  own_my_tts();
80 
81  // traversal is still broken ... need to add checking for composite
82  }
83 
85  template <std::size_t i>
86  auto in() {
87  return std::get<i>(ins);
88  }
89 
91  template <std::size_t i>
92  auto out() {
93  return std::get<i>(outs);
94  }
95 
96  TTBase *get_op(std::size_t i) { return tts.at(i).get(); }
97 
98  ttg::World get_world() const override final { return tts[0]->get_world(); }
99 
100  void fence() override { tts[0]->fence(); }
101 
102  void make_executable() override {
103  for (auto &op : tts) op->make_executable();
104  }
105 
106  private:
107  void own_my_tts() const {
108  for (auto &op : tts) op->owning_ttg = this;
109  }
110  };
111 
112  template <typename ttseqT, typename input_terminalsT, typename output_terminalsT>
113  auto make_ttg(ttseqT &&tts, const input_terminalsT &ins, const output_terminalsT &outs,
114  const std::string &name = "ttg") {
115  return std::make_unique<TTG<input_terminalsT, output_terminalsT>>(std::forward<ttseqT>(tts), ins, outs, name);
116  }
117 
119  template <typename keyT, typename input_valueT>
120  class SinkTT : public TTBase {
121  static constexpr int numins = 1;
122  static constexpr int numouts = 0;
123 
124  using input_terminals_type = std::tuple<ttg::In<keyT, input_valueT>>;
125  using input_edges_type = std::tuple<ttg::Edge<keyT, std::decay_t<input_valueT>>>;
126  using output_terminals_type = std::tuple<>;
127 
128  private:
129  input_terminals_type input_terminals;
130  output_terminals_type output_terminals;
131 
132  SinkTT(const SinkTT &other) = delete;
133  SinkTT &operator=(const SinkTT &other) = delete;
134  SinkTT(SinkTT &&other) = delete;
135  SinkTT &operator=(SinkTT &&other) = delete;
136 
137  template <typename terminalT>
138  void register_input_callback(terminalT &input) {
139  using valueT = std::decay_t<typename terminalT::value_type>;
140  auto move_callback = [](const keyT &key, valueT &&value) {};
141  auto send_callback = [](const keyT &key, const valueT &value) {};
142  auto broadcast_callback = [](const ttg::span<const keyT> &key, const valueT &value) {};
143  auto setsize_callback = [](const keyT &key, std::size_t size) {};
144  auto finalize_callback = [](const keyT &key) {};
145 
146  input.set_callback(send_callback, move_callback, broadcast_callback, setsize_callback, finalize_callback);
147  }
148 
149  public:
150  SinkTT(const std::string &inname = "junk") : TTBase("sink", numins, numouts) {
151  register_input_terminals(input_terminals, std::vector<std::string>{inname});
152  register_input_callback(std::get<0>(input_terminals));
153  }
154 
155  SinkTT(const input_edges_type &inedges, const std::string &inname = "junk") : TTBase("sink", numins, numouts) {
156  register_input_terminals(input_terminals, std::vector<std::string>{inname});
157  register_input_callback(std::get<0>(input_terminals));
158  std::get<0>(inedges).set_out(&std::get<0>(input_terminals));
159  }
160 
161  virtual ~SinkTT() {}
162 
163  void fence() override final {}
164 
165  void make_executable() override final { TTBase::make_executable(); }
166 
167  World get_world() const override final { return get_default_world(); }
168 
170  template <std::size_t i>
171  std::tuple_element_t<i, input_terminals_type> *in() {
172  static_assert(i == 0);
173  return &std::get<i>(input_terminals);
174  }
175  };
176 
177 } // namespace ttg
178 
179 #ifndef TTG_PROCESS_TT_OP_RETURN
180 #ifdef TTG_HAVE_COROUTINE
181 #define TTG_PROCESS_TT_OP_RETURN(result, id, invoke) \
182  { \
183  using return_type = decltype(invoke); \
184  if constexpr (std::is_same_v<return_type, void>) { \
185  invoke; \
186  id = ttg::TaskCoroutineID::Invalid; \
187  } else { \
188  auto coro_return = invoke; \
189  static_assert(std::is_same_v<return_type, void> || \
190  std::is_base_of_v<ttg::coroutine_handle<ttg::resumable_task_state>, decltype(coro_return)>|| \
191  std::is_base_of_v<ttg::coroutine_handle<ttg::device::detail::device_task_promise_type>, \
192  decltype(coro_return)>); \
193  if constexpr (std::is_base_of_v<ttg::coroutine_handle<ttg::resumable_task_state>, decltype(coro_return)>) \
194  id = ttg::TaskCoroutineID::ResumableTask; \
195  else if constexpr (std::is_base_of_v< \
196  ttg::coroutine_handle<ttg::device::detail::device_task_promise_type>, \
197  decltype(coro_return)>) \
198  id = ttg::TaskCoroutineID::DeviceTask; \
199  else \
200  std::abort(); \
201  result = coro_return.address(); \
202  } \
203  }
204 #else
205 #define TTG_PROCESS_TT_OP_RETURN(result, id, invoke) invoke
206 #endif
207 #else
208 #error "TTG_PROCESS_TT_OP_RETURN already defined in ttg/tt.h, check your header guards"
209 #endif // !defined(TTG_PROCESS_TT_OP_RETURN)
210 
211 #endif // TTG_TT_H
A data sink for one input.
Definition: tt.h:120
SinkTT(const std::string &inname="junk")
Definition: tt.h:150
void fence() override final
Definition: tt.h:163
std::tuple_element_t< i, input_terminals_type > * in()
Returns pointer to input terminal i to facilitate connection — terminal cannot be copied,...
Definition: tt.h:171
virtual ~SinkTT()
Definition: tt.h:161
World get_world() const override final
Definition: tt.h:167
SinkTT(const input_edges_type &inedges, const std::string &inname="junk")
Definition: tt.h:155
void make_executable() override final
Marks this executable.
Definition: tt.h:165
A base class for all template tasks.
Definition: tt.h:31
void set_terminals(std::index_sequence< IS... >, terminalsT &terms, const setfuncT setfunc)
Definition: tt.h:99
virtual void make_executable()=0
Marks this executable.
Definition: tt.h:287
void register_input_terminals(terminalsT &terms, const namesT &names)
Definition: tt.h:85
friend class TTG
Definition: tt.h:41
a template task graph implementation
Definition: tt.h:31
auto in()
Return a pointer to i'th input terminal.
Definition: tt.h:86
auto out()
Return a pointer to i'th output terminal.
Definition: tt.h:92
void fence() override
Definition: tt.h:100
input_terminalsT input_terminals_type
Definition: tt.h:36
ttg::World get_world() const override final
Definition: tt.h:98
static constexpr int numins
Definition: tt.h:33
TTBase * get_op(std::size_t i)
Definition: tt.h:96
void make_executable() override
Marks this executable.
Definition: tt.h:102
static constexpr int numouts
Definition: tt.h:34
output_terminalsT output_terminals_type
Definition: tt.h:37
TTG(ttseqT &&tts, const input_terminals_type &ins, const output_terminals_type &outs, const std::string &name="ttg")
Definition: tt.h:69
top-level TTG namespace contains runtime-neutral functionality
Definition: keymap.h:8
int size(World world=default_execution_context())
Definition: run.h:89
ttg::World & get_default_world()
Definition: world.h:80
auto make_ttg(ttseqT &&tts, const input_terminalsT &ins, const output_terminalsT &outs, const std::string &name="ttg")
Definition: tt.h:113