Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Migrate JAX workloads from pmap to jit #848

Open
wants to merge 56 commits into
base: dev
Choose a base branch
from
Open

Conversation

priyakasimbeg
Copy link
Contributor

@priyakasimbeg priyakasimbeg commented Mar 6, 2025

Purpose

The goal of this PR is to allow model parameter and optimizer state sharding, and also to migrate the JAX code from using jax.pmap to using jax.jit.

TODOs:

  • Migrate reference optimizers to use jax.jit
    • Nesterov
    • AdamW
    • Others
  • Migrate workloads to use jax.jit
    • (Test workload) MNIST
    • (Test workload) CIFAR
    • WMT
    • Criteo1TB
    • FastMRI
    • Librispeech
    • OGBG
    • ImageNet

Changelog

  • Added some sharding utilities to handle data distributed
  • Replaced pmap code for CIFAR/MNIST with jit
  • Modified AdamW and Nesterov accordingly
  • Updated checkpoint and data_utils to support the new approach (mostly removing explicit jax_utils.replicate calls).

Issues

  • Prefetching functionality in CIFAR is temporarily disabled (marked with FIXME), not sure how to best support it here.
  • I haven't edited any of the PyTorch code, we will need to make sure they still do comparably..

@priyakasimbeg priyakasimbeg requested a review from a team as a code owner March 6, 2025 21:47
Copy link

github-actions bot commented Mar 6, 2025

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@priyakasimbeg priyakasimbeg changed the title Jit switch [WIP] Migrate JAX workloads from pmap to jit Mar 6, 2025
@priyakasimbeg priyakasimbeg changed the base branch from main to dev March 7, 2025 00:17
@@ -154,13 +156,12 @@ def _eval_model_on_split(self,
num_batches=num_batches)

total_metrics = {'ssim': 0., 'loss': 0.}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did we swap out the eval_rngs with the model rng?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants