Skip to content

Jax distributed deadlock #31870

@jwnys

Description

@jwnys

Description

I'm running into deadlocks when using jax distributed sharding.
My code works on single node, multi-gpu sharding with 1 rank, but gives random deadlocks inside the code below when I run it on 2 nodes with 4 gpus each, using 8 tasks with SLURM.
There's no error produced; it just randomly hangs. It typically happens after calling the function successfully many times.
I can't figure out how to debug it either, since I can only reproduce it in this setting.

I've inspected the sharding of the arrays, and everything looks as expected.
I'm using the sharding and resharding as suggested in the documentation, and again: it works on sharding with single rank.
(Part of this code is adapted from the NetKet library)

So in this code:

  • tensors are resharded a couple of times
  • there's also a dense linear solve happening in parallel on all devices

Any idea what might be going on?

EDIT: The docs appear to warn for indefinite hangs when jax.process_index() is used. Elsewhere in my code I use this a couple of times (mainly for printing and writing info).
From what I understand now, the problem might not induced by the code below, but instead elsewhere. However, I can't find any way to debug these kind of problems (for now I was simply checking that all ranks arrive in this function), so any help on how to proceed or debug these kind of errors would still be very welcome. It remains confusing to me that all ranks go into the function, but never get out, and I don't see any place where i run rank-specific instructions in the code below.

from collections.abc import Callable
from functools import partial

from einops import rearrange

import jax
import jax.numpy as jnp
from jax.tree_util import tree_map

from jax.sharding import NamedSharding, PartitionSpec as P

from netket import jax as nkjax
from netket import config
from netket.jax._jacobian.default_mode import JacobianMode
from netket.utils import timing
from netket.utils.types import Array
from netket.jax import _ntk as nt

@timing.timed
@partial(
    jax.jit,
    static_argnames=(
        "log_psi",
        "solver_fn",
        "chunk_size",
        "mode",
        "store"
    ),
)
def srt_onthefly(
    log_psi,
    d_local_energies_uw,
    parameters,
    model_state,
    samples,
    *,
    diag_shift: float | Array,
    solver_fn: Callable[[Array, Array], Array],
    mode: JacobianMode,
    proj_reg: float | Array | None = None,
    momentum: float | Array | None = None,
    adaptive: float | Array | None = None,
    old_updates: Array | None = None,
    old_second_moment: Array | None = None,
    chunk_size: int | None = None,
    log_w: Array | None = None,
):
    
    
    if proj_reg is not None:
        raise NotImplementedError("TODO")
    # if momentum is not None:
    #     raise NotImplementedError("TODO")
    # if old_updates is not None:
    #     raise NotImplementedError("TODO")
    
    N_mc = d_local_energies_uw.size

    # Split all parameters into real and imaginary parts separately
    parameters_real, rss = nkjax.tree_to_real(parameters)

    # complex: (Nmc) -> (Nmc,2) - splitting real and imaginary output like 2 classes
    # real:    (Nmc) -> (Nmc,)  - no splitting
    def _apply_fn(parameters_real, samples):
        variables = {"params": rss(parameters_real), **model_state}
        log_amp = log_psi(variables, samples)

        if mode == "complex":
            re, im = log_amp.real, log_amp.imag
            return jnp.concatenate(
                (re[:, None], im[:, None]), axis=-1
            )  # shape [N_mc,2]
        else:
            return log_amp.real  # shape [N_mc, ]

    def jvp_f_chunk(parameters, vector, samples):
        r"""
        Creates the jvp of the function `_apply_fn` with respect to the parameters.
        This jvp is then evaluated in chunks of `chunk_size` samples.
        """
        f = lambda params: _apply_fn(params, samples)
        _, acc = jax.jvp(f, (parameters,), (vector,))
        return acc

    # compute rhs of the linear system
    d_local_energies_uw = d_local_energies_uw.flatten()
    if config.netket_experimental_sharding:
        # make sure we have the correct sharding
        d_local_energies_uw = jax.lax.with_sharding_constraint(
            d_local_energies_uw, NamedSharding(jax.sharding.get_abstract_mesh(), P("S"))
        )
    
    # compute extra quantities
    # convention: sum pdf to 1
    if log_w is None:
        pdf = jnp.full(d_local_energies_uw.shape, 1/N_mc)
    else:
        pdf = jnp.exp(log_w.flatten()) # already sharded
    if config.netket_experimental_sharding:
        # same sharding as de
        pdf = jax.lax.with_sharding_constraint(
            pdf, NamedSharding(jax.sharding.get_abstract_mesh(), P("S"))
        )
    de = d_local_energies_uw # since we already centered it
    sqrt_pdf = jnp.sqrt(pdf)        

    dv = 2.0 * de * sqrt_pdf            # shape [N_mc,]
    if mode == "complex":
        dv = jnp.stack([jnp.real(dv), jnp.imag(dv)], axis=-1)  # shape [N_mc,2]
    else:
        dv = jnp.real(dv)  # shape [N_mc,]


    if momentum is not None:
        if old_updates is None:
            old_updates = tree_map(jnp.zeros_like, parameters_real)
        else:
            # this sum runs over the parameters
            # computes sum_k O_k(x) phi_k for all x
            acc = nkjax.apply_chunked(
                jvp_f_chunk, in_axes=(None, None, 0), chunk_size=chunk_size
            )(parameters_real, old_updates, samples)
            assert acc.shape == dv.shape, f"shape = {acc.shape}" # [N_mc, 2]
            
            # let's make our life a bit easier
            if dv.ndim == 1:
                acc = acc[:,None]
            
            # this takes <O> over the samples
            # avg = jnp.mean(acc, axis=0)
            # acc = (acc - avg) / jnp.sqrt(N_mc)
            avg = jnp.sum(pdf[:,None] * acc, axis=0) # not over cmplx
            acc = (acc - avg) * sqrt_pdf[:,None]
            
            if dv.ndim == 1:
                acc = acc.squeeze(-1)

            dv -= momentum * acc

    if mode == "complex":
        dv = jax.lax.collapse(dv, 0, 2)  # shape [2*N_mc,] or [N_mc, ] if not complex

    # Collect all samples on all MPI ranks, those label the columns of the T matrix
    all_samples = samples
    if config.netket_experimental_sharding:
        samples = jax.lax.with_sharding_constraint(
            samples, NamedSharding(jax.sharding.get_abstract_mesh(), P("S", None))
        )
        all_samples = jax.lax.with_sharding_constraint(
            samples, NamedSharding(jax.sharding.get_abstract_mesh(), P())
        )
        
    if adaptive is not None:
        if old_second_moment is None:
            factor = None
        else:        
            # let's just wrap it to rescale the Ok's
            factor = jax.tree.map(lambda p: 1/(jnp.power(p, 1/4) + 1e-8), old_second_moment)
            _apply_fn = with_grad_multipliers(_apply_fn, argnums=0, grad_mul=factor)

    _jacobian_contraction = nt.empirical_ntk_by_jacobian(
        f=_apply_fn,
        trace_axes=(),
        vmap_axes=0,
    )

    def jacobian_contraction(samples, all_samples, parameters_real):
        if config.netket_experimental_sharding:
            parameters_real = jax.lax.pvary(parameters_real, "S")
        if chunk_size is None:
            # STRUCTURED_DERIVATIVES returns a complex array, but the imaginary part is zero
            # shape [N_mc/p.size, N_mc, 2, 2]
            return _jacobian_contraction(samples, all_samples, parameters_real).real
        else:
            _all_samples, _ = nkjax.chunk(all_samples, chunk_size=chunk_size)
            ntk_local = jax.lax.map(
                lambda batch_lattice: _jacobian_contraction(
                    samples, batch_lattice, parameters_real
                ).real,
                _all_samples,
            )
            if mode == "complex":
                return rearrange(ntk_local, "nbatches i j z w -> i (nbatches j) z w")
            else:
                return rearrange(ntk_local, "nbatches i j -> i (nbatches j)")

    # If we are sharding, use shard_map manually
    if config.netket_experimental_sharding:
        mesh = jax.sharding.get_abstract_mesh()
        # SAMPLES, ALL_SAMPLES, PARAMETERS_REAL
        in_specs = (P("S", None), P(), P())
        out_specs = P("S", None)

        jacobian_contraction = jax.shard_map(
            jacobian_contraction,
            mesh=mesh,
            in_specs=in_specs,
            out_specs=out_specs,
        )

    # This disables the nkjax.sharding_decorator in here, which might appear
    # in the apply function inside.
    with nkjax.sharding._increase_SHARD_MAP_STACK_LEVEL():
        ntk_local = jacobian_contraction(samples, all_samples, parameters_real).real

    # shape [N_mc, N_mc, 2, 2] or [N_mc, N_mc]
    if config.netket_experimental_sharding:
        # make sure every device has the ntk copy (gather the rows)
        ntk = jax.lax.with_sharding_constraint(
            ntk_local, NamedSharding(jax.sharding.get_abstract_mesh(), P())
        )
        pdf = jax.lax.with_sharding_constraint(
            pdf, NamedSharding(jax.sharding.get_abstract_mesh(), P())
        )
        sqrt_pdf = jax.lax.with_sharding_constraint(
            sqrt_pdf, NamedSharding(jax.sharding.get_abstract_mesh(), P())
        )
    else:
        ntk = ntk_local
                
    if mode == "complex":
        # shape [2*N_mc, 2*N_mc] checked with direct calculation of J^T J
        ntk = rearrange(ntk, "i j z w -> (i z) (j w)")

    delta = jnp.eye(N_mc) - pdf[:,None] # this form averages over the rows when used as delta^T v on a vector (!!!)
    if mode == "complex":
        # shape [2*N_mc, 2*N_mc]
        # Gets applied to the sub-blocks corresponding to the real part and imaginary part
        delta_conc = jnp.zeros((2 * N_mc, 2 * N_mc)).at[0::2, 0::2].set(delta)
        delta_conc = delta_conc.at[1::2, 1::2].set(delta)
        delta_conc = delta_conc.at[0::2, 1::2].set(0.0)
        delta_conc = delta_conc.at[1::2, 0::2].set(0.0)
    else:
        delta_conc = delta

    # shape [2*N_mc, 2*N_mc] centering the jacobian
    # ntk = (delta_conc @ (ntk @ delta_conc)) / N_mc
    ntk = delta_conc.T @ (ntk @ delta_conc)
    
    # assert ntk.shape[0] == 2 * sqrt_pdf.size
    # sqrt_pdf_double = jnp.stack([sqrt_pdf]*2, axis=-1) # (N_mc, 2)
    # sqrt_pdf_double = rearrange(sqrt_pdf_double, "i z -> (i z)") # use same ordering as the NTK!    
    
    if mode == "complex":
        assert ntk.shape[0] == 2 * sqrt_pdf.size
        sqrt_pdf_vec = jnp.stack([sqrt_pdf]*2, axis=-1)
        sqrt_pdf_vec = rearrange(sqrt_pdf_vec, "i z -> (i z)")  # (2N,) use same ordering as the NTK!    
    else:
        assert ntk.shape[0] == sqrt_pdf.size
        sqrt_pdf_vec = sqrt_pdf  # (N,)
    
    ntk = ntk * sqrt_pdf_vec[:,None] * sqrt_pdf_vec[None,:]
    
    # add diag shift
    ntk_shifted = ntk + diag_shift * jnp.eye(ntk.shape[0])

    # add projection regularization
    if proj_reg is not None:
        raise NotImplementedError("TODO")
        ntk_shifted = ntk_shifted + proj_reg / N_mc


    # after you set ntk, pdf, sqrt_pdf to P()
    if config.netket_experimental_sharding:
        mesh = jax.sharding.get_abstract_mesh()
        dv = jax.lax.with_sharding_constraint(dv, NamedSharding(mesh, P()))  # <— replicate RHS
    # some solvers return a tuple, some others do not.
    aus_vector = solver_fn(ntk_shifted, dv)
    if isinstance(aus_vector, tuple):
        aus_vector, info = aus_vector
    else:
        info = {}
    # aus_vector is available on every rank

    if info is None:
        info = {}

    aus_vector = aus_vector * sqrt_pdf_vec
    aus_vector = delta_conc @ aus_vector # note: no .T (!!!)

    # shape [N_mc,2]
    if mode == "complex":
        aus_vector = aus_vector.reshape(-1, 2)
    # shape [N_mc // p.size,2]
    if config.netket_experimental_sharding:
        aus_vector = jax.lax.with_sharding_constraint(
            aus_vector,
            NamedSharding(
                jax.sharding.get_abstract_mesh(),
                P("S", *(None,) * (aus_vector.ndim - 1)),
            ),
        )

    # _, vjp_fun = jax.vjp(f, parameters_real)
    vjp_fun = nkjax.vjp_chunked(
        _apply_fn,
        parameters_real,
        samples,
        chunk_size=chunk_size,
        chunk_argnums=1,
        nondiff_argnums=1,
    )

    (updates,) = vjp_fun(aus_vector)  # pytree [N_params,]

    if adaptive is not None:
        if old_updates is None:
            # we don't have any information yet, so just skip for now
            pass
        elif factor is not None: 
            # we are still missing a factor vk^-1/4
            updates = tree_map(lambda v, g: g*v, factor, updates)
                       
    if momentum is not None:
        updates = tree_map(lambda x, y: x + momentum * y, updates, old_updates)
    
    if adaptive is not None:
        if old_second_moment is None:
            # this multiplies, so ones to have no effect initially
            old_second_moment = tree_map(jnp.ones_like, parameters_real)
        else:
            # now also store the information for the next round
            old_second_moment = tree_map(
                lambda vprev, theta, thetaprev: adaptive*vprev + (theta - thetaprev)**2, 
                old_second_moment, updates, old_updates,
            )

        
    if momentum is not None or adaptive is not None:
        # store for next round (!)
        old_updates = updates
        
        

    return rss(updates), old_updates, old_second_moment, info

import jax
import jax.numpy as jnp
from jax import tree_util as jtu
from typing import Callable, Iterable, Tuple, Union, Any

@jax.custom_vjp
def scale_grad(x, scale):
    """Forward: returns x unchanged.
       Backward: multiplies dL/dx by `scale`. No grads w.r.t. `scale` (returns zeros)."""
    return x

def _scale_grad_fwd(x, scale):
    return x, (scale,)

def _scale_grad_bwd(res, g):
    (scale,) = res
    # cast/broadcast scale to g's dtype so it multiplies cleanly
    scale = jnp.asarray(scale, dtype=g.dtype)
    dx = g * scale
    dscale = jnp.zeros_like(scale)  # no gradient w.r.t. the multiplier
    return dx, dscale

scale_grad.defvjp(_scale_grad_fwd, _scale_grad_bwd)


def _broadcast_like(mul, tree):
    """If `mul` is a scalar/array, broadcast it to the structure of `tree`."""
    if isinstance(mul, (int, float, jnp.ndarray)):
        return jtu.tree_map(lambda _: mul, tree)
    return mul

def _apply_scaled(tree, mul_tree):
    """Apply scale_grad leafwise (mul can be scalar or a pytree matching `tree`)."""
    mul_tree = _broadcast_like(mul_tree, tree)
    return jtu.tree_map(lambda x, m: scale_grad(x, m), tree, mul_tree)

# ---------- 2) Wrapper: apply scaling to chosen positional args ----------

def with_grad_multipliers(f: Callable, grad_mul, argnums=(0,)):
    if isinstance(argnums, int):
        argnums = (argnums,)
    gm = (grad_mul,) if len(argnums) == 1 else tuple(grad_mul)
    def wrapped(*args, **kwargs):
        if gm is None:
            return f(*args, **kwargs)
        a = list(args)
        for idx, mul in zip(argnums, gm):
            a[idx] = _apply_scaled(a[idx], mul)
        return f(*a, **kwargs)
    return wrapped

System info (python version, jaxlib version, accelerator, etc.)

0: jax:    0.7.1
0: jaxlib: 0.7.1
0: numpy:  2.2.6
0: python: 3.13.5 (main, Jan  1 1980, 12:01:00) [GCC 14.2.0]
0: device info: NVIDIA GH200 120GB-8, 1 local devices"
0: process_count: 8
0: platform: uname_result(system='Linux', node='nid006545', release='5.14.21-150500.55.65_13.0.74-cray_shasta_c_64k', version='#1 SMP Mon Sep 9 09:48:48 UTC 2024 (ba86e71)', machine='aarch64')
0: JAX_PLATFORM_NAME=gpu
0: 
0: $ nvidia-smi
0: Wed Sep 17 13:49:29 2025       
0: +-----------------------------------------------------------------------------------------+
0: | NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
0: |-----------------------------------------+------------------------+----------------------+
0: | GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
0: | Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
0: |                                         |                        |               MIG M. |
0: |=========================================+==
0: ======================+======================|
0: |   0  NVIDIA GH200 120GB             On  |   00000009:01:00.0 Off |                    0 |
0: | N/A   26C    P0            134W /  900W |     814MiB /  97871MiB |      0%      Default |
0: |                                         |                        |             Disabled |
0: +-----------------------------------------+------------------------+----------------------+
0: |   1  NVIDIA GH200 120GB             On  |   00000019:01:00.0 Off |                    0 |
0: | N/A   27C    P0            143W /  900W |     815MiB /  97871MiB |      0%      Default |
0: |                                         |                        |             Disabled |
0: +-----------------------------------------+------------------------+----------------------+
0: |   2  NVIDIA GH200 120GB             On  |   00000029:01:00.0 Off |                    0 |
0: | N/A   28C    P0            124W /  900W |     816MiB /  97871MiB |      0%      Default |
0: |                                         |              
0:           |             Disabled |
0: +-----------------------------------------+------------------------+----------------------+
0: |   3  NVIDIA GH200 120GB             On  |   00000039:01:00.0 Off |                    0 |
0: | N/A   26C    P0            124W /  900W |     816MiB /  97871MiB |      0%      Default |
0: |                                         |                        |             Disabled |
0: +-----------------------------------------+------------------------+----------------------+
0:                                                                                          
0: +-----------------------------------------------------------------------------------------+
0: | Processes:                                                                              |
0: |  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
0: |        ID   ID                                                               Usage      |
0: |======================================================================
0: ===================|
0: |    0   N/A  N/A    156593      C   python                                        584MiB |
0: |    1   N/A  N/A    156592      C   python                                        584MiB |
0: |    2   N/A  N/A    156594      C   python                                        584MiB |
0: |    3   N/A  N/A    156595      C   python                                        584MiB |
0: +-----------------------------------------------------------------------------------------+

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions