diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 6740daf..6b49675 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -481,15 +481,22 @@ def get_or_create_triton_kernel( f" {compilation_result.cluster_dims}\n" ) - kernel = triton_kernel_call_lib.TritonKernel( - kernel_name, - num_warps, - compilation_result.shared_mem_bytes, - compilation_result.binary, - ttir, - compute_capability, - *compilation_result.cluster_dims, - ) + try: + kernel = triton_kernel_call_lib.TritonKernel( + kernel_name, + num_warps, + compilation_result.shared_mem_bytes, + compilation_result.binary, + ttir, + compute_capability, + *compilation_result.cluster_dims, + ) + + finally: + if platform == 'rocm': + # the hsaco path is a temporary file that should be removed. + # it's created in `compile_ttir_to_hsaco_inplace`. + os.remove(compilation_result.binary) _COMPILED_KERNEL_CACHE[cache_key] = kernel diff --git a/pyproject.toml b/pyproject.toml index b7d3ea0..71c595f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,9 @@ tests = [ "pytest" ] +experimental = [ + "oryx" +] [build-system] requires = ["setuptools", "setuptools-scm"]