ttg 1.0.0
Template Task Graph (TTG): flowgraph-based programming model for high-performance distributed-memory algorithms
Loading...
Searching...
No Matches
traverse.h
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-3-Clause
2#ifndef TTG_TRAVERSE_H
3#define TTG_TRAVERSE_H
4
5#include <iostream>
6#include <set>
7
8#include "ttg/tt.h"
9#include "ttg/util/meta.h"
10
11namespace ttg {
12
13 namespace detail {
15 class Traverse {
16 std::set<TTBase *> seen;
17
18 bool visited(TTBase *p) { return !seen.insert(p).second; }
19
20 public:
21 virtual void ttfunc(TTBase *tt) = 0;
22
23 virtual void infunc(TerminalBase *in) = 0;
24
25 virtual void outfunc(TerminalBase *out) = 0;
26
27 void reset() { seen.clear(); }
28
29 // Returns true if no null pointers encountered (i.e., if all
30 // encountered terminals/operations are connected)
31 bool traverse(TTBase *tt) {
32 if (!tt) {
33 std::cout << "ttg::Traverse: got a null op!\n";
34 return false;
35 }
36
37 if (visited(tt)) return true;
38
39 bool status = true;
40
41 ttfunc(tt);
42
43 int count = 0;
44 for (auto in : tt->get_inputs()) {
45 if (!in) {
46 std::cout << "ttg::Traverse: got a null in!\n";
47 status = false;
48 } else {
49 infunc(in);
50 if (!in->is_connected()) {
51 std::cout << "ttg::Traverse: " << tt->get_name() << " input terminal #" << count << " " << in->get_name()
52 << " is not connected\n";
53 status = false;
54 }
55 }
56 count++;
57 }
58
59 for (auto in : tt->get_inputs()) {
60 if (in) {
61 for (auto predecessor : in->get_predecessors()) {
62 if (!predecessor) {
63 std::cout << "ttg::Traverse: got a null predecessor!\n";
64 status = false;
65 } else {
66 status &= traverse(predecessor->get_tt());
67 }
68 }
69 }
70 }
71
72 count = 0;
73 for (auto out : tt->get_outputs()) {
74 if (!out) {
75 std::cout << "ttg::Traverse: got a null out!\n";
76 status = false;
77 } else {
78 outfunc(out);
79 if (!out->is_connected()) {
80 std::cout << "ttg::Traverse: " << tt->get_name() << " output terminal #" << count << " "
81 << out->get_name() << " is not connected\n";
82 status = false;
83 }
84 }
85 count++;
86 }
87
88 for (auto out : tt->get_outputs()) {
89 if (out) {
90 for (auto successor : out->get_connections()) {
91 if (!successor) {
92 std::cout << "ttg::Traverse: got a null successor!\n";
93 status = false;
94 } else {
95 status &= traverse(successor->get_tt());
96 }
97 }
98 }
99 }
100
101 return status;
102 }
103
104 template <typename TT>
105 std::enable_if_t<std::is_base_of_v<TTBase, TT> && !std::is_same_v<TT, TTBase>,
106 bool>
108 return traverse(static_cast<TTBase*>(tt));
109 }
110
111 template <typename TT>
112 std::enable_if_t<std::is_base_of_v<TTBase, TT>,
113 bool>
114 traverse(const std::shared_ptr<TTBase>& tt) {
115 return traverse(tt.get());
116 }
117
118 template <typename TT, typename Deleter>
119 std::enable_if_t<std::is_base_of_v<TTBase, TT>,
120 bool>
121 traverse(const std::unique_ptr<TT, Deleter>& tt) {
122 return traverse(tt.get());
123 }
124
127 template <typename Visitable>
130 void operator()(Visitable*) {};
132 void operator()(const Visitable*) {};
133 };
134
135 };
136 } // namespace detail
137
142 template <typename TTVisitor = detail::Traverse::null_visitor<TTBase>,
143 typename InVisitor = detail::Traverse::null_visitor<TerminalBase>,
144 typename OutVisitor = detail::Traverse::null_visitor<TerminalBase>>
145 class Traverse : private detail::Traverse {
146 public:
147 static_assert(
148 std::is_void_v<meta::void_t<decltype(std::declval<TTVisitor>()(std::declval<TTBase *>()))>>,
149 "Traverse<TTVisitor,...>: TTVisitor(TTBase *op) must be a valid expression");
150 static_assert(
151 std::is_void_v<meta::void_t<decltype(std::declval<InVisitor>()(std::declval<TerminalBase *>()))>>,
152 "Traverse<,InVisitor,>: InVisitor(TerminalBase *op) must be a valid expression");
153 static_assert(
154 std::is_void_v<meta::void_t<decltype(std::declval<OutVisitor>()(std::declval<TerminalBase *>()))>>,
155 "Traverse<...,OutVisitor>: OutVisitor(TerminalBase *op) must be a valid expression");
156
157 template <typename TTVisitor_ = detail::Traverse::null_visitor<TTBase>,
158 typename InVisitor_ = detail::Traverse::null_visitor<TerminalBase>,
159 typename OutVisitor_ = detail::Traverse::null_visitor<TerminalBase>>
160 Traverse(TTVisitor_ &&tt_v = TTVisitor_{}, InVisitor_ &&in_v = InVisitor_{}, OutVisitor_ &&out_v = OutVisitor_{})
161 : tt_visitor_(std::forward<TTVisitor_>(tt_v))
162 , in_visitor_(std::forward<InVisitor_>(in_v))
163 , out_visitor_(std::forward<OutVisitor_>(out_v)){};
164
165 const TTVisitor &tt_visitor() const { return tt_visitor_; }
166 const InVisitor &in_visitor() const { return in_visitor_; }
167 const OutVisitor &out_visitor() const { return out_visitor_; }
168
170 template <typename TTBasePtr, typename ... TTBasePtrs>
171 std::enable_if_t<std::is_base_of_v<TTBase, std::decay_t<decltype(*(std::declval<TTBasePtr>()))>> && (std::is_base_of_v<TTBase, std::decay_t<decltype(*(std::declval<TTBasePtrs>()))>> && ...),
172 bool>
174 TTBasePtr&& op, TTBasePtrs && ... ops) {
175 reset();
176 bool result = traverse_all(std::forward<TTBasePtr>(op), std::forward<TTBasePtrs>(ops)...);
177 reset();
178 return result;
179 }
180
181 private:
182 TTVisitor tt_visitor_;
183 InVisitor in_visitor_;
184 OutVisitor out_visitor_;
185
186 template <typename TTBasePtr, typename ... TTBasePtrs>
187 bool traverse_all(TTBasePtr&& op, TTBasePtrs && ... ops) {
188 bool result = traverse(op);
189 if constexpr(sizeof...(ops) > 0) {
190 result &= traverse_all(std::forward<TTBasePtrs>(ops)...);
191 }
192 return result;
193 }
194
195 void ttfunc(TTBase *tt) { tt_visitor_(tt); }
196
197 void infunc(TerminalBase *in) { in_visitor_(in); }
198
199 void outfunc(TerminalBase *out) { out_visitor_(out); }
200 };
201
202 namespace {
203 auto trivial_1param_lambda = [](auto &&op) {};
204 }
205 template <typename TTVisitor = decltype(trivial_1param_lambda)&, typename InVisitor = decltype(trivial_1param_lambda)&, typename OutVisitor = decltype(trivial_1param_lambda)&>
206 auto make_traverse(TTVisitor &&tt_v = trivial_1param_lambda, InVisitor &&in_v = trivial_1param_lambda, OutVisitor &&out_v = trivial_1param_lambda) {
207 return Traverse<std::remove_reference_t<TTVisitor>, std::remove_reference_t<InVisitor>,
208 std::remove_reference_t<OutVisitor>>{std::forward<TTVisitor>(tt_v), std::forward<InVisitor>(in_v),
209 std::forward<OutVisitor>(out_v)};
210 };
211
213 static Traverse<> verify{};
214
216 static auto print_ttg = make_traverse(
217 [](auto *tt) {
218 std::cout << "tt: " << (void *)tt << " " << tt->get_name() << " numin " << tt->get_inputs().size() << " numout "
219 << tt->get_outputs().size() << std::endl;
220 },
221 [](auto *in) {
222 std::cout << " in: " << in->get_index() << " " << in->get_name() << " " << in->get_key_type_str() << " "
223 << in->get_value_type_str() << std::endl;
224 },
225 [](auto *out) {
226 std::cout << " out: " << out->get_index() << " " << out->get_name() << " " << out->get_key_type_str() << " "
227 << out->get_value_type_str() << std::endl;
228 });
229
230
231} // namespace ttg
232
233#endif // TTG_TRAVERSE_H
A base class for all template tasks.
Definition tt.h:32
const std::vector< TerminalBase * > & get_outputs() const
Returns the vector of output terminals.
Definition tt.h:228
const std::vector< TerminalBase * > & get_inputs() const
Returns the vector of input terminals.
Definition tt.h:225
const std::string & get_name() const
Gets the name of this operation.
Definition tt.h:219
Traverses a graph of ops in depth-first manner following out edges.
Definition traverse.h:145
const OutVisitor & out_visitor() const
Definition traverse.h:167
const TTVisitor & tt_visitor() const
Definition traverse.h:165
std::enable_if_t< std::is_base_of_v< TTBase, std::decay_t< decltype(*(std::declval< TTBasePtr >()))> > &&std::is_base_of_v< TTBase, std::decay_t< decltype(*(std::declval< TTBasePtrs >()))> bool operator()(TTBasePtr &&op, TTBasePtrs &&... ops)
Definition traverse.h:173
const InVisitor & in_visitor() const
Definition traverse.h:166
Traverse(TTVisitor_ &&tt_v=TTVisitor_{}, InVisitor_ &&in_v=InVisitor_{}, OutVisitor_ &&out_v=OutVisitor_{})
Definition traverse.h:160
Traverses a graph of TTs in depth-first manner following out edges.
Definition traverse.h:15
std::enable_if_t< std::is_base_of_v< TTBase, TT >, bool > traverse(const std::shared_ptr< TTBase > &tt)
Definition traverse.h:114
virtual void ttfunc(TTBase *tt)=0
virtual void infunc(TerminalBase *in)=0
std::enable_if_t< std::is_base_of_v< TTBase, TT >, bool > traverse(const std::unique_ptr< TT, Deleter > &tt)
Definition traverse.h:121
bool traverse(TTBase *tt)
Definition traverse.h:31
virtual void outfunc(TerminalBase *out)=0
std::enable_if_t< std::is_base_of_v< TTBase, TT > &&!std::is_same_v< TT, TTBase >, bool > traverse(TT *tt)
Definition traverse.h:107
STL namespace.
top-level TTG namespace contains runtime-neutral functionality
Definition keymap.h:9
auto make_traverse(TTVisitor &&tt_v=trivial_1param_lambda, InVisitor &&in_v=trivial_1param_lambda, OutVisitor &&out_v=trivial_1param_lambda)
Definition traverse.h:206
void operator()(Visitable *)
visits a non-const Visitable object
Definition traverse.h:130
void operator()(const Visitable *)
visits a const Visitable object
Definition traverse.h:132