Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 12e078f

Browse files
[WIP]Debug print graph
1 parent 9653ab4 commit 12e078f

File tree

4 files changed

+97
-0
lines changed

4 files changed

+97
-0
lines changed

src/common/exec_utils.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,25 @@ bool CheckForInputNameDuplicates(const nnvm::IndexedGraph& idx) {
7575
return true;
7676
}
7777

78+
void PrintGraph(const nnvm::IndexedGraph& idx, std::ostream& os) {
79+
auto node_str = [&idx](uint32_t nid) {
80+
return std::to_string(nid) + " " + idx[nid].source->attrs.name;
81+
};
82+
for (size_t i = 0; i < idx.num_nodes(); ++i) {
83+
const auto& attrs = idx[i].source->attrs;
84+
os << "node " << node_str(i) << " " << (attrs.op ? attrs.op->name : "(var)") << "\n";
85+
for (auto [k, v] : attrs.dict)
86+
os << "attr " << k << " " << v << "\n";
87+
for (const auto& inp : idx[i].inputs)
88+
os << "inp " << node_str(inp.node_id) << " " << inp.index << " " << inp.version << "\n";
89+
for (auto dep : idx[i].control_deps)
90+
os << "dep " << node_str(dep) << "\n";
91+
for (const auto& sub : attrs.subgraphs) {
92+
std::string name;
93+
os << "sub " << (sub->GetAttr("name", &name) ? name : "(noname)") << "\n";
94+
}
95+
}
96+
}
97+
7898
} // namespace common
7999
} // namespace mxnet

src/common/exec_utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <nnvm/graph.h>
2828
#include <nnvm/pass_functions.h>
2929
#include <map>
30+
#include <ostream>
3031
#include <vector>
3132
#include <string>
3233
#include <utility>
@@ -570,6 +571,14 @@ void CopyGraph(nnvm::Graph* dst, const nnvm::Graph& src, bool copy_variables);
570571
*/
571572
bool CheckForInputNameDuplicates(const nnvm::IndexedGraph& idx);
572573

574+
/*!
575+
* \brief Prints graph to the specified stream.
576+
*
577+
* \param idx Indexed graph to print
578+
* \param os Output stream
579+
*/
580+
void PrintGraph(const nnvm::IndexedGraph& idx, std::ostream& os);
581+
573582
} // namespace common
574583
} // namespace mxnet
575584
#endif // MXNET_COMMON_EXEC_UTILS_H_

src/imperative/cached_op.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#define MXNET_IMPERATIVE_CACHED_OP_H_
2222

2323
#include <mxnet/imperative.h>
24+
#include <fstream>
2425
#include <vector>
2526
#include <numeric>
2627
#include <atomic>
@@ -29,6 +30,7 @@
2930
#include <unordered_map>
3031
#include <map>
3132
#include "../common/alm.h"
33+
#include "../common/exec_utils.h"
3234
#include "../operator/operator_common.h"
3335
#include "../operator/subgraph/common.h"
3436
#include "./imperative_utils.h"
@@ -330,13 +332,34 @@ void SetRefCounts(nnvm::Graph* fwd_graph, const nnvm::Graph& full_graph) {
330332
std::make_shared<dmlc::any>(std::move(full_ref_count));
331333
}
332334

335+
void MaybePrintGraph(const nnvm::IndexedGraph& idx, const std::string& msg) {
336+
if (!dmlc::GetEnv("MXNET_DEBUG_PRINT_GRAPH", false))
337+
return;
338+
339+
std::ofstream f;
340+
std::ostream* dest = &std::cout;
341+
std::string dest_name = dmlc::GetEnv("MXNET_DEBUG_PRINT_GRAPH_PATH", std::string("stdout"));
342+
if (dest_name == "stderr") {
343+
dest = &std::cerr;
344+
} else if (dest_name != "stdout") {
345+
f.open(dest_name.c_str(), std::ios::app);
346+
CHECK(f.good());
347+
dest = &f;
348+
}
349+
350+
*dest << "[[[ " << msg << "\n";
351+
common::PrintGraph(idx, *dest);
352+
*dest << "]]] " << msg << "\n";
353+
}
354+
333355
void OptimizeGraph(nnvm::Graph* full_graph,
334356
nnvm::Graph* fwd_graph,
335357
nnvm::Graph* grad_graph,
336358
std::vector<size_t>* input_map,
337359
const Context& context,
338360
size_t num_forward_outputs,
339361
const bool inlining) {
362+
MaybePrintGraph(full_graph->indexed_graph(), "graph before optimization");
340363
input_map->resize(full_graph->indexed_graph().input_nodes().size());
341364
std::iota(input_map->begin(), input_map->end(), 0);
342365
#if MXNET_USE_CUDA && !defined(_WIN32)
@@ -386,6 +409,7 @@ void OptimizeGraph(nnvm::Graph* full_graph,
386409
grad_graph->outputs = std::vector<nnvm::NodeEntry>(
387410
full_graph->outputs.begin() + num_forward_outputs, full_graph->outputs.end());
388411
SetRefCounts(fwd_graph, *full_graph);
412+
MaybePrintGraph(full_graph->indexed_graph(), "graph after optimization");
389413
}
390414

391415
/* \brief Check if param indices and data indices are set, if not then set data indices */

tools/print_graph.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#!/usr/bin/env python3
2+
3+
import re
4+
import sys
5+
6+
7+
RE_NODE = re.compile(r'node\s(.+)\n')
8+
RE_ATTR = re.compile(r'attr\s(.+)\n')
9+
RE_INP = re.compile(r'inp\s(.+)\n')
10+
RE_DEP = re.compile(r'dep\s(.+)\n')
11+
RE_SUB = re.compile(r'node\s(.+)\n')
12+
13+
14+
def to_dot(f):
15+
print('digraph Net {')
16+
for line in f:
17+
m = RE_NODE.fullmatch(line)
18+
if m:
19+
nid, name, op = m.group(1).split()
20+
shape = 'ellipse' if op == '(var)' else 'rectangle'
21+
print(f' node_{nid} [shape={shape}, label={name}]')
22+
continue
23+
m = RE_ATTR.fullmatch(line)
24+
if m:
25+
continue
26+
m = RE_INP.fullmatch(line)
27+
if m:
28+
njd, _name, index, _version = m.group(1).split()
29+
print(f' node_{njd} -> node_{nid} [label={index}, style=solid]')
30+
continue
31+
m = RE_DEP.fullmatch(line)
32+
if m:
33+
njd, _name = m.group(1).split()
34+
print(f' node_{njd} -> node_{nid} [style=dashed]')
35+
continue
36+
m = RE_SUB.fullmatch(line)
37+
if m:
38+
continue
39+
break
40+
print('}')
41+
42+
43+
if __name__ == '__main__':
44+
to_dot(sys.stdin)

0 commit comments

Comments
 (0)