Add JAX_COMPILATION_CACHE_EXPECT_PGLE option #24910
Open
+213
−21
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This aims to provide a better PGLE workflow that is compatible with profiling with Nsight Systems on GPU.
This is non-trivial because CUPTI, the interface used by CUDA profiling tools, only supports profiling by one tool at a time, meaning that the JAX profiler used by PGLE and Nsight Systems conflict with one another.
The PR adds a new JAX config option
compilation_cache_expect_pgle
that tells JAX to attempt to load PGLE-optimised entries from the compilation cache even if PGLE is disabled, and print warnings on certain unexpected results.With this, a workflow like:
is possible.
Warnings are added in three cases if
JAX_COMPILATION_CACHE_EXPECT_PGLE
is enabled:JAX_COMPILATION_CACHE_EXPECT_PGLE
.Note there are many more modules that are cached neither with nor without PGLE optimisations because they are too small or fast to compile to be cached with the default settings.