-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathsoftmax.py
More file actions
65 lines (52 loc) · 2.19 KB
/
softmax.py
File metadata and controls
65 lines (52 loc) · 2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import triton
import ops.ninetoothed.torch
import ops.triton.torch
if __name__ == "__main__":
torch.manual_seed(0)
dtype = torch.float16
device = "cuda"
input = torch.randn(1823, 781, dtype=dtype, device=device)
ninetoothed_output = ops.ninetoothed.torch.softmax(input)
torch_output = torch.softmax(input, axis=-1)
triton_output = ops.triton.torch.softmax(input)
print(ninetoothed_output)
print(torch_output)
print(triton_output)
if torch.allclose(ninetoothed_output, torch_output, atol=0.001):
print("✅ NineToothed and PyTorch match.")
else:
print("❌ NineToothed and PyTorch differ.")
if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0):
print("✅ NineToothed and Triton match.")
else:
print("❌ NineToothed and Triton differ.")
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["n"],
x_vals=[2**i for i in range(5, 15)],
x_log=True,
line_arg="provider",
line_vals=["ninetoothed", "torch", "triton"],
line_names=["NineToothed", "PyTorch", "Triton"],
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
ylabel="ms",
plot_name="softmax-performance",
args={"m": 4096},
)
)
def benchmark(m, n, provider):
input = torch.randn(m, n, dtype=dtype, device=device)
ninetoothed_output = ops.ninetoothed.torch.softmax(input)
torch_output = torch.softmax(input, axis=-1)
triton_output = ops.triton.torch.softmax(input)
assert torch.allclose(ninetoothed_output, torch_output, atol=0.001)
assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0)
if provider == "ninetoothed":
ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.softmax(input))
elif provider == "torch":
ms = triton.testing.do_bench(lambda: torch.softmax(input, axis=-1))
elif provider == "triton":
ms = triton.testing.do_bench(lambda: ops.triton.torch.softmax(input))
return ms
benchmark.run(show_plots=True, print_data=True, save_path=".")