diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 6b179f4d..a1b9b5d5 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -521,6 +521,19 @@ def __call__(cls, *args, **kwargs): obj.__post__init__() return obj +def _translate_mode(tb_args): + def _has_and_true(attr): + if hasattr(tb_args, attr) and attr: + return True + return False + if _has_and_true("fwd"): + tb_args.mode = "fwd" + if _has_and_true("bwd"): + tb_args.mode = "bwd" + if _has_and_true("fwd_bwd"): + tb_args.mode = "fwd_bwd" + if _has_and_true("fwd_no_grad"): + tb_args.mode = "fwd_no_grad" class BenchmarkOperator(metaclass=PostInitProcessor): mode: Mode = Mode.FWD @@ -556,11 +569,12 @@ def __init__( self.tb_args.cudagraph if self.tb_args.cudagraph else self.use_cuda_graphs ) # we accept both "fwd" and "eval" - if self.tb_args.mode == "fwd" or self.tb_args.fwd: + _translate_mode(self.tb_args) + if self.tb_args.mode == "fwd": self.mode = Mode.FWD - elif self.tb_args.mode == "fwd_bwd" or self.tb_args.fwd_bwd: + elif self.tb_args.mode == "fwd_bwd": self.mode = Mode.FWD_BWD - elif self.tb_args.mode == "fwd_no_grad" or self.tb_args.fwd_no_grad: + elif self.tb_args.mode == "fwd_no_grad": self.mode = Mode.FWD_NO_GRAD else: assert (