@@ -371,10 +371,11 @@ def get_or_create_triton_kernel(
371
371
# `JITFunction._get_config` to get the specialization_attr.
372
372
mock_torch_tensor = types .SimpleNamespace (data_ptr = lambda : 16 )
373
373
args_for_specialization_attr = [mock_torch_tensor ] * len (arg_dtypes )
374
+ backend = backend_init_func (device , compute_capability )
374
375
for i , _ , v in scalar_args :
375
376
args_for_specialization_attr [i ] = v
376
- specialization_attr = fn ._get_config (* args_for_specialization_attr ) # pylint: disable=protected-access
377
377
378
+ specialization_attr = backend .get_attrs_descriptor (fn .params [:len (args_for_specialization_attr )], args_for_specialization_attr ) # pylint: disable=protected-access
378
379
constants = {k : v for k , v in metaparams .items ()}
379
380
constants .update ({k : None for _ , k , v in scalar_args if v is None })
380
381
constants .update ({fn .arg_names [i ]: 1 for i in specialization_attr .equal_to_1 })
@@ -383,7 +384,7 @@ def get_or_create_triton_kernel(
383
384
cache_key = (
384
385
fn ,
385
386
tuple (signature .items ()),
386
- tuple (vars ( specialization_attr ). values () ),
387
+ tuple (specialization_attr . arg_properties ),
387
388
tuple (constants .items ()),
388
389
num_warps ,
389
390
num_stages ,
@@ -403,7 +404,6 @@ def get_or_create_triton_kernel(
403
404
"enable_fp_fusion" : enable_fp_fusion ,
404
405
}
405
406
406
- backend = backend_init_func (device , compute_capability )
407
407
options = backend .parse_options (opts )
408
408
409
409
kernel_hash = abs (hash (cache_key ))
@@ -643,7 +643,7 @@ def prune_configs(configs, named_args, **kwargs):
643
643
kernel_params .append (
644
644
triton_kernel_call_lib .create_array_parameter (
645
645
zeroed_params_with_sizes .get (i , 0 ),
646
- 16 if (i in specialization_attr .divisible_by_16 ) else 0 ,
646
+ 16 if (i in specialization_attr .divisibility_16 ) else 0 ,
647
647
)
648
648
)
649
649
elif i not in specialization_attr .equal_to_1 :
0 commit comments