traverse.h
Go to the documentation of this file.
1 #ifndef TTG_TRAVERSE_H
2 #define TTG_TRAVERSE_H
3 
4 #include <iostream>
5 #include <set>
6 
7 #include "ttg/tt.h"
8 #include "ttg/util/meta.h"
9 
10 namespace ttg {
11 
12  namespace detail {
14  class Traverse {
15  std::set<TTBase *> seen;
16 
17  bool visited(TTBase *p) { return !seen.insert(p).second; }
18 
19  public:
20  virtual void ttfunc(TTBase *tt) = 0;
21 
22  virtual void infunc(TerminalBase *in) = 0;
23 
24  virtual void outfunc(TerminalBase *out) = 0;
25 
26  void reset() { seen.clear(); }
27 
28  // Returns true if no null pointers encountered (i.e., if all
29  // encountered terminals/operations are connected)
30  bool traverse(TTBase *tt) {
31  if (!tt) {
32  std::cout << "ttg::Traverse: got a null op!\n";
33  return false;
34  }
35 
36  if (visited(tt)) return true;
37 
38  bool status = true;
39 
40  ttfunc(tt);
41 
42  int count = 0;
43  for (auto in : tt->get_inputs()) {
44  if (!in) {
45  std::cout << "ttg::Traverse: got a null in!\n";
46  status = false;
47  } else {
48  infunc(in);
49  if (!in->is_connected()) {
50  std::cout << "ttg::Traverse: " << tt->get_name() << " input terminal #" << count << " " << in->get_name()
51  << " is not connected\n";
52  status = false;
53  }
54  }
55  count++;
56  }
57 
58  for (auto in : tt->get_inputs()) {
59  if (in) {
60  for (auto predecessor : in->get_predecessors()) {
61  if (!predecessor) {
62  std::cout << "ttg::Traverse: got a null predecessor!\n";
63  status = false;
64  } else {
65  status &= traverse(predecessor->get_tt());
66  }
67  }
68  }
69  }
70 
71  count = 0;
72  for (auto out : tt->get_outputs()) {
73  if (!out) {
74  std::cout << "ttg::Traverse: got a null out!\n";
75  status = false;
76  } else {
77  outfunc(out);
78  if (!out->is_connected()) {
79  std::cout << "ttg::Traverse: " << tt->get_name() << " output terminal #" << count << " "
80  << out->get_name() << " is not connected\n";
81  status = false;
82  }
83  }
84  count++;
85  }
86 
87  for (auto out : tt->get_outputs()) {
88  if (out) {
89  for (auto successor : out->get_connections()) {
90  if (!successor) {
91  std::cout << "ttg::Traverse: got a null successor!\n";
92  status = false;
93  } else {
94  status &= traverse(successor->get_tt());
95  }
96  }
97  }
98  }
99 
100  return status;
101  }
102 
103  template <typename TT>
104  std::enable_if_t<std::is_base_of_v<TTBase, TT> && !std::is_same_v<TT, TTBase>,
105  bool>
106  traverse(TT* tt) {
107  return traverse(static_cast<TTBase*>(tt));
108  }
109 
110  template <typename TT>
111  std::enable_if_t<std::is_base_of_v<TTBase, TT>,
112  bool>
113  traverse(const std::shared_ptr<TTBase>& tt) {
114  return traverse(tt.get());
115  }
116 
117  template <typename TT, typename Deleter>
118  std::enable_if_t<std::is_base_of_v<TTBase, TT>,
119  bool>
120  traverse(const std::unique_ptr<TT, Deleter>& tt) {
121  return traverse(tt.get());
122  }
123 
126  template <typename Visitable>
127  struct null_visitor {
129  void operator()(Visitable*) {};
131  void operator()(const Visitable*) {};
132  };
133 
134  };
135  } // namespace detail
136 
141  template <typename TTVisitor = detail::Traverse::null_visitor<TTBase>,
142  typename InVisitor = detail::Traverse::null_visitor<TerminalBase>,
143  typename OutVisitor = detail::Traverse::null_visitor<TerminalBase>>
144  class Traverse : private detail::Traverse {
145  public:
146  static_assert(
147  std::is_void_v<meta::void_t<decltype(std::declval<TTVisitor>()(std::declval<TTBase *>()))>>,
148  "Traverse<TTVisitor,...>: TTVisitor(TTBase *op) must be a valid expression");
149  static_assert(
150  std::is_void_v<meta::void_t<decltype(std::declval<InVisitor>()(std::declval<TerminalBase *>()))>>,
151  "Traverse<,InVisitor,>: InVisitor(TerminalBase *op) must be a valid expression");
152  static_assert(
153  std::is_void_v<meta::void_t<decltype(std::declval<OutVisitor>()(std::declval<TerminalBase *>()))>>,
154  "Traverse<...,OutVisitor>: OutVisitor(TerminalBase *op) must be a valid expression");
155 
156  template <typename TTVisitor_ = detail::Traverse::null_visitor<TTBase>,
157  typename InVisitor_ = detail::Traverse::null_visitor<TerminalBase>,
158  typename OutVisitor_ = detail::Traverse::null_visitor<TerminalBase>>
159  Traverse(TTVisitor_ &&tt_v = TTVisitor_{}, InVisitor_ &&in_v = InVisitor_{}, OutVisitor_ &&out_v = OutVisitor_{})
160  : tt_visitor_(std::forward<TTVisitor_>(tt_v))
161  , in_visitor_(std::forward<InVisitor_>(in_v))
162  , out_visitor_(std::forward<OutVisitor_>(out_v)){};
163 
164  const TTVisitor &tt_visitor() const { return tt_visitor_; }
165  const InVisitor &in_visitor() const { return in_visitor_; }
166  const OutVisitor &out_visitor() const { return out_visitor_; }
167 
169  template <typename TTBasePtr, typename ... TTBasePtrs>
170  std::enable_if_t<std::is_base_of_v<TTBase, std::decay_t<decltype(*(std::declval<TTBasePtr>()))>> && (std::is_base_of_v<TTBase, std::decay_t<decltype(*(std::declval<TTBasePtrs>()))>> && ...),
171  bool>
173  TTBasePtr&& op, TTBasePtrs && ... ops) {
174  reset();
175  bool result = traverse_all(std::forward<TTBasePtr>(op), std::forward<TTBasePtrs>(ops)...);
176  reset();
177  return result;
178  }
179 
180  private:
181  TTVisitor tt_visitor_;
182  InVisitor in_visitor_;
183  OutVisitor out_visitor_;
184 
185  template <typename TTBasePtr, typename ... TTBasePtrs>
186  bool traverse_all(TTBasePtr&& op, TTBasePtrs && ... ops) {
187  bool result = traverse(op);
188  if constexpr(sizeof...(ops) > 0) {
189  result &= traverse_all(std::forward<TTBasePtrs>(ops)...);
190  }
191  return result;
192  }
193 
194  void ttfunc(TTBase *tt) { tt_visitor_(tt); }
195 
196  void infunc(TerminalBase *in) { in_visitor_(in); }
197 
198  void outfunc(TerminalBase *out) { out_visitor_(out); }
199  };
200 
201  namespace {
202  auto trivial_1param_lambda = [](auto &&op) {};
203  }
204  template <typename TTVisitor = decltype(trivial_1param_lambda)&, typename InVisitor = decltype(trivial_1param_lambda)&, typename OutVisitor = decltype(trivial_1param_lambda)&>
205  auto make_traverse(TTVisitor &&tt_v = trivial_1param_lambda, InVisitor &&in_v = trivial_1param_lambda, OutVisitor &&out_v = trivial_1param_lambda) {
206  return Traverse<std::remove_reference_t<TTVisitor>, std::remove_reference_t<InVisitor>,
207  std::remove_reference_t<OutVisitor>>{std::forward<TTVisitor>(tt_v), std::forward<InVisitor>(in_v),
208  std::forward<OutVisitor>(out_v)};
209  };
210 
212  static Traverse<> verify{};
213 
215  static auto print_ttg = make_traverse(
216  [](auto *tt) {
217  std::cout << "tt: " << (void *)tt << " " << tt->get_name() << " numin " << tt->get_inputs().size() << " numout "
218  << tt->get_outputs().size() << std::endl;
219  },
220  [](auto *in) {
221  std::cout << " in: " << in->get_index() << " " << in->get_name() << " " << in->get_key_type_str() << " "
222  << in->get_value_type_str() << std::endl;
223  },
224  [](auto *out) {
225  std::cout << " out: " << out->get_index() << " " << out->get_name() << " " << out->get_key_type_str() << " "
226  << out->get_value_type_str() << std::endl;
227  });
228 
229 
230 } // namespace ttg
231 
232 #endif // TTG_TRAVERSE_H
A base class for all template tasks.
Definition: tt.h:30
const std::vector< TerminalBase * > & get_inputs() const
Returns the vector of input terminals.
Definition: tt.h:223
const std::vector< TerminalBase * > & get_outputs() const
Returns the vector of output terminals.
Definition: tt.h:226
const std::string & get_name() const
Gets the name of this operation.
Definition: tt.h:217
Traverses a graph of ops in depth-first manner following out edges.
Definition: traverse.h:144
const InVisitor & in_visitor() const
Definition: traverse.h:165
const OutVisitor & out_visitor() const
Definition: traverse.h:166
const TTVisitor & tt_visitor() const
Definition: traverse.h:164
Traverse(TTVisitor_ &&tt_v=TTVisitor_{}, InVisitor_ &&in_v=InVisitor_{}, OutVisitor_ &&out_v=OutVisitor_{})
Definition: traverse.h:159
std::enable_if_t< std::is_base_of_v< TTBase, std::decay_t< decltype(*(std::declval< TTBasePtr >)))> > &&(std::is_base_of_v< TTBase, std::decay_t< decltype(*(std::declval< TTBasePtrs >)))> bool operator()(TTBasePtr &&op, TTBasePtrs &&... ops)
Definition: traverse.h:172
Traverses a graph of TTs in depth-first manner following out edges.
Definition: traverse.h:14
std::enable_if_t< std::is_base_of_v< TTBase, TT >, bool > traverse(const std::unique_ptr< TT, Deleter > &tt)
Definition: traverse.h:120
virtual void ttfunc(TTBase *tt)=0
std::enable_if_t< std::is_base_of_v< TTBase, TT >, bool > traverse(const std::shared_ptr< TTBase > &tt)
Definition: traverse.h:113
virtual void infunc(TerminalBase *in)=0
std::enable_if_t< std::is_base_of_v< TTBase, TT > &&!std::is_same_v< TT, TTBase >, bool > traverse(TT *tt)
Definition: traverse.h:106
bool traverse(TTBase *tt)
Definition: traverse.h:30
virtual void outfunc(TerminalBase *out)=0
void void_t
Definition: meta.h:23
constexpr bool is_void_v
Definition: meta.h:209
top-level TTG namespace contains runtime-neutral functionality
Definition: keymap.h:8
auto make_traverse(TTVisitor &&tt_v=trivial_1param_lambda, InVisitor &&in_v=trivial_1param_lambda, OutVisitor &&out_v=trivial_1param_lambda)
Definition: traverse.h:205
void operator()(Visitable *)
visits a non-const Visitable object
Definition: traverse.h:129
void operator()(const Visitable *)
visits a const Visitable object
Definition: traverse.h:131