forked from mstfbl/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_script_profile.cpp
62 lines (53 loc) · 1.59 KB
/
test_script_profile.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include <gtest/gtest.h>
#include <c10/util/Optional.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/runtime/script_profile.h>
namespace torch {
namespace jit {
TEST(ScriptProfileTest, Basic) {
const std::string source_string = R"V0G0N(
def foo(a, b):
return a + b #
)V0G0N";
auto begin = source_string.find("return");
auto end = source_string.find(" #");
Graph g;
const auto graph_string = R"IR(
graph(%a : Tensor,
%b : Tensor):
%2 : int = prim::Constant[value=1]()
%3 : Tensor = aten::add(%a, %b, %2)
return (%3))IR";
torch::jit::parseIR(graph_string, &g);
auto source = std::make_shared<Source>(source_string, "", 0);
auto node = *g.nodes().begin();
node->setSourceRange(SourceRange{source, begin, end});
ScriptProfile p;
p.enable();
{
profiling::InstructionSpan g0(*node);
profiling::InstructionSpan g1(*node);
profiling::InstructionSpan g2(*node);
}
p.disable();
auto stats = p.dumpStats();
EXPECT_EQ(stats.size(), 1);
auto it = stats.find(*source.get());
EXPECT_NE(it, stats.end());
auto& lines = it->second;
EXPECT_EQ(lines.size(), 1);
const auto& stat = lines.at(source->lineno_for_offset(begin));
EXPECT_EQ(stat.count, 3);
}
TEST(ScriptProfileTest, CallingOrder) {
ScriptProfile p;
p.enable();
EXPECT_THROW(p.dumpStats(), c10::Error);
p.disable();
auto dp = std::make_shared<profiling::Datapoint>(SourceRange{});
EXPECT_THROW(p.addDatapoint(std::move(dp)), c10::Error);
}
} // namespace jit
} // namespace torch