Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CI test failures #55

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .ci/tritonbench/test-gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ fi

. "${SETUP_SCRIPT}"

# install deps
pip install psutil tabulate
# FIXME: patch hstu
sudo apt-get install -y patch
python install.py --hstu

python -m unittest test.test_gpu.main
python -m unittest test.test_gpu.main -v
1 change: 1 addition & 0 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ on:
pull_request:
paths:
- .ci/*
- test/test_gpu/*
- tritonbench/*
- .github/workflows/pr.yaml
push:
Expand Down
6 changes: 0 additions & 6 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,6 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorRe
Opbench = load_opbench_by_name_from_loader(args)
else:
Opbench = load_opbench_by_name(args.op)
if args.fwd_bwd:
args.mode = "fwd_bwd"
if args.bwd:
args.mode = "bwd"
if args.fwd_no_grad:
args.mode = "fwd_no_grad"
opbench = Opbench(
tb_args=args,
extra_args=extra_args,
Expand Down
42 changes: 32 additions & 10 deletions test/test_gpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ def check_ci_output(op):
), f"output impls: {output_impls} != ci_enabled impls: {ci_enabled_impls}"


def _run_one_operator(
tb_args: argparse.Namespace,
extra_args: Optional[List[str]] = None,
):
def _run_one_operator(args: List[str]):
parser = get_parser(args)
tb_args, extra_args = parser.parse_known_args(args)
if tb_args.op in skip_tests:
# If the op itself is in the skip list, skip all tests
if not skip_tests[tb_args.op]:
Expand All @@ -80,6 +79,31 @@ def _run_one_operator(
)


def _run_operator_in_task(op: str, args: List[str]):
from tritonbench.operators.op_task import OpTask

if op in skip_tests:
# If the op itself is in the skip list, skip all tests
if not skip_tests[op]:
return
skip = ",".join(skip_tests[op])
args.extend(["--skip", skip])
task = OpTask(op)
task.make_operator_instance(args=args)
task.run()
task.check_output()
task.del_op_instance()
# Test backward (if applicable)
try:
args.extend(["--bwd"])
task.make_operator_instance(args=args)
task.run()
task.check_output()
except NotImplementedError:
# Operator does not support backward, skip the test
pass


def make_test(operator):
def test_case(self):
# Add `--test-only` to disable Triton autotune in tests
Expand All @@ -92,12 +116,10 @@ def test_case(self):
"1",
"--test-only",
]
parser = get_parser(args)
tb_args, extra_args = parser.parse_known_args(args)
_run_one_operator(
tb_args,
extra_args,
)
if IS_FBCODE:
_run_one_operator(args)
else:
_run_operator_in_task(op=operator, args=args)

return test_case

Expand Down
10 changes: 5 additions & 5 deletions test/test_gpu/skip_tests_h100_pytorch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ flash_attention:
# - flex_attention
fp8_attention:
- colfax_fmha
# triton_flash_v2 now requires the main branch of Triton
# pytorch version does not work
- triton_flash_v2
fp8_fused_quant_gemm_rowwise:
fp8_gemm:
- triton_persistent_fp8_gemm
Expand All @@ -29,9 +32,6 @@ jagged_layer_norm:
jagged_mean:
jagged_softmax:
jagged_sum:
layer_norm:
low_mem_dropout:
rms_norm:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Liger kernels require triton package rather than pytorch-triton. I assume triton is not conflict with pytorch-triton because pytorch-triton doesn't cover import triton. I tested in local environment and it works well. but not sure if this is a safe way to do so.

Copy link
Contributor Author

@xuzhao9 xuzhao9 Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are planning to have separate tests for triton main and pytorch-triton. Our docker has two conda environments, pytorch and triton-main, so that they can be tested in the same docker.

Right now, we are only deploying tests against pytorch-triton. We will setup the triton main config as skip_tests_h100_triton_main.yaml.

rope:
template_attention:
ragged_attention:
- hstu_triton_ragged_attention_persistent
test_op:
7 changes: 2 additions & 5 deletions tritonbench/operators/addmm/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
import triton
from tritonbench.utils.path_utils import add_path, SUBMODULE_PATH

with add_path(str(SUBMODULE_PATH)):
triton_addmm = importlib.import_module(
"generative-recommenders.ops.triton.triton_addmm"
)
_addmm_fwd = triton_addmm._addmm_fwd
with add_path(str(SUBMODULE_PATH.joinpath("generative-recommenders"))):
from generative_recommenders.ops.triton.triton_addmm import _addmm_fwd


class _AddMmFunction(torch.autograd.Function):
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/operators/fp8_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

from tritonbench.kernels.triton_fused_attention import attention as triton_attention
from tritonbench.kernels.triton_fused_attention import attention_opt as triton_attention
from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Expand Down Expand Up @@ -110,7 +110,7 @@ def triton_flash_v2(
triton_q, triton_k, triton_v = self.triton_preprocess(q, k, v)
# full fp8 will be enabled if type of q,k,v is fp8
return lambda: triton_attention(
triton_q, triton_k, triton_v, False, self.sm_scale
triton_q, triton_k, triton_v, False, self.sm_scale, "base"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @manman-ren attention_opt will compile error on the pytorch version of Triton, does it require the latest Triton main branch?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the breakage. What is the error message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@manman-ren here is the error message: https://github.com/pytorch-labs/tritonbench/actions/runs/11903695593/job/33171153429?pr=55. By default we are using the pytorch built-in Triton in the CI.

)

def get_x_val(self, _example_inputs) -> Tuple[int, int, int, int]:
Expand Down
36 changes: 27 additions & 9 deletions tritonbench/operators/op_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class OpDetails:

name: str
exists: bool
metadata: Dict[str, Any]


class OpTask(base_task.TaskBase):
Expand Down Expand Up @@ -111,7 +110,6 @@ def _maybe_import_operator(package: str, op_name: str) -> Dict[str, Any]:
return {
"name": op_name,
"exists": Operator is not None,
"metadata": {},
}

# =========================================================================
Expand All @@ -121,23 +119,24 @@ def _maybe_import_operator(package: str, op_name: str) -> Dict[str, Any]:
@base_task.run_in_worker(scoped=True)
@staticmethod
def make_operator_instance(
mode: str,
device: str,
extra_args: Optional[List[str]] = None,
args: List[str],
) -> None:
from tritonbench.utils.parser import get_parser

parser = get_parser()
tb_args, extra_args = parser.parse_known_args(args)
Operator = globals()["Operator"]
parser = get_parser()
op = Operator(
mode=mode,
device=device,
tb_args=tb_args,
extra_args=extra_args,
)

import gc

gc.collect()

if device == "cuda":
torch.cuda.empty_cache()
if op.device == "cuda":
maybe_sync = torch.cuda.synchronize
else:
maybe_sync = lambda: None
Expand Down Expand Up @@ -181,6 +180,25 @@ def get_attribute(
else:
return None

# =========================================================================
# == Check output is expected in the child process ========================
# =========================================================================
@base_task.run_in_worker(scoped=True)
@staticmethod
def check_output() -> None:
op = globals()["op"]
from tritonbench.utils.triton_op import REGISTERED_BENCHMARKS

output = op.output
output_impls = output.result[0][1].keys()
ci_enabled_impls = [
x for x in REGISTERED_BENCHMARKS[output.op_name].keys() if x not in op._skip
]
# Make sure that all the ci_enabled impls are in the output
assert set(output_impls) == set(
ci_enabled_impls
), f"output impls: {output_impls} != ci_enabled impls: {ci_enabled_impls}"

def del_op_instance(self):
self.worker.run(
"""
Expand Down
7 changes: 3 additions & 4 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
_ragged_hstu_attn_fwd_persistent = (
triton_ragged_hstu_attention._ragged_hstu_attn_fwd_persistent
)
_RaggedAttentionRelativeBiasFunction = (
triton_ragged_hstu_attention._RaggedAttentionRelativeBiasFunction
RaggedAttentionRelativeBiasFunction = (
triton_ragged_hstu_attention.RaggedAttentionRelativeBiasFunction
)

@torch.fx.wrap
Expand Down Expand Up @@ -150,7 +150,7 @@ def forward(
grid = (1216,)
_ragged_hstu_attn_fwd_persistent[grid](**kwargs)
else:
out = _RaggedAttentionRelativeBiasFunction.apply(
out = RaggedAttentionRelativeBiasFunction.apply(
self.max_seq_len, # N
kwargs["alpha"],
q,
Expand All @@ -169,7 +169,6 @@ def forward(
kwargs["time_delta"], # time_delta
kwargs["max_pos_ind"], # max_pos_ind
kwargs["num_targets"],
None, # attn_scale
kwargs["ATTN_BIAS_TYPE"], # relative_bias_type
kwargs["MAX_ATTN_LEN"], # max_attn_len
kwargs["contextual_seq_len"], # contextual_seq_len
Expand Down
20 changes: 18 additions & 2 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,22 @@ def __call__(cls, *args, **kwargs):
return obj


def _translate_mode(tb_args):
def _has_and_true(attr):
if hasattr(tb_args, attr) and getattr(tb_args, attr):
return True
return False

if _has_and_true("fwd"):
tb_args.mode = "fwd"
if _has_and_true("bwd"):
tb_args.mode = "bwd"
if _has_and_true("fwd_bwd"):
tb_args.mode = "fwd_bwd"
if _has_and_true("fwd_no_grad"):
tb_args.mode = "fwd_no_grad"


class BenchmarkOperator(metaclass=PostInitProcessor):
mode: Mode = Mode.FWD
test: str = "eval"
Expand Down Expand Up @@ -555,7 +571,7 @@ def __init__(
self.use_cuda_graphs = (
self.tb_args.cudagraph if self.tb_args.cudagraph else self.use_cuda_graphs
)
# we accept both "fwd" and "eval"
_translate_mode(self.tb_args)
if self.tb_args.mode == "fwd":
self.mode = Mode.FWD
elif self.tb_args.mode == "fwd_bwd":
Expand All @@ -565,7 +581,7 @@ def __init__(
else:
assert (
self.tb_args.mode == "bwd"
), f"We only accept 3 test modes: fwd(eval), fwd_bwd(train), or bwd."
), "We only accept test modes: fwd, bwd, fwd_bwd, or fwd_no_grad."
self.mode = Mode.BWD
self.device = tb_args.device
self.required_metrics = (
Expand Down
Loading