From 789dd31e070abfbf913fbb3eef57b93d404aee4f Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 5 Nov 2024 20:01:14 -0800 Subject: [PATCH] Fix a few operators on AMD Summary: Need to make a few changes to enable more operators on AMD: 1. use `HAS_LIGER_KERNEL` to enable liger layer_norm on demand 2. fix a bug when running with `--isolate`. `--isolate` will run each operator in subprocess, to avoid interference between operators. 3. we need to explicitly add some dependencies on AMD Reviewed By: adamomainz Differential Revision: D65450899 fbshipit-source-id: 9cd1406123b0d2a5fbddf15f4a711e88bbd977d6 --- test/test_gpu/main.py | 2 +- tritonbench/operators/layer_norm/operator.py | 2 +- tritonbench/utils/parser.py | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_gpu/main.py b/test/test_gpu/main.py index 32885988..36ec6cdf 100644 --- a/test/test_gpu/main.py +++ b/test/test_gpu/main.py @@ -39,7 +39,7 @@ def check_ci_output(op): output = op.output output_impls = output.result[0][1].keys() - skiped_impls = op.tb_args.skip + skiped_impls = op.tb_args.skip if op.tb_args.skip else [] ci_enabled_impls = [ x for x in REGISTERED_BENCHMARKS[output.op_name].keys() if x not in skiped_impls ] diff --git a/tritonbench/operators/layer_norm/operator.py b/tritonbench/operators/layer_norm/operator.py index 73a0344c..d2e1df29 100644 --- a/tritonbench/operators/layer_norm/operator.py +++ b/tritonbench/operators/layer_norm/operator.py @@ -39,7 +39,7 @@ def inner(*args): return lambda: inner(*args) - @register_benchmark() + @register_benchmark(enabled=HAS_LIGER_KERNEL) def liger_layer_norm(self, *args): (x, w_shape, weight, bias, eps) = args return lambda: LigerLayerNormFunction.apply(x, weight, bias, eps) diff --git a/tritonbench/utils/parser.py b/tritonbench/utils/parser.py index f87d8f9d..02dadae5 100644 --- a/tritonbench/utils/parser.py +++ b/tritonbench/utils/parser.py @@ -190,6 +190,8 @@ def _find_param_loc(params, key: str) -> int: def _remove_params(params, loc): if loc == -1: return params + if params[loc + 1].startswith("--"): + return params[:loc] + params[loc + 1 :] return params[:loc] + params[loc + 2 :]