forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 35
Open
Labels
Description
Problem Description
When doing max, min etc. like operations using {int16, uint16} tensor and constant, the constant and the results get up-casted to int32. While this doesn't cause any inaccuracies, it has negative performance implications for kernels doing a lot of bit manipulations with {int16, uint16}.
I tried couple approaches but failed to keep everything in {int16, uint16}.
The issue happens with 3.3.0+git7dc54920 but not with Version: 3.3.0 (from pip install triton).
Triton 3.3.0+git7dc54920 on H100 also works as expected.
Here is a python script that shows the issue:
import torch
import triton
import triton.language as tl
import argparse
import os
def fill_uniform(shape, dtype):
x = torch.randint(low=0, high=255, size=shape, device="cuda")
x = x.to(dtype)
return x
@triton.jit
def add_constant_kernel(x_ptr, z_ptr,
M,
BLOCK_SIZE: tl.constexpr,
OPT: tl.constexpr):
pid = tl.program_id(0)
inds = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
x = tl.load(x_ptr + inds, mask=inds<M, other=0.0)
if OPT == "add":
res = x + 156
elif OPT == "constexpr-add":
CONST: tl.constexpr = 156
res = x + CONST
elif OPT == "constuint16-add":
CONST = 156
res = x + CONST
elif OPT == "maximum":
res = tl.maximum(x, 156)
tl.store(z_ptr + inds, res, mask=inds<M)
def add_constant(x, opt, print_irs=True, ir_prefix=""):
z = torch.zeros_like(x)
BLOCK_SIZE = 256
M = len(x)
kernel = add_constant_kernel[(triton.cdiv(M,BLOCK_SIZE),)](x, z, M, BLOCK_SIZE, opt)
if print_irs:
print_irs_to_files(kernel, ir_prefix + opt)
return z
def print_irs_to_files(compiled_kernel, prefix):
for key in compiled_kernel.asm.keys():
with open(f"{prefix}_{key}.txt", "w") as fptr:
print(compiled_kernel.asm[key], file=fptr)
def extract_line(file_path, pattern):
with open(file_path, 'r') as file:
lines = file.readlines()
for line in lines:
if pattern in line:
return line.strip()
return "no match"
parser = argparse.ArgumentParser(description="Reproduce type up-cast")
parser.add_argument('--M', type=int, default=1024, help='size of the vector')
parser.add_argument('--path', type=str, default="add_res", help='')
args = parser.parse_args()
if not os.path.exists(args.path):
os.makedirs(args.path)
in_dtype = torch.uint16
M = args.M
x = fill_uniform((M,), in_dtype)
for opt in ['add', 'constexpr-add', 'constuint16-add', 'maximum']:
res = add_constant(x, opt, print_irs=True, ir_prefix=f"{args.path}/")
match = extract_line(f"{args.path}/{opt}_ttir.txt", "arith.constant dense")
print(opt, "TTIR", match)
match = extract_line(f"{args.path}/{opt}_ttgir.txt", "arith.constant dense")
print(opt, "TTGIR", match)
This is what it prints:
add TTIR %cst = arith.constant dense<0> : tensor<256xi16> loc(#loc1)
add TTGIR %cst = arith.constant dense<156> : tensor<256xi16, #blocked> loc(#loc1)
constexpr-add TTIR %cst = arith.constant dense<0> : tensor<256xi16> loc(#loc1)
constexpr-add TTGIR %cst = arith.constant dense<156> : tensor<256xi16, #blocked> loc(#loc1)
constuint16-add TTIR %cst = arith.constant dense<0> : tensor<256xi16> loc(#loc1)
constuint16-add TTGIR %cst = arith.constant dense<156> : tensor<256xi32, #blocked> loc(#loc1)
maximum TTIR %cst = arith.constant dense<0> : tensor<256xi16> loc(#loc1)
maximum TTGIR %cst = arith.constant dense<156> : tensor<256xi32, #blocked> loc(#loc1)
Triton Version
triton==3.3.0+git7dc54920
Operating System
22.04.5 LTS (Jammy Jellyfish)
CPU
Intel(R) Xeon(R) Platinum 8480C
GPU
MI300X
ROCm Version
6.3.2