Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nsys-jax: re-work to be more pip-install-able #1165

Open
wants to merge 37 commits into
base: main
Choose a base branch
from

Conversation

olupton
Copy link
Collaborator

@olupton olupton commented Nov 21, 2024

The overarching goal of this PR is to get closer to a world where the nsys-jax tooling is straightforwardly pip install-able. While the diff looks scary, it's mostly re-organisation.

Substantive changes:

  • nsys-jax no longer bundles Python code in the output archives, the install.sh script provided for users to run on local machines becomes, loosely, install 'pip nsys-jax[jupyter] @ git+https://github.com/NVIDIA/JAX-Toolbox.git@COMMIT#subdirectory=.github/container/nsys_jax', where COMMIT corresponds to the nsys-jax command that produced the archive. For the ghcr.io/nvidia/jax containers, this is the commit of JAX-Toolbox that triggered the container build.

Changes included:

  • Introduce /opt/pip-tools-post-install.d, which pip-finalize.sh will execute the contents of after installing the pip-managed world
  • Move nsys-jax installation (specifically for the containers) into install-nsys-jax.sh and thereby clean up install-nsight.sh. The new script has to be told the git commit hash of JAX-Toolbox that is being built, because nsys-jax bakes this into an installation script in its output .zip archives to ensure the local environment matches the profile-collection environment.
  • The CLI tools like nsys-jax, nsys-jax-combine and install-protoc are now handled via [project.scripts] in pyproject.toml instead of being standalone Python scripts. This is "more standard", and also makes it easier to share code between nsys-jax and nsys-jax-combine.
  • The Python library is renamed from jax_nsys to nsys_jax for consistency.
  • It's now possible to set the default data loading path via the NSYS_JAX_DEFAULT_PREFIX environment variable; previously the default was the current working directory, but that can be inconvenient to steer in Jupyter environments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant