|
11 | 11 |
|
12 | 12 | from vllm.triton_utils import tl, tldevice, triton |
13 | 13 |
|
| 14 | +from .utils import is_gather_supported |
| 15 | + |
14 | 16 | if os.environ.get("FLA_USE_FAST_OPS", "0") == "1": |
15 | | - div = tldevice.fast_dividef |
16 | 17 | exp = tldevice.fast_expf |
17 | 18 | log = tldevice.fast_logf |
18 | 19 | log2 = tldevice.fast_log2f |
19 | 20 | else: |
20 | | - |
21 | | - @triton.jit |
22 | | - def div_normal(x, y): |
23 | | - return x / y |
24 | | - |
25 | | - div = div_normal |
26 | 21 | exp = tl.exp |
27 | 22 | log = tl.log |
28 | 23 | log2 = tl.log2 |
29 | 24 |
|
30 | 25 |
|
31 | | -if not hasattr(tl, "gather"): |
| 26 | +if not is_gather_supported: |
32 | 27 |
|
33 | 28 | @triton.jit |
34 | 29 | def gather(src, index, axis, _builder=None): |
35 | | - # This is a fallback implementation when tl.gather is not supported |
36 | | - # In order to pass triton compiler, there is no actual gather operation |
37 | | - return src |
| 30 | + """ |
| 31 | + Gather operation that works when tl.gather is not supported. |
| 32 | + This is a fallback implementation that returns None. |
| 33 | + Just to make triton compiler happy. |
| 34 | + """ |
| 35 | + return None |
38 | 36 | else: |
39 | 37 | gather = tl.gather |
| 38 | + |
| 39 | +if hasattr(triton.language, "_experimental_make_tensor_descriptor"): |
| 40 | + # For Triton 3.3.x |
| 41 | + make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor |
| 42 | +elif hasattr(triton.language, "make_tensor_descriptor"): |
| 43 | + # For Triton 3.4.x and later |
| 44 | + make_tensor_descriptor = triton.language.make_tensor_descriptor |
| 45 | +else: |
| 46 | + """ |
| 47 | + Fallback implementation when TMA is not supported. |
| 48 | + Returns None to indicate TMA descriptors are unavailable. |
| 49 | + Just make triton compiler happy. |
| 50 | + """ |
| 51 | + |
| 52 | + @triton.jit |
| 53 | + def make_tensor_descriptor( |
| 54 | + base, |
| 55 | + shape, |
| 56 | + strides, |
| 57 | + block_shape, |
| 58 | + _builder=None, |
| 59 | + ): |
| 60 | + return None |
0 commit comments