dot.h
Go to the documentation of this file.
1 #ifndef TTG_UTIL_DOT_H
2 #define TTG_UTIL_DOT_H
3 
4 #include <sstream>
5 #include <map>
6 #include <string>
7 
8 #include "ttg/base/terminal.h"
9 #include "ttg/traverse.h"
10 
11 namespace ttg {
13  class Dot : private detail::Traverse {
14  std::stringstream edges;
15  std::map<const TTBase*, std::stringstream> tt_nodes;
16  std::multimap<const TTBase *, const TTBase *> ttg_hierarchy;
17  int cluster_cnt;
18  bool disable_type;
19 
20  public:
23  Dot(bool disable_type = false) : disable_type(disable_type){};
24 
25  // Insert backslash before characters that dot is interpreting
26  std::string escape(const std::string &in) {
27  std::stringstream s;
28  for (char c : in) {
29  if (c == '<' || c == '>' || c == '"' || c == '|')
30  s << "\\" << c;
31  else
32  s << c;
33  }
34  return s.str();
35  }
36 
37  // A unique name for the node derived from the pointer
38  std::string nodename(const TTBase *op) {
39  std::stringstream s;
40  s << "n" << (void *)op;
41  return s.str();
42  }
43 
44  void build_ttg_hierarchy(const TTBase *tt) {
45  if(nullptr == tt) {
46  return;
47  }
48  auto search = ttg_hierarchy.find(tt->ttg_ptr());
49  if(search == ttg_hierarchy.end()) {
50  build_ttg_hierarchy(tt->ttg_ptr()); // make sure the parent is in the hierarchy
51  }
52  search = ttg_hierarchy.find(tt);
53  if(search == ttg_hierarchy.end()) {
54  ttg_hierarchy.insert( decltype(ttg_hierarchy)::value_type(tt->ttg_ptr(), tt) );
55  }
56  }
57 
58  void ttfunc(TTBase *tt) {
59  std::string ttnm = nodename(tt);
60  bool is_ttg = true;
61 
62  const TTBase *ttc = reinterpret_cast<const TTBase*>(tt);
64  if(!tt->is_ttg()) {
65  std::stringstream ttss;
66 
67  ttss << " " << ttnm << " [shape=record,style=filled,fillcolor=gray90,label=\"{";
68 
69  size_t count = 0;
70  if (tt->get_inputs().size() > 0) ttss << "{";
71  for (auto in : tt->get_inputs()) {
72  if (in) {
73  if (count != in->get_index()) throw "ttg::Dot: lost count of ins";
74  if (disable_type) {
75  ttss << " <in" << count << ">"
76  << " " << escape(in->get_key_type_str()) << " " << escape(in->get_name());
77  } else {
78  ttss << " <in" << count << ">"
79  << " " << escape("<" + in->get_key_type_str() + "," + in->get_value_type_str() + ">") << " "
80  << escape(in->get_name());
81  }
82  } else {
83  ttss << " <in" << count << ">"
84  << " unknown ";
85  }
86  count++;
87  if (count < tt->get_inputs().size()) ttss << " |";
88  }
89  if (tt->get_inputs().size() > 0) ttss << "} |";
90 
91  ttss << tt->get_name() << " ";
92 
93  if (tt->get_outputs().size() > 0) ttss << " | {";
94 
95  count = 0;
96  for (auto out : tt->get_outputs()) {
97  if (out) {
98  if (count != out->get_index()) throw "ttg::Dot: lost count of outs";
99  if (disable_type) {
100  ttss << " <out" << count << ">"
101  << " " << escape(out->get_key_type_str()) << " " << out->get_name();
102  } else {
103  ttss << " <out" << count << ">"
104  << " " << escape("<" + out->get_key_type_str() + "," + out->get_value_type_str() + ">") << " "
105  << out->get_name();
106  }
107  } else {
108  ttss << " <out" << count << ">"
109  << " unknown ";
110  }
111  count++;
112  if (count < tt->get_outputs().size()) ttss << " |";
113  }
114 
115  if (tt->get_outputs().size() > 0) ttss << "}";
116 
117  ttss << " } \"];\n";
118 
119  auto search = tt_nodes.find(ttc);
120  if( tt_nodes.end() == search ) {
121  tt_nodes.insert( {ttc, std::move(ttss)} );
122  } else {
123  search->second << ttss.str();
124  }
125  } else {
126  std::cout << ttnm << " is a TTG" << std::endl;
127  }
128 
129  for (auto out : tt->get_outputs()) {
130  if (out) {
131  for (auto successor : out->get_connections()) {
132  if (successor) {
133  edges << ttnm << ":out" << out->get_index() << ":s -> " << nodename(successor->get_tt()) << ":in"
134  << successor->get_index() << ":n;\n";
135  }
136  }
137  }
138  }
139  }
140 
141  void infunc(TerminalBase *in) {}
142 
143  void outfunc(TerminalBase *out) {}
144 
145  void tree_down(int level, const TTBase *node, std::stringstream &buf) {
146  if(node == nullptr || node->is_ttg()) {
147  if(nullptr != node) {
148  buf << "subgraph cluster_" << cluster_cnt++ << " {\n";
149  }
150  auto children = ttg_hierarchy.equal_range(node);
151  for(auto child = children.first; child != children.second; child++) {
152  assert(child->first == node);
153  tree_down(level+1, child->second, buf);
154  }
155  if(nullptr != node) {
156  buf << " label = \"" << node->get_name() << "\";\n";
157  buf << "}\n";
158  }
159  } else {
160  auto child = tt_nodes.find(node);
161  if( child != tt_nodes.end()) {
162  assert(child->first == node);
163  buf << child->second.str();
164  }
165  }
166  }
167 
168  public:
170  template <typename... TTBasePtrs>
171  std::enable_if_t<(std::is_convertible_v<std::remove_const_t<std::remove_reference_t<TTBasePtrs>>, TTBase *> && ...),
172  std::string>
173  operator()(TTBasePtrs &&... ops) {
174  reset();
175  std::stringstream buf;
176  buf.str(std::string());
177  buf.clear();
178 
179  edges.str(std::string());
180  edges.clear();
181 
182  tt_nodes.clear();
183  ttg_hierarchy.clear();
184 
185  buf << "digraph G {\n";
186  buf << " ranksep=1.5;\n";
187  bool t = true;
188  t &= (traverse(std::forward<TTBasePtrs>(ops)) && ... );
189 
190  cluster_cnt = 0;
191  tree_down(0, nullptr, buf);
192 
193  buf << edges.str();
194  buf << "}\n";
195 
196  reset();
197  std::string result = buf.str();
198  buf.str(std::string());
199  buf.clear();
200 
201  return result;
202  }
203  };
204 } // namespace ttg
205 #endif // TTG_UTIL_DOT_H
Prints the graph to a std::string in the format understood by GraphViz's dot program.
Definition: dot.h:13
void ttfunc(TTBase *tt)
Definition: dot.h:58
void infunc(TerminalBase *in)
Definition: dot.h:141
void outfunc(TerminalBase *out)
Definition: dot.h:143
std::string nodename(const TTBase *op)
Definition: dot.h:38
std::string escape(const std::string &in)
Definition: dot.h:26
void tree_down(int level, const TTBase *node, std::stringstream &buf)
Definition: dot.h:145
Dot(bool disable_type=false)
Definition: dot.h:23
std::enable_if_t<(std::is_convertible_v< std::remove_const_t< std::remove_reference_t< TTBasePtrs >>, TTBase * > &&...), std::string > operator()(TTBasePtrs &&... ops)
Definition: dot.h:173
void build_ttg_hierarchy(const TTBase *tt)
Definition: dot.h:44
A base class for all template tasks.
Definition: tt.h:30
const std::vector< TerminalBase * > & get_inputs() const
Returns the vector of input terminals.
Definition: tt.h:223
const std::vector< TerminalBase * > & get_outputs() const
Returns the vector of output terminals.
Definition: tt.h:226
const std::string & get_name() const
Gets the name of this operation.
Definition: tt.h:217
bool is_ttg() const
Definition: tt.h:209
const TTBase * ttg_ptr() const
Definition: tt.h:205
Traverses a graph of TTs in depth-first manner following out edges.
Definition: traverse.h:14
bool traverse(TTBase *tt)
Definition: traverse.h:30
top-level TTG namespace contains runtime-neutral functionality
Definition: keymap.h:8
int size(World world=default_execution_context())
Definition: run.h:89
auto edges(inedgesT &&...args)
Make a tuple of Edges to pass to.
Definition: func.h:147