Skip to content

Commit e7230ca

Browse files
authored
Disable JITs in JAX and PyTorch to enable effective profiling (#892)
* Disable JIT in PyTorch and JAX to improve profiling (but don't override existing environment variables). * Disable JIT. * Documented.
1 parent 79d7f61 commit e7230ca

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

scalene/scalene_preload.py

+10
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@ def get_preload_environ(args: argparse.Namespace) -> Dict[str, str]:
1919
)
2020
}
2121

22+
# Disable JITting in PyTorch and JAX to improve profiling,
23+
# unless the environment variables are already set.
24+
# JAX_DISABLE_JIT: https://jax.readthedocs.io/en/latest/debugging/flags.html#id1
25+
# PYTORCH_JIT: https://pytorch.org/docs/stable/jit.html#disable-jit-for-debugging
26+
jit_flags = [ ('JAX_DISABLE_JIT', '1'), # truthy => disable JIT
27+
('PYTORCH_JIT', '0') ] # falsy => disable JIT
28+
for name, val in jit_flags:
29+
if name not in os.environ:
30+
env[name] = val
31+
2232
# Set environment variables for loading the Scalene dynamic library,
2333
# which interposes on allocation and copying functions.
2434
if sys.platform == "darwin":

0 commit comments

Comments
 (0)