Skip to content

Commit 0722d73

Browse files
Aliia KhasanovaGoogle-ML-Automation
authored andcommitted
Integrate Triton up to [68aa962e67baa191cec5aac173255abdba80db1a](https://github.com/openai/triton/commits/68aa962e67baa191cec5aac173255abdba80db1a)
PiperOrigin-RevId: 684403022
1 parent c7b8cd5 commit 0722d73

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

jax_triton/triton_lib.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,11 @@ def get_or_create_triton_kernel(
371371
# `JITFunction._get_config` to get the specialization_attr.
372372
mock_torch_tensor = types.SimpleNamespace(data_ptr=lambda: 16)
373373
args_for_specialization_attr = [mock_torch_tensor] * len(arg_dtypes)
374+
backend = backend_init_func(device, compute_capability)
374375
for i, _, v in scalar_args:
375376
args_for_specialization_attr[i] = v
376-
specialization_attr = fn._get_config(*args_for_specialization_attr) # pylint: disable=protected-access
377377

378+
specialization_attr = backend.get_attrs_descriptor(fn.params[:len(args_for_specialization_attr)], args_for_specialization_attr) # pylint: disable=protected-access
378379
constants = {k: v for k, v in metaparams.items()}
379380
constants.update({k: None for _, k, v in scalar_args if v is None})
380381
constants.update({fn.arg_names[i]: 1 for i in specialization_attr.equal_to_1})
@@ -383,7 +384,7 @@ def get_or_create_triton_kernel(
383384
cache_key = (
384385
fn,
385386
tuple(signature.items()),
386-
tuple(vars(specialization_attr).values()),
387+
tuple(specialization_attr.arg_properties),
387388
tuple(constants.items()),
388389
num_warps,
389390
num_stages,
@@ -403,7 +404,6 @@ def get_or_create_triton_kernel(
403404
"enable_fp_fusion": enable_fp_fusion,
404405
}
405406

406-
backend = backend_init_func(device, compute_capability)
407407
options = backend.parse_options(opts)
408408

409409
kernel_hash = abs(hash(cache_key))
@@ -643,7 +643,7 @@ def prune_configs(configs, named_args, **kwargs):
643643
kernel_params.append(
644644
triton_kernel_call_lib.create_array_parameter(
645645
zeroed_params_with_sizes.get(i, 0),
646-
16 if (i in specialization_attr.divisible_by_16) else 0,
646+
16 if (i in specialization_attr.divisibility_16) else 0,
647647
)
648648
)
649649
elif i not in specialization_attr.equal_to_1:

tests/triton_call_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,10 +564,10 @@ def test_specialization(self):
564564
# Pointers are assumed to divide by 16, as do `M`, `N`, `stride_{bk,cm}`.
565565
# However, we've marked `a_ptr`, `M`, `stride_bk`, and `c_ptr` as "do not
566566
# specialize", leaving `b_ptr`, `N`, and `stride_cm`.
567-
self.assertEqual(specialization.attrs.divisible_by_16, (1, 3, 9))
567+
self.assertEqual(specialization.attrs.divisibility_16, [1, 3, 9])
568568
# `stride_{ak,bn,cn}` equal 1, but we've marked `stride_ak` as "do not
569569
# specialize" leaving `stride_{bn,cn}`.
570-
self.assertEqual(specialization.attrs.equal_to_1, (8, 10))
570+
self.assertEqual(specialization.attrs.equal_to_1, [8, 10])
571571

572572

573573
if __name__ == "__main__":

0 commit comments

Comments
 (0)