Skip to content

Commit 359dfb4

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Add ufmt linter for pyproject (#47)
Summary: Add linter to make sure PR is consistent with internal `arc lint`. To format the code, run the following in the repo directory: ``` ufmt format . ``` We also test 25 operators in our H100 CI. Note that many of the flash_attention backends do not work right now and we need to fix them. Pull Request resolved: #47 Reviewed By: FindHao Differential Revision: D65709181 Pulled By: xuzhao9 fbshipit-source-id: 5b013906e7b04c8ee41d74db5756de08eec5b5b2
1 parent 7b4a0eb commit 359dfb4

File tree

29 files changed

+206
-106
lines changed

29 files changed

+206
-106
lines changed

.ci/tritonbench/test-gpu.sh

+3
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,7 @@ fi
88

99
. "${SETUP_SCRIPT}"
1010

11+
# install deps
12+
pip install psutil tabulate
13+
1114
python -m unittest test.test_gpu.main

.github/workflows/linter.yaml

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
name: Linter
2+
on:
3+
pull_request:
4+
push:
5+
branches:
6+
- main
7+
workflow_dispatch:
8+
9+
jobs:
10+
pylint:
11+
permissions:
12+
contents: read
13+
runs-on: ubuntu-latest
14+
steps:
15+
- name: Checkout
16+
uses: actions/checkout@v3
17+
with:
18+
path: tritonbench
19+
- name: Install deps
20+
run: |
21+
pip install ruff-api==0.1.0
22+
- name: Check Formatting
23+
uses: omnilib/ufmt@action-v1
24+
with:
25+
path: tritonbench
26+
27+
concurrency:
28+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
29+
cancel-in-progress: true

.github/workflows/pr.yaml

+9-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ on:
55
- .ci/*
66
- tritonbench/*
77
- .github/workflows/pr.yaml
8+
push:
9+
branches:
10+
- main
811

912
jobs:
1013
h100-pytorch-test:
@@ -27,6 +30,10 @@ jobs:
2730
sudo nvidia-smi -pm 1
2831
sudo ldconfig
2932
nvidia-smi
30-
- name: Test Tritonbench operators
33+
- name: Test Tritonbench operators on H100 GPU
3134
run: |
32-
bash ./.ci/tritonbench/test-operators.sh
35+
bash ./.ci/tritonbench/test-gpu.sh
36+
37+
concurrency:
38+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
39+
cancel-in-progress: true

pyproject.toml

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[tool.ufmt]
2+
formatter = "ruff-api"
3+
excludes = ["submodules/"]
4+
5+
[tool.black]
6+
line-length = 88
7+
target-version = ["py312"]

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
packaging
22
pynvml
3+
psutil
4+
tabulate
35
transformers==4.46.1

test/test_cpu/main.py

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88

99
class TestTritonbenchCpu(unittest.TestCase):
10-
1110
def _get_test_op(self):
1211
parser = get_parser(["--device", "cpu", "--op", "test_op"])
1312
tb_args, extra_args = parser.parse_known_args(

test/test_gpu/main.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
fbcode_skip_file_path = "fb/skip_tests_h100_fbcode.yaml"
1919
SKIP_FILE = importlib.resources.files(__package__).joinpath(fbcode_skip_file_path)
2020
else:
21-
SKIP_FILE = "skip_tests_h100_pytorch.yaml"
21+
import os
22+
23+
SKIP_FILE = os.path.abspath(
24+
os.path.join(os.path.dirname(__file__), "skip_tests_h100_pytorch.yaml")
25+
)
2226

2327
with open(SKIP_FILE, "r") as f:
2428
skip_tests = yaml.safe_load(f)
@@ -55,7 +59,7 @@ def _run_one_operator(
5559
):
5660
if tb_args.op in skip_tests:
5761
# If the op itself is in the skip list, skip all tests
58-
if skip_tests[tb_args.op] is None:
62+
if not skip_tests[tb_args.op]:
5963
return
6064
tb_args.skip = ",".join(skip_tests[tb_args.op])
6165
Operator = load_opbench_by_name(tb_args.op)
+35-9
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,37 @@
11
# Tests we skip in OSS CI
22
# This file is regarding to the Triton version bundled with pytorch
3-
# Use <op-name> to skip an entire operator
4-
# Use <op-name/impl-name> to skip an impl
5-
- test_op
6-
- bf16xint16_gemm/bf16xint16
7-
- fp8_attention/colfax_fmha
8-
- fp8_fused_quant_gemm_rowwise
9-
- fp8_gemm/triton_persistent_fp8_gemm
10-
- fp8_gemm/triton_tma_persistent_fp8_gemm
11-
- fp8_gemm_rowwise
3+
# Use <op-name:> to skip an entire operator
4+
# Use <op-name:\n - impl-name> to skip an impl
5+
bf16xint16_gemm:
6+
- bf16xint16
7+
# TODO: we have many buggy backends for flash_attention
8+
# Need to fix them in the CI
9+
flash_attention:
10+
# - triton_tutorial_flash_v2_tma
11+
# - triton_op_flash_v2
12+
# - xformers_splitk
13+
# - colfax_cutlass
14+
# - tk
15+
# - sdpa
16+
# - cudnn
17+
# - flex_attention
18+
fp8_attention:
19+
- colfax_fmha
20+
fp8_fused_quant_gemm_rowwise:
21+
fp8_gemm:
22+
- triton_persistent_fp8_gemm
23+
- triton_tma_persistent_fp8_gemm
24+
fp8_gemm_rowwise:
25+
gemm:
26+
grouped_gemm:
27+
int4_gemm:
28+
jagged_layer_norm:
29+
jagged_mean:
30+
jagged_softmax:
31+
jagged_sum:
32+
layer_norm:
33+
low_mem_dropout:
34+
rms_norm:
35+
rope:
36+
template_attention:
37+
test_op:

tools/cuda_utils.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,5 @@ def check_torch_nightly_version(force_date: Optional[str] = None):
238238
if args.install_torch_nightly:
239239
install_pytorch_nightly(cuda_version=args.cudaver, env=os.environ)
240240
if args.check_torch_nightly_version:
241-
assert (
242-
not args.install_torch_nightly
243-
), "Error: Can't run install torch nightly and check version in the same command."
241+
assert not args.install_torch_nightly, "Error: Can't run install torch nightly and check version in the same command."
244242
check_torch_nightly_version(args.force_date)

tritonbench/components/workers/subprocess_rpc.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,10 @@ def write(self, msg: bytes) -> None:
274274
def get_writer_pid(self) -> int:
275275
assert (
276276
self._writer_pid is not None
277-
), "Writer pid is not specified. Maybe calling from child process or input pipe.\
277+
), (
278+
"Writer pid is not specified. Maybe calling from child process or input pipe.\
278279
Please report a bug."
280+
)
279281
return self._writer_pid
280282

281283
def set_writer_pid(self, writer_pid: int) -> None:

tritonbench/kernels/triton_fused_attention.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737

3838

3939
class TmaAutoTuneHelper:
40-
4140
# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
4241
class KernelParamWrapper:
4342
def __init__(self, desc):
@@ -734,7 +733,6 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
734733
HEAD_DIM: tl.constexpr, #
735734
STAGE: tl.constexpr, #
736735
):
737-
738736
tl.static_assert(BLOCK_N <= HEAD_DIM)
739737
start_m = tl.program_id(0)
740738
off_hz = tl.program_id(1)
@@ -848,7 +846,14 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
848846

849847
@triton.jit
850848
def _attn_bwd_preprocess(
851-
O, DO, Delta, Z, H, N_CTX, BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # # # #
849+
O,
850+
DO,
851+
Delta,
852+
Z,
853+
H,
854+
N_CTX,
855+
BLOCK_M: tl.constexpr,
856+
HEAD_DIM: tl.constexpr, # # # #
852857
):
853858
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
854859
off_hz = tl.program_id(1)
@@ -1179,7 +1184,6 @@ def _attn_bwd(
11791184

11801185

11811186
class _attention_ws(torch.autograd.Function):
1182-
11831187
@staticmethod
11841188
def forward(ctx, q, k, v, causal, sm_scale):
11851189
# shape constraints
@@ -1232,7 +1236,7 @@ def forward(ctx, q, k, v, causal, sm_scale):
12321236
N_CTX=q.shape[2], #
12331237
HEAD_DIM=HEAD_DIM_K, #
12341238
STAGE=stage, #
1235-
**extra_kern_args
1239+
**extra_kern_args,
12361240
)
12371241

12381242
ctx.save_for_backward(q, k, v, o, M)
@@ -1304,7 +1308,6 @@ def backward(ctx, do):
13041308

13051309

13061310
class _attention(torch.autograd.Function):
1307-
13081311
@staticmethod
13091312
def forward(ctx, q, k, v, causal, sm_scale):
13101313
# shape constraints
@@ -1355,7 +1358,7 @@ def forward(ctx, q, k, v, causal, sm_scale):
13551358
N_CTX=q.shape[2], #
13561359
HEAD_DIM=HEAD_DIM_K, #
13571360
STAGE=stage, #
1358-
**extra_kern_args
1361+
**extra_kern_args,
13591362
)
13601363

13611364
ctx.save_for_backward(q, k, v, o, M)
@@ -1427,7 +1430,6 @@ def backward(ctx, do):
14271430

14281431

14291432
class _attention_tma(torch.autograd.Function):
1430-
14311433
@staticmethod
14321434
def forward(ctx, q, k, v, causal, sm_scale):
14331435
# shape constraints
@@ -1587,7 +1589,7 @@ def grid_tma(META):
15871589
N_CTX=q.shape[2], #
15881590
HEAD_DIM=HEAD_DIM_K, #
15891591
STAGE=stage, #
1590-
**extra_kern_args
1592+
**extra_kern_args,
15911593
)
15921594

15931595
ctx.save_for_backward(q, k, v, o, M)

tritonbench/operators/gather_gemv/operator.py

-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727

2828
class Operator(BenchmarkOperator):
29-
3029
@register_metric()
3130
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
3231
arg0_1, arg1_1, arg2_1 = example_inputs

tritonbench/operators/gather_gemv/triton_gather_gemv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def triton_red_fused_mv_0(
8282
rbase = tl.arange(0, RBLOCK)[None, :].to(tl.int64)
8383
x0 = xindex
8484
# x0 // rnumel should have the same value of either 0 or 1
85-
tmp0 = tl.load(in_ptr0 + ((x0 // rnumel)), None, eviction_policy="evict_last")
85+
tmp0 = tl.load(in_ptr0 + (x0 // rnumel), None, eviction_policy="evict_last")
8686
_tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
8787
for roffset in range(0, rnumel, RBLOCK):
8888
rindex = roffset + rbase

tritonbench/operators/jagged_layer_norm/operator.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def parse_op_args(args: List[str]):
3636

3737

3838
class Operator(BenchmarkOperator):
39-
4039
DEFAULT_METRICS = ["latency", "accuracy"]
4140
DEFAULT_PRECISION = "fp32"
4241

@@ -48,8 +47,8 @@ def __init__(
4847
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
4948
):
5049
super().__init__(tb_args, extra_args)
51-
self.sizes = list(range(2, 12, 4)) + list(
52-
range(12, 23, 3)
50+
self.sizes = (
51+
list(range(2, 12, 4)) + list(range(12, 23, 3))
5352
) # bias towards larger sizes, which are more representative of real-world shapes
5453

5554
args = parse_op_args(self.extra_args)
@@ -105,8 +104,8 @@ def _inner():
105104
) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
106105

107106
padded_normalized = (
108-
padded_values - mean
109-
) * padded_mask_values # mask elements outside of the ragged dimension size for correct variance calculation
107+
(padded_values - mean) * padded_mask_values
108+
) # mask elements outside of the ragged dimension size for correct variance calculation
110109

111110
variance = (
112111
torch.sum(

tritonbench/operators/jagged_mean/kernels.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ def triton_jagged_mean_kernel_simple_fused_sum_then_buffer(
5353
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
5454
mask_m = offsets_m < M
5555

56-
ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
57-
input_ptr_offsets + (pid_b + 1)
56+
ragged_start, ragged_end = (
57+
tl.load(input_ptr_offsets + pid_b),
58+
tl.load(input_ptr_offsets + (pid_b + 1)),
5859
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
5960
ragged_len = ragged_end - ragged_start
6061

@@ -133,8 +134,9 @@ def triton_jagged_mean_kernel_simple_fused_buffer_then_sum(
133134
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
134135
mask_m = offsets_m < M
135136

136-
ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
137-
input_ptr_offsets + (pid_b + 1)
137+
ragged_start, ragged_end = (
138+
tl.load(input_ptr_offsets + pid_b),
139+
tl.load(input_ptr_offsets + (pid_b + 1)),
138140
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
139141
ragged_len = ragged_end - ragged_start
140142

@@ -212,8 +214,9 @@ def triton_jagged_mean_kernel_variable_length_loop_sum_then_buffer(
212214
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
213215
mask_m = offsets_m < M
214216

215-
ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
216-
input_ptr_offsets + (pid_b + 1)
217+
ragged_start, ragged_end = (
218+
tl.load(input_ptr_offsets + pid_b),
219+
tl.load(input_ptr_offsets + (pid_b + 1)),
217220
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
218221
ragged_len = ragged_end - ragged_start
219222

@@ -288,8 +291,9 @@ def triton_jagged_mean_kernel_variable_length_loop_buffer_then_sum(
288291
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
289292
mask_m = offsets_m < M
290293

291-
ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load(
292-
input_ptr_offsets + (pid_ragged + 1)
294+
ragged_start, ragged_end = (
295+
tl.load(input_ptr_offsets + pid_ragged),
296+
tl.load(input_ptr_offsets + (pid_ragged + 1)),
293297
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
294298
ragged_len = ragged_end - ragged_start
295299

0 commit comments

Comments
 (0)