Skip to content

Commit

Permalink
Use code detection to check bwd method override. (#57)
Browse files Browse the repository at this point in the history
Summary:
We can use a smarter way to detect whether an operator has implemented backward. This can help speedup the unit test.

Pull Request resolved: #57

Reviewed By: adamomainz

Differential Revision: D66212702

Pulled By: xuzhao9

fbshipit-source-id: 9cbb806d546f151eddceafe5dad0d3d7f05b6cb6
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 20, 2024
1 parent b151b84 commit 6091289
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
15 changes: 4 additions & 11 deletions test/test_gpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,13 @@ def _run_one_operator(args: List[str]):
op = Operator(tb_args=tb_args, extra_args=extra_args)
op.run()
check_ci_output(op)
del op
# Test backward (if applicable)
try:
if op.has_bwd():
del op
tb_args.mode = "bwd"
op = Operator(tb_args=tb_args, extra_args=extra_args)
op.run()
check_ci_output(op)
except NotImplementedError:
logger.info(
f"Operator {op.name} does not support backward, skipping backward test."
)


def _run_operator_in_task(op: str, args: List[str]):
Expand All @@ -92,16 +88,13 @@ def _run_operator_in_task(op: str, args: List[str]):
task.make_operator_instance(args=args)
task.run()
task.check_output()
task.del_op_instance()
# Test backward (if applicable)
try:
if task.get_attribute("has_bwd", method=True):
task.del_op_instance()
args.extend(["--bwd"])
task.make_operator_instance(args=args)
task.run()
task.check_output()
except NotImplementedError:
# Operator does not support backward, skip the test
pass


def make_test(operator):
Expand Down
11 changes: 7 additions & 4 deletions tritonbench/operators/op_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,21 @@ def run(self) -> None:
@base_task.run_in_worker(scoped=True)
@staticmethod
def get_attribute(
attr: str, field: Optional[str] = None, classattr: bool = False
attr: str,
field: Optional[str] = None,
classattr: bool = False,
method: bool = False,
) -> Any:
if classattr:
op = globals()["Operator"]
else:
op = globals()["op"]
if hasattr(op, attr):
if field:
op_attr = getattr(op, attr)
return getattr(op_attr, field)
op_attr = getattr(getattr(op, attr), field)
else:
return getattr(op, attr)
op_attr = getattr(op, attr)
return op_attr() if method else op_attr
else:
return None

Expand Down
4 changes: 4 additions & 0 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,3 +1517,7 @@ def run_and_capture(self, *args, **kwargs):
ir_dir / f"{fn._name}_{kernel.name}_{input_id}.sass", "w"
) as f:
f.write(sass)

@classmethod
def has_bwd(cls) -> bool:
return cls.get_bwd_fn is not BenchmarkOperator.get_bwd_fn

0 comments on commit 6091289

Please sign in to comment.