diff --git a/test/test_line_profiler.py b/test/test_line_profiler.py index d2d0064..6b38417 100644 --- a/test/test_line_profiler.py +++ b/test/test_line_profiler.py @@ -1,32 +1,40 @@ +import re + +import numpy as np import pytest import torch -import numpy as np +from pytorch_memlab import (LineProfiler, clear_global_line_profiler, profile, + profile_every, set_target_gpu) -from pytorch_memlab import LineProfiler, profile, profile_every, set_target_gpu, clear_global_line_profiler def test_display(): - def work(): - # comment + def main(): linear = torch.nn.Linear(100, 100).cuda() - linear_2 = torch.nn.Linear(100, 100).cuda() - linear_3 = torch.nn.Linear(100, 100).cuda() + part1() + part2() - def work_3(): + def part1(): lstm = torch.nn.LSTM(1000, 1000).cuda() + subpart11() - def work_2(): - # comment + def part2(): + linear_2 = torch.nn.Linear(100, 100).cuda() + linear_3 = torch.nn.Linear(100, 100).cuda() + + def subpart11(): linear = torch.nn.Linear(100, 100).cuda() linear_2 = torch.nn.Linear(100, 100).cuda() linear_3 = torch.nn.Linear(100, 100).cuda() - work_3() - - with LineProfiler(work, work_2) as prof: - work() - work_2() - return prof.display() + with LineProfiler(subpart11, part2) as prof: + main() + + s = str(prof.display()) # cast from line_records.RecordsDisplay + assert re.search("## .*subpart11", s) + assert "def subpart11():" in s + assert re.search("## .*part2", s) + assert "def part2():" in s def test_line_report(): @@ -56,6 +64,7 @@ def work_2(): line_profiler.disable() line_profiler.print_stats() + def test_line_report_decorator(): clear_global_line_profiler() @@ -72,11 +81,13 @@ def work2(): linear = torch.nn.Linear(100, 100).cuda() linear_2 = torch.nn.Linear(100, 100).cuda() linear_3 = torch.nn.Linear(100, 100).cuda() + work() work2() work() work() + def test_line_report_method(): clear_global_line_profiler() @@ -94,6 +105,7 @@ def forward(self, inp): inp = torch.Tensor(50, 100).cuda() net(inp) + def test_line_report_profile(): clear_global_line_profiler() @@ -107,6 +119,7 @@ def work(): work() work() + def test_line_report_profile_set_gpu(): clear_global_line_profiler() @@ -122,6 +135,7 @@ def work(): work() work() + def test_line_report_profile_interrupt(): clear_global_line_profiler() @@ -139,4 +153,4 @@ def work2(): work() work2() - raise KeyboardInterrupt \ No newline at end of file + raise KeyboardInterrupt