Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX_COMPILATION_CACHE_EXPECT_PGLE option #24910

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

olupton
Copy link
Contributor

@olupton olupton commented Nov 15, 2024

This aims to provide a better PGLE workflow that is compatible with profiling with Nsight Systems on GPU.
This is non-trivial because CUPTI, the interface used by CUDA profiling tools, only supports profiling by one tool at a time, meaning that the JAX profiler used by PGLE and Nsight Systems conflict with one another.
The PR adds a new JAX config option compilation_cache_expect_pgle that tells JAX to attempt to load PGLE-optimised entries from the compilation cache even if PGLE is disabled, and print warnings on certain unexpected results.
With this, a workflow like:

$ rm -rf /root/jax_cache/
$ export JAX_ENABLE_COMPILATION_CACHE=yes          # not strictly needed, on by default
$ export JAX_COMPILATION_CACHE_DIR=/root/jax_cache # not needed in this example because MaxText configures it
$ JAX_ENABLE_PGLE=yes test-maxtext.sh --model-name=gemma-2b
$ ls -1 /root/jax_cache/*-cache
/root/jax_cache/jit_initialize_state-42f5c604b3add9a249cc00624720755475e29b9ab7007d8f5b781abb34061775-cache
/root/jax_cache/jit_train_step-0ec4a202e1bd3117ab79a0585a981305f2d5b2dfab3b0741445747ba463f2a20-cache
/root/jax_cache/jit_train_step-4a215955e445e4383b9ccf1f283143331a1118edfb9b281f53ae02c793b77ccf-cache
$ JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile test-maxtext.sh --model-name=gemma-2b
...
W1114 02:44:52.551798 140529154335808 compiler.py:381] PERSISTENT CACHE MISS for PGLE-optimized jit_initialize_state despite non-PGLE hit; it may not have been executed enough times when the cache was populated
...

is possible.
Warnings are added in three cases if JAX_COMPILATION_CACHE_EXPECT_PGLE is enabled:

  • If a module is cached without PGLE optimisations but not with them. That is typical of modules that were not executed enough times in the first (cache-populating) run to reach the threshold for recompilation with profile data. This is seen above. We would rely on the user to see "initialize ... not executed enough times" and think "sounds fine".
  • If a module is written to the cache. This is typical of a cache-populating run that did not hit as many code paths as the second run with Nsight Systems + JAX_COMPILATION_CACHE_EXPECT_PGLE.
  • If the PGLE profiler returns an empty profile. This is typical of trying to enable PGLE under Nsight Systems.

Note there are many more modules that are cached neither with nor without PGLE optimisations because they are too small or fast to compile to be cached with the default settings.

This allows using external profiling tools, such as Nsight Systems,
with the automatic PGLE workflow supported by JAX with a simple two-step
workflow:

export JAX_COMPILATION_CACHE_DIR=...
JAX_ENABLE_PGLE=yes python model.py
JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python model.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant