7
7
import shutil
8
8
import subprocess
9
9
from pathlib import Path
10
+ from enum import Enum
10
11
from utils .result import BenchmarkMetadata , BenchmarkTag , Result
11
12
from options import options
12
13
from utils .utils import download , run
13
14
from abc import ABC , abstractmethod
14
15
from utils .unitrace import get_unitrace
16
+ from utils .flamegraph import get_flamegraph
15
17
from utils .logger import log
16
18
19
+
20
+ class TracingType (Enum ):
21
+ """Enumeration of available tracing types."""
22
+
23
+ NONE = ""
24
+ UNITRACE = "unitrace"
25
+ FLAMEGRAPH = "flamegraph"
26
+
27
+
17
28
benchmark_tags = [
18
29
BenchmarkTag ("SYCL" , "Benchmark uses SYCL runtime" ),
19
30
BenchmarkTag ("UR" , "Benchmark uses Unified Runtime API" ),
@@ -62,12 +73,17 @@ def enabled(self) -> bool:
62
73
By default, it returns True, but can be overridden to disable a benchmark."""
63
74
return True
64
75
65
- def traceable (self ) -> bool :
66
- """Returns whether this benchmark should be traced by Unitrace.
67
- By default, it returns True, but can be overridden to disable tracing for a benchmark.
76
+ def traceable (self , tracing_type : TracingType ) -> bool :
77
+ """Returns whether this benchmark should be traced by the specified tracing method.
78
+ By default, it returns True for all tracing types, but can be overridden
79
+ to disable specific tracing methods for a benchmark.
68
80
"""
69
81
return True
70
82
83
+ def tracing_enabled (self , run_trace , force_trace , tr_type : TracingType ):
84
+ """Returns whether tracing is enabled for the given type."""
85
+ return (self .traceable (tr_type ) or force_trace ) and run_trace == tr_type
86
+
71
87
@abstractmethod
72
88
def setup (self ):
73
89
pass
@@ -77,12 +93,18 @@ def teardown(self):
77
93
pass
78
94
79
95
@abstractmethod
80
- def run (self , env_vars , run_unitrace : bool = False ) -> list [Result ]:
96
+ def run (
97
+ self ,
98
+ env_vars ,
99
+ run_trace : TracingType = TracingType .NONE ,
100
+ force_trace : bool = False ,
101
+ ) -> list [Result ]:
81
102
"""Execute the benchmark with the given environment variables.
82
103
83
104
Args:
84
105
env_vars: Environment variables to use when running the benchmark.
85
- run_unitrace: Whether to run benchmark under Unitrace.
106
+ run_trace: The type of tracing to run (NONE, UNITRACE, or FLAMEGRAPH).
107
+ force_trace: If True, ignore the traceable() method and force tracing.
86
108
87
109
Returns:
88
110
A list of Result objects with the benchmark results.
@@ -111,8 +133,9 @@ def run_bench(
111
133
ld_library = [],
112
134
add_sycl = True ,
113
135
use_stdout = True ,
114
- run_unitrace = False ,
115
- extra_unitrace_opt = None ,
136
+ run_trace : TracingType = TracingType .NONE ,
137
+ extra_trace_opt = None ,
138
+ force_trace : bool = False ,
116
139
):
117
140
env_vars = env_vars .copy ()
118
141
if options .ur is not None :
@@ -125,15 +148,26 @@ def run_bench(
125
148
ld_libraries = options .extra_ld_libraries .copy ()
126
149
ld_libraries .extend (ld_library )
127
150
128
- if self .traceable () and run_unitrace :
129
- if extra_unitrace_opt is None :
130
- extra_unitrace_opt = []
151
+ unitrace_output = None
152
+ if self .tracing_enabled (run_trace , force_trace , TracingType .UNITRACE ):
153
+ if extra_trace_opt is None :
154
+ extra_trace_opt = []
131
155
unitrace_output , command = get_unitrace ().setup (
132
- self .name (), command , extra_unitrace_opt
156
+ self .name (), command , extra_trace_opt
133
157
)
134
158
log .debug (f"Unitrace output: { unitrace_output } " )
135
159
log .debug (f"Unitrace command: { ' ' .join (command )} " )
136
160
161
+ # flamegraph run
162
+
163
+ perf_data_file = None
164
+ if self .tracing_enabled (run_trace , force_trace , TracingType .FLAMEGRAPH ):
165
+ perf_data_file , command = get_flamegraph ().setup (
166
+ self .name (), self .get_suite_name (), command
167
+ )
168
+ log .debug (f"FlameGraph perf data: { perf_data_file } " )
169
+ log .debug (f"FlameGraph command: { ' ' .join (command )} " )
170
+
137
171
try :
138
172
result = run (
139
173
command = command ,
@@ -143,13 +177,27 @@ def run_bench(
143
177
ld_library = ld_libraries ,
144
178
)
145
179
except subprocess .CalledProcessError :
146
- if run_unitrace :
180
+ if run_trace == TracingType . UNITRACE and unitrace_output :
147
181
get_unitrace ().cleanup (options .benchmark_cwd , unitrace_output )
182
+ if run_trace == TracingType .FLAMEGRAPH and perf_data_file :
183
+ get_flamegraph ().cleanup (perf_data_file )
148
184
raise
149
185
150
- if self .traceable () and run_unitrace :
186
+ if (
187
+ self .tracing_enabled (run_trace , force_trace , TracingType .UNITRACE )
188
+ and unitrace_output
189
+ ):
151
190
get_unitrace ().handle_output (unitrace_output )
152
191
192
+ if (
193
+ self .tracing_enabled (run_trace , force_trace , TracingType .FLAMEGRAPH )
194
+ and perf_data_file
195
+ ):
196
+ svg_file = get_flamegraph ().handle_output (
197
+ self .name (), perf_data_file , self .get_suite_name ()
198
+ )
199
+ log .info (f"FlameGraph generated: { svg_file } " )
200
+
153
201
if use_stdout :
154
202
return result .stdout .decode ()
155
203
else :
0 commit comments