Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ufmt linter for pyproject #47

Closed
wants to merge 20 commits into from
3 changes: 3 additions & 0 deletions .ci/tritonbench/test-gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ fi

. "${SETUP_SCRIPT}"

# install deps
pip install psutil tabulate

python -m unittest test.test_gpu.main
29 changes: 29 additions & 0 deletions .github/workflows/linter.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: Linter
on:
pull_request:
push:
branches:
- main
workflow_dispatch:

jobs:
pylint:
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
with:
path: tritonbench
- name: Install deps
run: |
pip install ruff-api==0.1.0
- name: Check Formatting
uses: omnilib/ufmt@action-v1
with:
path: tritonbench

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
11 changes: 9 additions & 2 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ on:
- .ci/*
- tritonbench/*
- .github/workflows/pr.yaml
push:
branches:
- main

jobs:
h100-pytorch-test:
Expand All @@ -27,6 +30,10 @@ jobs:
sudo nvidia-smi -pm 1
sudo ldconfig
nvidia-smi
- name: Test Tritonbench operators
- name: Test Tritonbench operators on H100 GPU
run: |
bash ./.ci/tritonbench/test-operators.sh
bash ./.ci/tritonbench/test-gpu.sh

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[tool.ufmt]
formatter = "ruff-api"
excludes = ["submodules/"]

[tool.black]
line-length = 88
target-version = ["py312"]
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
packaging
pynvml
psutil
tabulate
transformers==4.46.1
1 change: 0 additions & 1 deletion test/test_cpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class TestTritonbenchCpu(unittest.TestCase):

def _get_test_op(self):
parser = get_parser(["--device", "cpu", "--op", "test_op"])
tb_args, extra_args = parser.parse_known_args(
Expand Down
8 changes: 6 additions & 2 deletions test/test_gpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
fbcode_skip_file_path = "fb/skip_tests_h100_fbcode.yaml"
SKIP_FILE = importlib.resources.files(__package__).joinpath(fbcode_skip_file_path)
else:
SKIP_FILE = "skip_tests_h100_pytorch.yaml"
import os

SKIP_FILE = os.path.abspath(
os.path.join(os.path.dirname(__file__), "skip_tests_h100_pytorch.yaml")
)

with open(SKIP_FILE, "r") as f:
skip_tests = yaml.safe_load(f)
Expand Down Expand Up @@ -55,7 +59,7 @@ def _run_one_operator(
):
if tb_args.op in skip_tests:
# If the op itself is in the skip list, skip all tests
if skip_tests[tb_args.op] is None:
if not skip_tests[tb_args.op]:
return
tb_args.skip = ",".join(skip_tests[tb_args.op])
Operator = load_opbench_by_name(tb_args.op)
Expand Down
44 changes: 35 additions & 9 deletions test/test_gpu/skip_tests_h100_pytorch.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,37 @@
# Tests we skip in OSS CI
# This file is regarding to the Triton version bundled with pytorch
# Use <op-name> to skip an entire operator
# Use <op-name/impl-name> to skip an impl
- test_op
- bf16xint16_gemm/bf16xint16
- fp8_attention/colfax_fmha
- fp8_fused_quant_gemm_rowwise
- fp8_gemm/triton_persistent_fp8_gemm
- fp8_gemm/triton_tma_persistent_fp8_gemm
- fp8_gemm_rowwise
# Use <op-name:> to skip an entire operator
# Use <op-name:\n - impl-name> to skip an impl
bf16xint16_gemm:
- bf16xint16
# TODO: we have many buggy backends for flash_attention
# Need to fix them in the CI
flash_attention:
# - triton_tutorial_flash_v2_tma
# - triton_op_flash_v2
# - xformers_splitk
# - colfax_cutlass
# - tk
# - sdpa
# - cudnn
# - flex_attention
fp8_attention:
- colfax_fmha
fp8_fused_quant_gemm_rowwise:
fp8_gemm:
- triton_persistent_fp8_gemm
- triton_tma_persistent_fp8_gemm
fp8_gemm_rowwise:
gemm:
grouped_gemm:
int4_gemm:
jagged_layer_norm:
jagged_mean:
jagged_softmax:
jagged_sum:
layer_norm:
low_mem_dropout:
rms_norm:
rope:
template_attention:
test_op:
4 changes: 1 addition & 3 deletions tools/cuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,5 @@ def check_torch_nightly_version(force_date: Optional[str] = None):
if args.install_torch_nightly:
install_pytorch_nightly(cuda_version=args.cudaver, env=os.environ)
if args.check_torch_nightly_version:
assert (
not args.install_torch_nightly
), "Error: Can't run install torch nightly and check version in the same command."
assert not args.install_torch_nightly, "Error: Can't run install torch nightly and check version in the same command."
check_torch_nightly_version(args.force_date)
4 changes: 3 additions & 1 deletion tritonbench/components/workers/subprocess_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,10 @@ def write(self, msg: bytes) -> None:
def get_writer_pid(self) -> int:
assert (
self._writer_pid is not None
), "Writer pid is not specified. Maybe calling from child process or input pipe.\
), (
"Writer pid is not specified. Maybe calling from child process or input pipe.\
Please report a bug."
)
return self._writer_pid

def set_writer_pid(self, writer_pid: int) -> None:
Expand Down
17 changes: 10 additions & 7 deletions tritonbench/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@


class TmaAutoTuneHelper:

# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
class KernelParamWrapper:
def __init__(self, desc):
Expand Down Expand Up @@ -457,7 +456,6 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
HEAD_DIM: tl.constexpr, #
STAGE: tl.constexpr, #
):

tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
Expand Down Expand Up @@ -569,7 +567,14 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, #

@triton.jit
def _attn_bwd_preprocess(
O, DO, Delta, Z, H, N_CTX, BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # # # #
O,
DO,
Delta,
Z,
H,
N_CTX,
BLOCK_M: tl.constexpr,
HEAD_DIM: tl.constexpr, # # # #
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_hz = tl.program_id(1)
Expand Down Expand Up @@ -900,7 +905,6 @@ def _attn_bwd(


class _attention(torch.autograd.Function):

@staticmethod
def forward(ctx, q, k, v, causal, sm_scale):
# shape constraints
Expand Down Expand Up @@ -949,7 +953,7 @@ def forward(ctx, q, k, v, causal, sm_scale):
N_CTX=q.shape[2], #
HEAD_DIM=HEAD_DIM_K, #
STAGE=stage, #
**extra_kern_args
**extra_kern_args,
)

ctx.save_for_backward(q, k, v, o, M)
Expand Down Expand Up @@ -1021,7 +1025,6 @@ def backward(ctx, do):


class _attention_tma(torch.autograd.Function):

@staticmethod
def forward(ctx, q, k, v, causal, sm_scale):
# shape constraints
Expand Down Expand Up @@ -1175,7 +1178,7 @@ def grid_tma(META):
N_CTX=q.shape[2], #
HEAD_DIM=HEAD_DIM_K, #
STAGE=stage, #
**extra_kern_args
**extra_kern_args,
)

ctx.save_for_backward(q, k, v, o, M)
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/gather_gemv/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@


class Operator(BenchmarkOperator):

@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
arg0_1, arg1_1, arg2_1 = example_inputs
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/gather_gemv/triton_gather_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def triton_red_fused_mv_0(
rbase = tl.arange(0, RBLOCK)[None, :].to(tl.int64)
x0 = xindex
# x0 // rnumel should have the same value of either 0 or 1
tmp0 = tl.load(in_ptr0 + ((x0 // rnumel)), None, eviction_policy="evict_last")
tmp0 = tl.load(in_ptr0 + (x0 // rnumel), None, eviction_policy="evict_last")
_tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
Expand Down
9 changes: 4 additions & 5 deletions tritonbench/operators/jagged_layer_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def parse_op_args(args: List[str]):


class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]
DEFAULT_PRECISION = "fp32"

Expand All @@ -48,8 +47,8 @@ def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
self.sizes = list(range(2, 12, 4)) + list(
range(12, 23, 3)
self.sizes = (
list(range(2, 12, 4)) + list(range(12, 23, 3))
) # bias towards larger sizes, which are more representative of real-world shapes

args = parse_op_args(self.extra_args)
Expand Down Expand Up @@ -105,8 +104,8 @@ def _inner():
) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm

padded_normalized = (
padded_values - mean
) * padded_mask_values # mask elements outside of the ragged dimension size for correct variance calculation
(padded_values - mean) * padded_mask_values
) # mask elements outside of the ragged dimension size for correct variance calculation

variance = (
torch.sum(
Expand Down
20 changes: 12 additions & 8 deletions tritonbench/operators/jagged_mean/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def triton_jagged_mean_kernel_simple_fused_sum_then_buffer(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
input_ptr_offsets + (pid_b + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_b),
tl.load(input_ptr_offsets + (pid_b + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
ragged_len = ragged_end - ragged_start

Expand Down Expand Up @@ -133,8 +134,9 @@ def triton_jagged_mean_kernel_simple_fused_buffer_then_sum(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
input_ptr_offsets + (pid_b + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_b),
tl.load(input_ptr_offsets + (pid_b + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
ragged_len = ragged_end - ragged_start

Expand Down Expand Up @@ -212,8 +214,9 @@ def triton_jagged_mean_kernel_variable_length_loop_sum_then_buffer(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
input_ptr_offsets + (pid_b + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_b),
tl.load(input_ptr_offsets + (pid_b + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
ragged_len = ragged_end - ragged_start

Expand Down Expand Up @@ -288,8 +291,9 @@ def triton_jagged_mean_kernel_variable_length_loop_buffer_then_sum(
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load(
input_ptr_offsets + (pid_ragged + 1)
ragged_start, ragged_end = (
tl.load(input_ptr_offsets + pid_ragged),
tl.load(input_ptr_offsets + (pid_ragged + 1)),
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
ragged_len = ragged_end - ragged_start

Expand Down
Loading
Loading