1515
1616from tritonbench .operators import load_opbench_by_name
1717from tritonbench .operators_collection import list_operators_by_collection
18+ from tritonbench .utils .ab_test import compare_ab_results , run_ab_test
1819from tritonbench .utils .env_utils import is_fbcode
1920from tritonbench .utils .gpu_utils import gpu_lockdown
2021from tritonbench .utils .list_operator_details import list_operator_details
2324
2425from tritonbench .utils .triton_op import BenchmarkOperatorResult
2526from tritonbench .utils .tritonparse_utils import tritonparse_init , tritonparse_parse
26- from tritonbench .utils .ab_test import run_ab_test , compare_ab_results
2727
2828try :
2929 if is_fbcode ():
3434 usage_report_logger = lambda * args , ** kwargs : None
3535
3636
37-
38-
3937def _run (args : argparse .Namespace , extra_args : List [str ]) -> BenchmarkOperatorResult :
4038 if is_loader_op (args .op ):
4139 Opbench = get_op_loader_bench_cls_by_name (args .op )
@@ -132,23 +130,26 @@ def run(args: List[str] = []):
132130 # Check if A/B testing mode is enabled
133131 if args .side_a is not None and args .side_b is not None :
134132 # A/B testing mode - only support single operator
135- assert len (ops ) == 1 , "A/B testing validation should have caught multiple operators"
133+ assert (
134+ len (ops ) == 1
135+ ), "A/B testing validation should have caught multiple operators"
136136 op = ops [0 ]
137137 args .op = op
138-
138+
139139 print ("[A/B Testing Mode Enabled]" )
140140 print (f"Operator: { op } " )
141141 print ()
142-
142+
143143 with gpu_lockdown (args .gpu_lockdown ):
144144 try :
145145 result_a , result_b = run_ab_test (args , extra_args , _run )
146-
146+
147147 from tritonbench .utils .ab_test import parse_ab_config
148+
148149 config_a_args = parse_ab_config (args .side_a )
149150 config_b_args = parse_ab_config (args .side_b )
150151 compare_ab_results (result_a , result_b , config_a_args , config_b_args )
151-
152+
152153 except Exception as e :
153154 print (f"A/B test failed: { e } " )
154155 if not args .bypass_fail :
@@ -166,7 +167,7 @@ def run(args: List[str] = []):
166167 run_in_task (op )
167168 else :
168169 _run (args , extra_args )
169-
170+
170171 tritonparse_parse (args .tritonparse )
171172
172173
0 commit comments