Skip to content

[Issue]: Unnecessary up casting uint16 to int32 #800

@cagrikymk

Description

@cagrikymk

Problem Description

When doing max, min etc. like operations using {int16, uint16} tensor and constant, the constant and the results get up-casted to int32. While this doesn't cause any inaccuracies, it has negative performance implications for kernels doing a lot of bit manipulations with {int16, uint16}.

I tried couple approaches but failed to keep everything in {int16, uint16}.
The issue happens with 3.3.0+git7dc54920 but not with Version: 3.3.0 (from pip install triton).

Triton 3.3.0+git7dc54920 on H100 also works as expected.

Here is a python script that shows the issue:

import torch
import triton
import triton.language as tl
import argparse
import os

def fill_uniform(shape, dtype):
    x = torch.randint(low=0, high=255, size=shape, device="cuda")
    x = x.to(dtype)
    return x

@triton.jit
def add_constant_kernel(x_ptr, z_ptr,
              M,
              BLOCK_SIZE: tl.constexpr,
              OPT: tl.constexpr):
    pid = tl.program_id(0)
    inds = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
    x = tl.load(x_ptr + inds, mask=inds<M, other=0.0)
    if OPT == "add":
        res = x + 156
    elif OPT == "constexpr-add":
        CONST: tl.constexpr = 156
        res = x + CONST
    elif OPT == "constuint16-add":
        CONST = 156
        res = x + CONST
    elif OPT == "maximum":
        res = tl.maximum(x, 156)

    tl.store(z_ptr + inds, res, mask=inds<M)

def add_constant(x, opt, print_irs=True, ir_prefix=""):
    z = torch.zeros_like(x)

    BLOCK_SIZE = 256
    M = len(x)
    kernel = add_constant_kernel[(triton.cdiv(M,BLOCK_SIZE),)](x, z, M, BLOCK_SIZE, opt)
    if print_irs:
        print_irs_to_files(kernel, ir_prefix + opt)
    return z


def print_irs_to_files(compiled_kernel, prefix):
    for key in compiled_kernel.asm.keys():
        with open(f"{prefix}_{key}.txt", "w") as fptr:
            print(compiled_kernel.asm[key], file=fptr)


def extract_line(file_path, pattern):
    with open(file_path, 'r') as file:
        lines = file.readlines()
        for line in lines:
            if pattern in line:
                return line.strip()
    return "no match"


parser = argparse.ArgumentParser(description="Reproduce type up-cast")
parser.add_argument('--M', type=int, default=1024, help='size of the vector')
parser.add_argument('--path', type=str, default="add_res", help='')
args = parser.parse_args()

if not os.path.exists(args.path):
    os.makedirs(args.path)

in_dtype = torch.uint16
M = args.M
x = fill_uniform((M,), in_dtype)

for opt in ['add', 'constexpr-add', 'constuint16-add', 'maximum']:
    res = add_constant(x, opt, print_irs=True, ir_prefix=f"{args.path}/")
    match = extract_line(f"{args.path}/{opt}_ttir.txt", "arith.constant dense")
    print(opt, "TTIR", match)
    match = extract_line(f"{args.path}/{opt}_ttgir.txt", "arith.constant dense")
    print(opt, "TTGIR", match)

This is what it prints:

add TTIR %cst = arith.constant dense<0> : tensor<256xi16> loc(#loc1)
add TTGIR %cst = arith.constant dense<156> : tensor<256xi16, #blocked> loc(#loc1)
constexpr-add TTIR %cst = arith.constant dense<0> : tensor<256xi16> loc(#loc1)
constexpr-add TTGIR %cst = arith.constant dense<156> : tensor<256xi16, #blocked> loc(#loc1)
constuint16-add TTIR %cst = arith.constant dense<0> : tensor<256xi16> loc(#loc1)
constuint16-add TTGIR %cst = arith.constant dense<156> : tensor<256xi32, #blocked> loc(#loc1)
maximum TTIR %cst = arith.constant dense<0> : tensor<256xi16> loc(#loc1)
maximum TTGIR %cst = arith.constant dense<156> : tensor<256xi32, #blocked> loc(#loc1)

Triton Version

triton==3.3.0+git7dc54920

Operating System

22.04.5 LTS (Jammy Jellyfish)

CPU

Intel(R) Xeon(R) Platinum 8480C

GPU

MI300X

ROCm Version

6.3.2

Metadata

Metadata

Assignees

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions