Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
0d220a5
Replace auto-tuning with hard-coding
voltjia Jan 10, 2025
5ee0bb9
Use the half-precision floating-point format as the data type for arg…
voltjia Jan 10, 2025
7951329
Improve the performance of the 2D convolution compute kernel written …
voltjia Jan 11, 2025
e2e0112
Update the benchmark input dimensions to vary based on batch size ins…
voltjia Jan 11, 2025
b904705
Remove the `other=0.0` argument from the `tl.load` calls in the `trit…
voltjia Jan 11, 2025
ec7bc2d
Add `if __name__ == "__main__"` to `add.py`
voltjia Jan 11, 2025
a8bfbd6
Add `if __name__ == "__main__"` to `softmax.py`
voltjia Jan 11, 2025
f07b748
Add code size comparison
voltjia Jan 11, 2025
33a29fe
Add an example for Root Mean Square Layer Normalization
voltjia Jan 11, 2025
9275994
Remove the unused `n_rows` parameter in the `triton_softmax_kernel` f…
voltjia Jan 11, 2025
83f7596
Add `rms_norm` data into `code_size_comparison.py`
voltjia Jan 11, 2025
9d006c9
Add performance comparison
voltjia Jan 11, 2025
0b12eb8
Fix a precision issue by casting the loaded data to `tl.float32`
voltjia Jan 12, 2025
be4f8c2
Use `atol=0` and `rtol=0` in `when comparing NineToothed and Triton o…
voltjia Jan 12, 2025
ee28ccb
Add `requirements.txt`
voltjia Jan 12, 2025
11d09dd
Add correctness verification during performance testing
voltjia Jan 12, 2025
9d5c731
Update benchmark `x_vals` ranges and use log scaling for performance …
voltjia Jan 12, 2025
60c9a6e
Add statistics into `performance_comparison.py`
voltjia Jan 13, 2025
69b626d
Add statistics into `code_size_comparison.py`
voltjia Jan 13, 2025
13dc6be
Add overall comparison into `code_size_comparison.py`
voltjia Jan 13, 2025
b5674ee
Remove the stream setting in `softmax.py`
voltjia Jan 14, 2025
aafa34a
Use `torch.randn` instead of `torch.rand` in `add.py`
voltjia Jan 14, 2025
0dfaeb0
Use time as the metric for measuring performance
voltjia Jan 14, 2025
bec9334
Add PyTorch data into performance comparison
voltjia Jan 14, 2025
ed92844
Merge branch 'master' of github.com:InfiniTensor/ninetoothed-examples…
voltjia May 13, 2025
0cb93e5
Separate `add` kernels into modular packages
voltjia May 13, 2025
d278433
Rename `matmul` to `mm` and separate the kernels into modular packages
voltjia May 13, 2025
932a1d5
Separate `addmm` kernels into modular packages
voltjia May 13, 2025
97b2842
Separate `conv2d` kernels into modular packages
voltjia May 13, 2025
4db4e5f
Separate `softmax` kernels into modular packages
voltjia May 13, 2025
236add4
Separate `rms_norm` kernels into modular packages
voltjia May 14, 2025
a33a14c
Rename `ops.triton.torch.triton_conv2d` to `ops.triton.torch.conv2d`
voltjia May 14, 2025
be1f694
Rename `attention` to `scaled_dot_product_attention` and separate the…
voltjia May 14, 2025
6952672
Use `dtype` access instead of hardcoding
voltjia May 14, 2025
fc45cea
Fix the boundary issues
voltjia May 14, 2025
db86285
Improve the Triton `conv2d` implementation
voltjia May 14, 2025
29bf2f9
Refactor `import mm` to `import ops.ninetoothed.kernels.mm as mm` in …
voltjia May 14, 2025
f87582f
Add `compare_code_metrics.py`
voltjia May 14, 2025
406c18e
Separate `bmm` kernels into modular packages
Ziminli May 15, 2025
114f12f
Separate `fused_rms_norm` kernels into modular packages
Ziminli May 15, 2025
da3e0b5
Update the `bmm` function call in `linear.py`
Ziminli May 15, 2025
dbc429a
Separate `silu` kernels into modular packages
Ziminli May 15, 2025
ebe1681
Separate `swiglu` kernels into modular packages
Ziminli May 15, 2025
596455f
Fix the Triton implementation in `scaled_dot_product_attention.py`
Ziminli May 15, 2025
3585cfa
Add inference profiling support
Ziminli May 15, 2025
db05592
Add Triton and PyTorch implementations of non-interleaved RoPE
voltjia May 15, 2025
9db2eeb
Separate `rope` kernels into modular packages
voltjia May 15, 2025
8086ccc
Add context managers to select the backend to use for inference
voltjia May 15, 2025
d711d57
Relax tolerance in `scaled_dot_product_attention.py`
voltjia May 16, 2025
379ca83
Improve inference logging with JSON output and token-level performanc…
voltjia May 16, 2025
14b8582
Extract the backslash character into a constant
voltjia May 16, 2025
a5319d2
Update `.gitignore` to exclude evaluation result files
voltjia May 19, 2025
72d552c
Remove upper bound constraint in `conv2d.py`
voltjia May 19, 2025
c8ae164
Rename `rope` to `rotary_position_embedding`
voltjia May 19, 2025
4ccc2bb
Generate a single table instead of multiple tables
voltjia May 19, 2025
f3a8db1
Replace the manually specified lists of Triton `Config` objects with …
voltjia May 21, 2025
a9cefab
Replace `torch.rand` with `torch.randn`
voltjia May 21, 2025
276c5ab
Standardize `plot_name` values across scripts
voltjia May 21, 2025
f75f841
Add `compare_performance_metrics.py`
voltjia May 21, 2025
123bd37
Add `run_experiments.py`
voltjia May 21, 2025
0f8477c
Run tasks in `run_experiments.py` instead of `compare_performance_met…
voltjia May 21, 2025
27ae4b9
Add end-to-end model inference throughput comparison plot
voltjia May 21, 2025
0517108
Add `transformers` and `radon` to `requirements.txt`
voltjia May 22, 2025
bd6a730
Refactor CSV export to occur within task-processing loop
voltjia May 23, 2025
772cae8
Rename the output CSV file from `performance-metrics.csv` to `microbe…
voltjia May 23, 2025
287fc78
Rename output image files
voltjia May 23, 2025
b04acf8
Rename evaluation scripts and output files for consistency and clarit…
voltjia May 23, 2025
e8a0d1a
Use `torch.utils.collect_env` in `run_experiments.py`
voltjia May 23, 2025
7798e39
Update `README.md`
voltjia Jul 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,10 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Evaluation results
*.csv
*.html
*.json
*.png
*.tex
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# NineToothed Examples

This repository contains examples for [NineToothed](https://github.com/InfiniTensor/ninetoothed), including implementations of several common compute kernels written using NineToothed.
This repository contains examples of [NineToothed](https://github.com/InfiniTensor/ninetoothed), including implementations of several common compute kernels written using NineToothed.

## Usage

After cloning this repository, you can run any of the examples using Python. For instance, to run the matrix multiplication example, execute the following command:

```bash
python matmul.py
python mm.py
```

### Autotuning Behavior

By default, the examples apply autotuning, which may take several minutes or longer to complete for complex kernels. If you wish to disable autotuning, you can replace symbol definitions with concrete values. Consider the following example:
Some examples apply autotuning, which may take several minutes or longer to complete for complex kernels. If you wish to disable autotuning, you can replace symbol definitions with concrete values.

Consider the following example:

```python
BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True)
Expand All @@ -29,6 +31,8 @@ BLOCK_SIZE = 1024

These approaches allow you to obtain results in seconds. However, selecting optimal values is crucial for good performance. Experiment with different values to determine the best configuration.

Note: Please don't forget to also disable the autotuning of the corresponding Triton compute kernels.

## Third-Party Code and Licenses

This project includes code modified or inspired from the following open-source repositories:
Expand Down
170 changes: 61 additions & 109 deletions add.py
Original file line number Diff line number Diff line change
@@ -1,118 +1,70 @@
import ninetoothed
import torch
import triton
import triton.language as tl
from ninetoothed import Symbol, Tensor

BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True)


@ninetoothed.jit
def add_kernel(
lhs: Tensor(1).tile((BLOCK_SIZE,)),
rhs: Tensor(1).tile((BLOCK_SIZE,)),
output: Tensor(1).tile((BLOCK_SIZE,)),
):
output = lhs + rhs # noqa: F841


def add(lhs, rhs):
output = torch.empty_like(lhs)

add_kernel(lhs, rhs, output)

return output


@triton.jit
def triton_add_kernel(
lhs_ptr,
rhs_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)

block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

lhs = tl.load(lhs_ptr + offsets, mask=mask)
rhs = tl.load(rhs_ptr + offsets, mask=mask)
output = lhs + rhs

tl.store(output_ptr + offsets, output, mask=mask)


def triton_add(lhs, rhs):
output = torch.empty_like(lhs)
n_elements = output.numel()

def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)

triton_add_kernel[grid](lhs, rhs, output, n_elements, BLOCK_SIZE=1024)

return output


torch.manual_seed(0)
size = 98432
lhs = torch.rand(size, device="cuda")
rhs = torch.rand(size, device="cuda")
ninetoothed_output = add(lhs, rhs)
torch_output = lhs + rhs
triton_output = triton_add(lhs, rhs)
print(ninetoothed_output)
print(torch_output)
print(triton_output)
if torch.allclose(ninetoothed_output, torch_output):
print("✅ NineToothed and PyTorch match.")
else:
print("❌ NineToothed and PyTorch differ.")
if torch.allclose(ninetoothed_output, triton_output):
print("✅ NineToothed and Triton match.")
else:
print("❌ NineToothed and Triton differ.")


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["size"],
x_vals=[2**i for i in range(12, 28, 1)],
x_log=True,
line_arg="provider",
line_vals=["ninetoothed", "torch", "triton"],
line_names=["NineToothed", "PyTorch", "Triton"],
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
ylabel="GB/s",
plot_name="vector-addition-performance",
args={},
import ops.ninetoothed.torch
import ops.triton.torch

if __name__ == "__main__":
torch.manual_seed(0)

size = 98432
dtype = torch.float16
device = "cuda"

input = torch.randn(size, dtype=dtype, device=device)
other = torch.randn(size, dtype=dtype, device=device)

ninetoothed_output = ops.ninetoothed.torch.add(input, other)
torch_output = input + other
triton_output = ops.triton.torch.add(input, other)

print(ninetoothed_output)
print(torch_output)
print(triton_output)

if torch.allclose(ninetoothed_output, torch_output):
print("✅ NineToothed and PyTorch match.")
else:
print("❌ NineToothed and PyTorch differ.")
if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0):
print("✅ NineToothed and Triton match.")
else:
print("❌ NineToothed and Triton differ.")

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["size"],
x_vals=[2**i for i in range(18, 28)],
x_log=True,
line_arg="provider",
line_vals=["ninetoothed", "torch", "triton"],
line_names=["NineToothed", "PyTorch", "Triton"],
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
ylabel="ms",
plot_name="add-performance",
args={},
)
)
)
def benchmark(size, provider):
lhs = torch.rand(size, device="cuda", dtype=torch.float32)
rhs = torch.rand(size, device="cuda", dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
def benchmark(size, provider):
input = torch.randn(size, dtype=dtype, device=device)
other = torch.randn(size, dtype=dtype, device=device)

if provider == "ninetoothed":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: add(lhs, rhs), quantiles=quantiles
)
elif provider == "torch":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: lhs + rhs, quantiles=quantiles
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_add(lhs, rhs), quantiles=quantiles
)
ninetoothed_output = ops.ninetoothed.torch.add(input, other)
torch_output = torch.add(input, other)
triton_output = ops.triton.torch.add(input, other)

def gbps(ms):
return 3 * lhs.numel() * lhs.element_size() / ms * 1e-6
assert torch.allclose(ninetoothed_output, torch_output)
assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0)

return gbps(ms), gbps(max_ms), gbps(min_ms)
if provider == "ninetoothed":
ms = triton.testing.do_bench(
lambda: ops.ninetoothed.torch.add(input, other)
)
elif provider == "torch":
ms = triton.testing.do_bench(lambda: torch.add(input, other))
elif provider == "triton":
ms = triton.testing.do_bench(lambda: ops.triton.torch.add(input, other))

return ms

benchmark.run(print_data=True, show_plots=True, save_path=".")
benchmark.run(print_data=True, show_plots=True, save_path=".")
Loading