Skip to content

Commit

Permalink
Fix CI test failures (#55)
Browse files Browse the repository at this point in the history
Summary:
The unit test workflow seems to hang needs to be fixed: https://github.com/pytorch-labs/tritonbench/actions/runs/11898546601/job/33155282740

This PR rewrites the unit test function to run each test in an individual subprocess.

Pull Request resolved: #55

Reviewed By: FindHao

Differential Revision: D66167407

Pulled By: xuzhao9

fbshipit-source-id: 704db2c4f53d2c5f205f5f478e7e553dc274083d
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 19, 2024
1 parent 6e52ed2 commit d9633be
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 42 deletions.
7 changes: 4 additions & 3 deletions .ci/tritonbench/test-gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ fi

. "${SETUP_SCRIPT}"

# install deps
pip install psutil tabulate
# FIXME: patch hstu
sudo apt-get install -y patch
python install.py --hstu

python -m unittest test.test_gpu.main
python -m unittest test.test_gpu.main -v
1 change: 1 addition & 0 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ on:
pull_request:
paths:
- .ci/*
- test/test_gpu/*
- tritonbench/*
- .github/workflows/pr.yaml
push:
Expand Down
6 changes: 0 additions & 6 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,6 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorRe
Opbench = load_opbench_by_name_from_loader(args)
else:
Opbench = load_opbench_by_name(args.op)
if args.fwd_bwd:
args.mode = "fwd_bwd"
if args.bwd:
args.mode = "bwd"
if args.fwd_no_grad:
args.mode = "fwd_no_grad"
opbench = Opbench(
tb_args=args,
extra_args=extra_args,
Expand Down
42 changes: 32 additions & 10 deletions test/test_gpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ def check_ci_output(op):
), f"output impls: {output_impls} != ci_enabled impls: {ci_enabled_impls}"


def _run_one_operator(
tb_args: argparse.Namespace,
extra_args: Optional[List[str]] = None,
):
def _run_one_operator(args: List[str]):
parser = get_parser(args)
tb_args, extra_args = parser.parse_known_args(args)
if tb_args.op in skip_tests:
# If the op itself is in the skip list, skip all tests
if not skip_tests[tb_args.op]:
Expand All @@ -80,6 +79,31 @@ def _run_one_operator(
)


def _run_operator_in_task(op: str, args: List[str]):
from tritonbench.operators.op_task import OpTask

if op in skip_tests:
# If the op itself is in the skip list, skip all tests
if not skip_tests[op]:
return
skip = ",".join(skip_tests[op])
args.extend(["--skip", skip])
task = OpTask(op)
task.make_operator_instance(args=args)
task.run()
task.check_output()
task.del_op_instance()
# Test backward (if applicable)
try:
args.extend(["--bwd"])
task.make_operator_instance(args=args)
task.run()
task.check_output()
except NotImplementedError:
# Operator does not support backward, skip the test
pass


def make_test(operator):
def test_case(self):
# Add `--test-only` to disable Triton autotune in tests
Expand All @@ -92,12 +116,10 @@ def test_case(self):
"1",
"--test-only",
]
parser = get_parser(args)
tb_args, extra_args = parser.parse_known_args(args)
_run_one_operator(
tb_args,
extra_args,
)
if IS_FBCODE:
_run_one_operator(args)
else:
_run_operator_in_task(op=operator, args=args)

return test_case

Expand Down
10 changes: 5 additions & 5 deletions test/test_gpu/skip_tests_h100_pytorch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ flash_attention:
# - flex_attention
fp8_attention:
- colfax_fmha
# triton_flash_v2 now requires the main branch of Triton
# pytorch version does not work
- triton_flash_v2
fp8_fused_quant_gemm_rowwise:
fp8_gemm:
- triton_persistent_fp8_gemm
Expand All @@ -29,9 +32,6 @@ jagged_layer_norm:
jagged_mean:
jagged_softmax:
jagged_sum:
layer_norm:
low_mem_dropout:
rms_norm:
rope:
template_attention:
ragged_attention:
- hstu_triton_ragged_attention_persistent
test_op:
7 changes: 2 additions & 5 deletions tritonbench/operators/addmm/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
import triton
from tritonbench.utils.path_utils import add_path, SUBMODULE_PATH

with add_path(str(SUBMODULE_PATH)):
triton_addmm = importlib.import_module(
"generative-recommenders.ops.triton.triton_addmm"
)
_addmm_fwd = triton_addmm._addmm_fwd
with add_path(str(SUBMODULE_PATH.joinpath("generative-recommenders"))):
from generative_recommenders.ops.triton.triton_addmm import _addmm_fwd


class _AddMmFunction(torch.autograd.Function):
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/fp8_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

from tritonbench.kernels.triton_fused_attention import attention as triton_attention
from tritonbench.kernels.triton_fused_attention import attention_opt as triton_attention
from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Expand Down Expand Up @@ -110,7 +110,7 @@ def triton_flash_v2(
triton_q, triton_k, triton_v = self.triton_preprocess(q, k, v)
# full fp8 will be enabled if type of q,k,v is fp8
return lambda: triton_attention(
triton_q, triton_k, triton_v, False, self.sm_scale
triton_q, triton_k, triton_v, False, self.sm_scale, "base"
)

def get_x_val(self, _example_inputs) -> Tuple[int, int, int, int]:
Expand Down
36 changes: 27 additions & 9 deletions tritonbench/operators/op_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class OpDetails:

name: str
exists: bool
metadata: Dict[str, Any]


class OpTask(base_task.TaskBase):
Expand Down Expand Up @@ -111,7 +110,6 @@ def _maybe_import_operator(package: str, op_name: str) -> Dict[str, Any]:
return {
"name": op_name,
"exists": Operator is not None,
"metadata": {},
}

# =========================================================================
Expand All @@ -121,23 +119,24 @@ def _maybe_import_operator(package: str, op_name: str) -> Dict[str, Any]:
@base_task.run_in_worker(scoped=True)
@staticmethod
def make_operator_instance(
mode: str,
device: str,
extra_args: Optional[List[str]] = None,
args: List[str],
) -> None:
from tritonbench.utils.parser import get_parser

parser = get_parser()
tb_args, extra_args = parser.parse_known_args(args)
Operator = globals()["Operator"]
parser = get_parser()
op = Operator(
mode=mode,
device=device,
tb_args=tb_args,
extra_args=extra_args,
)

import gc

gc.collect()

if device == "cuda":
torch.cuda.empty_cache()
if op.device == "cuda":
maybe_sync = torch.cuda.synchronize
else:
maybe_sync = lambda: None
Expand Down Expand Up @@ -181,6 +180,25 @@ def get_attribute(
else:
return None

# =========================================================================
# == Check output is expected in the child process ========================
# =========================================================================
@base_task.run_in_worker(scoped=True)
@staticmethod
def check_output() -> None:
op = globals()["op"]
from tritonbench.utils.triton_op import REGISTERED_BENCHMARKS

output = op.output
output_impls = output.result[0][1].keys()
ci_enabled_impls = [
x for x in REGISTERED_BENCHMARKS[output.op_name].keys() if x not in op._skip
]
# Make sure that all the ci_enabled impls are in the output
assert set(output_impls) == set(
ci_enabled_impls
), f"output impls: {output_impls} != ci_enabled impls: {ci_enabled_impls}"

def del_op_instance(self):
self.worker.run(
"""
Expand Down
20 changes: 18 additions & 2 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,22 @@ def __call__(cls, *args, **kwargs):
return obj


def _translate_mode(tb_args):
def _has_and_true(attr):
if hasattr(tb_args, attr) and getattr(tb_args, attr):
return True
return False

if _has_and_true("fwd"):
tb_args.mode = "fwd"
if _has_and_true("bwd"):
tb_args.mode = "bwd"
if _has_and_true("fwd_bwd"):
tb_args.mode = "fwd_bwd"
if _has_and_true("fwd_no_grad"):
tb_args.mode = "fwd_no_grad"


class BenchmarkOperator(metaclass=PostInitProcessor):
mode: Mode = Mode.FWD
test: str = "eval"
Expand Down Expand Up @@ -555,7 +571,7 @@ def __init__(
self.use_cuda_graphs = (
self.tb_args.cudagraph if self.tb_args.cudagraph else self.use_cuda_graphs
)
# we accept both "fwd" and "eval"
_translate_mode(self.tb_args)
if self.tb_args.mode == "fwd":
self.mode = Mode.FWD
elif self.tb_args.mode == "fwd_bwd":
Expand All @@ -565,7 +581,7 @@ def __init__(
else:
assert (
self.tb_args.mode == "bwd"
), f"We only accept 3 test modes: fwd(eval), fwd_bwd(train), or bwd."
), "We only accept test modes: fwd, bwd, fwd_bwd, or fwd_no_grad."
self.mode = Mode.BWD
self.device = tb_args.device
self.required_metrics = (
Expand Down

0 comments on commit d9633be

Please sign in to comment.