Skip to content

Commit 9fce7be

Browse files
authored
[Kernel] Accelerate solve_tril with TMA (vllm-project#26746)
Signed-off-by: zjy0516 <[email protected]>
1 parent b63f214 commit 9fce7be

File tree

3 files changed

+412
-301
lines changed

3 files changed

+412
-301
lines changed

vllm/model_executor/layers/fla/ops/op.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,50 @@
1111

1212
from vllm.triton_utils import tl, tldevice, triton
1313

14+
from .utils import is_gather_supported
15+
1416
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
15-
div = tldevice.fast_dividef
1617
exp = tldevice.fast_expf
1718
log = tldevice.fast_logf
1819
log2 = tldevice.fast_log2f
1920
else:
20-
21-
@triton.jit
22-
def div_normal(x, y):
23-
return x / y
24-
25-
div = div_normal
2621
exp = tl.exp
2722
log = tl.log
2823
log2 = tl.log2
2924

3025

31-
if not hasattr(tl, "gather"):
26+
if not is_gather_supported:
3227

3328
@triton.jit
3429
def gather(src, index, axis, _builder=None):
35-
# This is a fallback implementation when tl.gather is not supported
36-
# In order to pass triton compiler, there is no actual gather operation
37-
return src
30+
"""
31+
Gather operation that works when tl.gather is not supported.
32+
This is a fallback implementation that returns None.
33+
Just to make triton compiler happy.
34+
"""
35+
return None
3836
else:
3937
gather = tl.gather
38+
39+
if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
40+
# For Triton 3.3.x
41+
make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
42+
elif hasattr(triton.language, "make_tensor_descriptor"):
43+
# For Triton 3.4.x and later
44+
make_tensor_descriptor = triton.language.make_tensor_descriptor
45+
else:
46+
"""
47+
Fallback implementation when TMA is not supported.
48+
Returns None to indicate TMA descriptors are unavailable.
49+
Just make triton compiler happy.
50+
"""
51+
52+
@triton.jit
53+
def make_tensor_descriptor(
54+
base,
55+
shape,
56+
strides,
57+
block_shape,
58+
_builder=None,
59+
):
60+
return None

0 commit comments

Comments
 (0)