Skip to content

Commit 66c5ab6

Browse files
dfmGoogle-ML-Automation
authored andcommitted
[JAX] Wrap triton_call custom call using the FFI.
We've been putting off porting the `triton_call` kernel to the FFI because it's possible that we could make better use of the FFI features to improve the design. However, XLA is planning to remove support for legacy custom calls in a few months, so we need to update to the newer API. As a stopgap solution, this change wraps the existing kernel so that it can be directly called using the new API version without any major refactoring. This will support XLA's deprecation plan without holding them up with our design discussions! PiperOrigin-RevId: 766642128
1 parent d0cfcec commit 66c5ab6

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

jax_triton/triton_lib.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from jax._src import state
3636
from jax._src import util
3737
from jax._src.lib import gpu_triton as triton_kernel_call_lib
38+
from jax._src.lib import jaxlib_extension_version
3839
import jax.dlpack
3940
import jax.extend as jex
4041
from jax.interpreters import ad
@@ -671,13 +672,22 @@ def prune_configs(configs, named_args, **kwargs):
671672
kernel_call = kernel_calls[0]
672673

673674
call_proto = kernel_call.to_proto(kernel_call_name, serialized_metadata)
674-
rule = jax.ffi.ffi_lowering(
675-
custom_call_target_name,
676-
api_version=2,
677-
backend_config=zlib.compress(call_proto),
678-
operand_output_aliases=dict(input_output_aliases)
679-
)
680-
return rule(ctx, *array_args)
675+
opaque = zlib.compress(call_proto)
676+
if jaxlib_extension_version < 347:
677+
rule = jax.ffi.ffi_lowering(
678+
"triton_kernel_call",
679+
api_version=2,
680+
backend_config=opaque,
681+
operand_output_aliases=dict(input_output_aliases)
682+
)
683+
attrs = {}
684+
else:
685+
rule = jax.ffi.ffi_lowering(
686+
"triton_kernel_call",
687+
operand_output_aliases=dict(input_output_aliases)
688+
)
689+
attrs = dict(opaque=opaque)
690+
return rule(ctx, *array_args, **attrs)
681691

682692
mlir.register_lowering(
683693
triton_kernel_call_p,

0 commit comments

Comments
 (0)