Skip to content

Commit

Permalink
Fix a few operators on AMD
Browse files Browse the repository at this point in the history
Summary:
Need to make a few changes to enable more operators on AMD:
1. use `HAS_LIGER_KERNEL` to enable liger layer_norm on demand
2. fix a bug when running with `--isolate`. `--isolate` will run each operator in subprocess, to avoid interference between operators.
3. we need to explicitly add some dependencies on AMD

Reviewed By: adamomainz

Differential Revision: D65450899

fbshipit-source-id: 9cd1406123b0d2a5fbddf15f4a711e88bbd977d6
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 6, 2024
1 parent 6ae2eba commit 789dd31
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion test/test_gpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def check_ci_output(op):

output = op.output
output_impls = output.result[0][1].keys()
skiped_impls = op.tb_args.skip
skiped_impls = op.tb_args.skip if op.tb_args.skip else []
ci_enabled_impls = [
x for x in REGISTERED_BENCHMARKS[output.op_name].keys() if x not in skiped_impls
]
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/layer_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def inner(*args):

return lambda: inner(*args)

@register_benchmark()
@register_benchmark(enabled=HAS_LIGER_KERNEL)
def liger_layer_norm(self, *args):
(x, w_shape, weight, bias, eps) = args
return lambda: LigerLayerNormFunction.apply(x, weight, bias, eps)
Expand Down
2 changes: 2 additions & 0 deletions tritonbench/utils/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def _find_param_loc(params, key: str) -> int:
def _remove_params(params, loc):
if loc == -1:
return params
if params[loc + 1].startswith("--"):
return params[:loc] + params[loc + 1 :]
return params[:loc] + params[loc + 2 :]


Expand Down

0 comments on commit 789dd31

Please sign in to comment.