-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Description
Description
I'm running into deadlocks when using jax distributed sharding.
My code works on single node, multi-gpu sharding with 1 rank, but gives random deadlocks inside the code below when I run it on 2 nodes with 4 gpus each, using 8 tasks with SLURM.
There's no error produced; it just randomly hangs. It typically happens after calling the function successfully many times.
I can't figure out how to debug it either, since I can only reproduce it in this setting.
I've inspected the sharding of the arrays, and everything looks as expected.
I'm using the sharding and resharding as suggested in the documentation, and again: it works on sharding with single rank.
(Part of this code is adapted from the NetKet library)
So in this code:
- tensors are resharded a couple of times
- there's also a dense linear solve happening in parallel on all devices
Any idea what might be going on?
EDIT: The docs appear to warn for indefinite hangs when jax.process_index() is used. Elsewhere in my code I use this a couple of times (mainly for printing and writing info).
From what I understand now, the problem might not induced by the code below, but instead elsewhere. However, I can't find any way to debug these kind of problems (for now I was simply checking that all ranks arrive in this function), so any help on how to proceed or debug these kind of errors would still be very welcome. It remains confusing to me that all ranks go into the function, but never get out, and I don't see any place where i run rank-specific instructions in the code below.
from collections.abc import Callable
from functools import partial
from einops import rearrange
import jax
import jax.numpy as jnp
from jax.tree_util import tree_map
from jax.sharding import NamedSharding, PartitionSpec as P
from netket import jax as nkjax
from netket import config
from netket.jax._jacobian.default_mode import JacobianMode
from netket.utils import timing
from netket.utils.types import Array
from netket.jax import _ntk as nt
@timing.timed
@partial(
jax.jit,
static_argnames=(
"log_psi",
"solver_fn",
"chunk_size",
"mode",
"store"
),
)
def srt_onthefly(
log_psi,
d_local_energies_uw,
parameters,
model_state,
samples,
*,
diag_shift: float | Array,
solver_fn: Callable[[Array, Array], Array],
mode: JacobianMode,
proj_reg: float | Array | None = None,
momentum: float | Array | None = None,
adaptive: float | Array | None = None,
old_updates: Array | None = None,
old_second_moment: Array | None = None,
chunk_size: int | None = None,
log_w: Array | None = None,
):
if proj_reg is not None:
raise NotImplementedError("TODO")
# if momentum is not None:
# raise NotImplementedError("TODO")
# if old_updates is not None:
# raise NotImplementedError("TODO")
N_mc = d_local_energies_uw.size
# Split all parameters into real and imaginary parts separately
parameters_real, rss = nkjax.tree_to_real(parameters)
# complex: (Nmc) -> (Nmc,2) - splitting real and imaginary output like 2 classes
# real: (Nmc) -> (Nmc,) - no splitting
def _apply_fn(parameters_real, samples):
variables = {"params": rss(parameters_real), **model_state}
log_amp = log_psi(variables, samples)
if mode == "complex":
re, im = log_amp.real, log_amp.imag
return jnp.concatenate(
(re[:, None], im[:, None]), axis=-1
) # shape [N_mc,2]
else:
return log_amp.real # shape [N_mc, ]
def jvp_f_chunk(parameters, vector, samples):
r"""
Creates the jvp of the function `_apply_fn` with respect to the parameters.
This jvp is then evaluated in chunks of `chunk_size` samples.
"""
f = lambda params: _apply_fn(params, samples)
_, acc = jax.jvp(f, (parameters,), (vector,))
return acc
# compute rhs of the linear system
d_local_energies_uw = d_local_energies_uw.flatten()
if config.netket_experimental_sharding:
# make sure we have the correct sharding
d_local_energies_uw = jax.lax.with_sharding_constraint(
d_local_energies_uw, NamedSharding(jax.sharding.get_abstract_mesh(), P("S"))
)
# compute extra quantities
# convention: sum pdf to 1
if log_w is None:
pdf = jnp.full(d_local_energies_uw.shape, 1/N_mc)
else:
pdf = jnp.exp(log_w.flatten()) # already sharded
if config.netket_experimental_sharding:
# same sharding as de
pdf = jax.lax.with_sharding_constraint(
pdf, NamedSharding(jax.sharding.get_abstract_mesh(), P("S"))
)
de = d_local_energies_uw # since we already centered it
sqrt_pdf = jnp.sqrt(pdf)
dv = 2.0 * de * sqrt_pdf # shape [N_mc,]
if mode == "complex":
dv = jnp.stack([jnp.real(dv), jnp.imag(dv)], axis=-1) # shape [N_mc,2]
else:
dv = jnp.real(dv) # shape [N_mc,]
if momentum is not None:
if old_updates is None:
old_updates = tree_map(jnp.zeros_like, parameters_real)
else:
# this sum runs over the parameters
# computes sum_k O_k(x) phi_k for all x
acc = nkjax.apply_chunked(
jvp_f_chunk, in_axes=(None, None, 0), chunk_size=chunk_size
)(parameters_real, old_updates, samples)
assert acc.shape == dv.shape, f"shape = {acc.shape}" # [N_mc, 2]
# let's make our life a bit easier
if dv.ndim == 1:
acc = acc[:,None]
# this takes <O> over the samples
# avg = jnp.mean(acc, axis=0)
# acc = (acc - avg) / jnp.sqrt(N_mc)
avg = jnp.sum(pdf[:,None] * acc, axis=0) # not over cmplx
acc = (acc - avg) * sqrt_pdf[:,None]
if dv.ndim == 1:
acc = acc.squeeze(-1)
dv -= momentum * acc
if mode == "complex":
dv = jax.lax.collapse(dv, 0, 2) # shape [2*N_mc,] or [N_mc, ] if not complex
# Collect all samples on all MPI ranks, those label the columns of the T matrix
all_samples = samples
if config.netket_experimental_sharding:
samples = jax.lax.with_sharding_constraint(
samples, NamedSharding(jax.sharding.get_abstract_mesh(), P("S", None))
)
all_samples = jax.lax.with_sharding_constraint(
samples, NamedSharding(jax.sharding.get_abstract_mesh(), P())
)
if adaptive is not None:
if old_second_moment is None:
factor = None
else:
# let's just wrap it to rescale the Ok's
factor = jax.tree.map(lambda p: 1/(jnp.power(p, 1/4) + 1e-8), old_second_moment)
_apply_fn = with_grad_multipliers(_apply_fn, argnums=0, grad_mul=factor)
_jacobian_contraction = nt.empirical_ntk_by_jacobian(
f=_apply_fn,
trace_axes=(),
vmap_axes=0,
)
def jacobian_contraction(samples, all_samples, parameters_real):
if config.netket_experimental_sharding:
parameters_real = jax.lax.pvary(parameters_real, "S")
if chunk_size is None:
# STRUCTURED_DERIVATIVES returns a complex array, but the imaginary part is zero
# shape [N_mc/p.size, N_mc, 2, 2]
return _jacobian_contraction(samples, all_samples, parameters_real).real
else:
_all_samples, _ = nkjax.chunk(all_samples, chunk_size=chunk_size)
ntk_local = jax.lax.map(
lambda batch_lattice: _jacobian_contraction(
samples, batch_lattice, parameters_real
).real,
_all_samples,
)
if mode == "complex":
return rearrange(ntk_local, "nbatches i j z w -> i (nbatches j) z w")
else:
return rearrange(ntk_local, "nbatches i j -> i (nbatches j)")
# If we are sharding, use shard_map manually
if config.netket_experimental_sharding:
mesh = jax.sharding.get_abstract_mesh()
# SAMPLES, ALL_SAMPLES, PARAMETERS_REAL
in_specs = (P("S", None), P(), P())
out_specs = P("S", None)
jacobian_contraction = jax.shard_map(
jacobian_contraction,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
)
# This disables the nkjax.sharding_decorator in here, which might appear
# in the apply function inside.
with nkjax.sharding._increase_SHARD_MAP_STACK_LEVEL():
ntk_local = jacobian_contraction(samples, all_samples, parameters_real).real
# shape [N_mc, N_mc, 2, 2] or [N_mc, N_mc]
if config.netket_experimental_sharding:
# make sure every device has the ntk copy (gather the rows)
ntk = jax.lax.with_sharding_constraint(
ntk_local, NamedSharding(jax.sharding.get_abstract_mesh(), P())
)
pdf = jax.lax.with_sharding_constraint(
pdf, NamedSharding(jax.sharding.get_abstract_mesh(), P())
)
sqrt_pdf = jax.lax.with_sharding_constraint(
sqrt_pdf, NamedSharding(jax.sharding.get_abstract_mesh(), P())
)
else:
ntk = ntk_local
if mode == "complex":
# shape [2*N_mc, 2*N_mc] checked with direct calculation of J^T J
ntk = rearrange(ntk, "i j z w -> (i z) (j w)")
delta = jnp.eye(N_mc) - pdf[:,None] # this form averages over the rows when used as delta^T v on a vector (!!!)
if mode == "complex":
# shape [2*N_mc, 2*N_mc]
# Gets applied to the sub-blocks corresponding to the real part and imaginary part
delta_conc = jnp.zeros((2 * N_mc, 2 * N_mc)).at[0::2, 0::2].set(delta)
delta_conc = delta_conc.at[1::2, 1::2].set(delta)
delta_conc = delta_conc.at[0::2, 1::2].set(0.0)
delta_conc = delta_conc.at[1::2, 0::2].set(0.0)
else:
delta_conc = delta
# shape [2*N_mc, 2*N_mc] centering the jacobian
# ntk = (delta_conc @ (ntk @ delta_conc)) / N_mc
ntk = delta_conc.T @ (ntk @ delta_conc)
# assert ntk.shape[0] == 2 * sqrt_pdf.size
# sqrt_pdf_double = jnp.stack([sqrt_pdf]*2, axis=-1) # (N_mc, 2)
# sqrt_pdf_double = rearrange(sqrt_pdf_double, "i z -> (i z)") # use same ordering as the NTK!
if mode == "complex":
assert ntk.shape[0] == 2 * sqrt_pdf.size
sqrt_pdf_vec = jnp.stack([sqrt_pdf]*2, axis=-1)
sqrt_pdf_vec = rearrange(sqrt_pdf_vec, "i z -> (i z)") # (2N,) use same ordering as the NTK!
else:
assert ntk.shape[0] == sqrt_pdf.size
sqrt_pdf_vec = sqrt_pdf # (N,)
ntk = ntk * sqrt_pdf_vec[:,None] * sqrt_pdf_vec[None,:]
# add diag shift
ntk_shifted = ntk + diag_shift * jnp.eye(ntk.shape[0])
# add projection regularization
if proj_reg is not None:
raise NotImplementedError("TODO")
ntk_shifted = ntk_shifted + proj_reg / N_mc
# after you set ntk, pdf, sqrt_pdf to P()
if config.netket_experimental_sharding:
mesh = jax.sharding.get_abstract_mesh()
dv = jax.lax.with_sharding_constraint(dv, NamedSharding(mesh, P())) # <— replicate RHS
# some solvers return a tuple, some others do not.
aus_vector = solver_fn(ntk_shifted, dv)
if isinstance(aus_vector, tuple):
aus_vector, info = aus_vector
else:
info = {}
# aus_vector is available on every rank
if info is None:
info = {}
aus_vector = aus_vector * sqrt_pdf_vec
aus_vector = delta_conc @ aus_vector # note: no .T (!!!)
# shape [N_mc,2]
if mode == "complex":
aus_vector = aus_vector.reshape(-1, 2)
# shape [N_mc // p.size,2]
if config.netket_experimental_sharding:
aus_vector = jax.lax.with_sharding_constraint(
aus_vector,
NamedSharding(
jax.sharding.get_abstract_mesh(),
P("S", *(None,) * (aus_vector.ndim - 1)),
),
)
# _, vjp_fun = jax.vjp(f, parameters_real)
vjp_fun = nkjax.vjp_chunked(
_apply_fn,
parameters_real,
samples,
chunk_size=chunk_size,
chunk_argnums=1,
nondiff_argnums=1,
)
(updates,) = vjp_fun(aus_vector) # pytree [N_params,]
if adaptive is not None:
if old_updates is None:
# we don't have any information yet, so just skip for now
pass
elif factor is not None:
# we are still missing a factor vk^-1/4
updates = tree_map(lambda v, g: g*v, factor, updates)
if momentum is not None:
updates = tree_map(lambda x, y: x + momentum * y, updates, old_updates)
if adaptive is not None:
if old_second_moment is None:
# this multiplies, so ones to have no effect initially
old_second_moment = tree_map(jnp.ones_like, parameters_real)
else:
# now also store the information for the next round
old_second_moment = tree_map(
lambda vprev, theta, thetaprev: adaptive*vprev + (theta - thetaprev)**2,
old_second_moment, updates, old_updates,
)
if momentum is not None or adaptive is not None:
# store for next round (!)
old_updates = updates
return rss(updates), old_updates, old_second_moment, info
import jax
import jax.numpy as jnp
from jax import tree_util as jtu
from typing import Callable, Iterable, Tuple, Union, Any
@jax.custom_vjp
def scale_grad(x, scale):
"""Forward: returns x unchanged.
Backward: multiplies dL/dx by `scale`. No grads w.r.t. `scale` (returns zeros)."""
return x
def _scale_grad_fwd(x, scale):
return x, (scale,)
def _scale_grad_bwd(res, g):
(scale,) = res
# cast/broadcast scale to g's dtype so it multiplies cleanly
scale = jnp.asarray(scale, dtype=g.dtype)
dx = g * scale
dscale = jnp.zeros_like(scale) # no gradient w.r.t. the multiplier
return dx, dscale
scale_grad.defvjp(_scale_grad_fwd, _scale_grad_bwd)
def _broadcast_like(mul, tree):
"""If `mul` is a scalar/array, broadcast it to the structure of `tree`."""
if isinstance(mul, (int, float, jnp.ndarray)):
return jtu.tree_map(lambda _: mul, tree)
return mul
def _apply_scaled(tree, mul_tree):
"""Apply scale_grad leafwise (mul can be scalar or a pytree matching `tree`)."""
mul_tree = _broadcast_like(mul_tree, tree)
return jtu.tree_map(lambda x, m: scale_grad(x, m), tree, mul_tree)
# ---------- 2) Wrapper: apply scaling to chosen positional args ----------
def with_grad_multipliers(f: Callable, grad_mul, argnums=(0,)):
if isinstance(argnums, int):
argnums = (argnums,)
gm = (grad_mul,) if len(argnums) == 1 else tuple(grad_mul)
def wrapped(*args, **kwargs):
if gm is None:
return f(*args, **kwargs)
a = list(args)
for idx, mul in zip(argnums, gm):
a[idx] = _apply_scaled(a[idx], mul)
return f(*a, **kwargs)
return wrapped
System info (python version, jaxlib version, accelerator, etc.)
0: jax: 0.7.1
0: jaxlib: 0.7.1
0: numpy: 2.2.6
0: python: 3.13.5 (main, Jan 1 1980, 12:01:00) [GCC 14.2.0]
0: device info: NVIDIA GH200 120GB-8, 1 local devices"
0: process_count: 8
0: platform: uname_result(system='Linux', node='nid006545', release='5.14.21-150500.55.65_13.0.74-cray_shasta_c_64k', version='#1 SMP Mon Sep 9 09:48:48 UTC 2024 (ba86e71)', machine='aarch64')
0: JAX_PLATFORM_NAME=gpu
0:
0: $ nvidia-smi
0: Wed Sep 17 13:49:29 2025
0: +-----------------------------------------------------------------------------------------+
0: | NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
0: |-----------------------------------------+------------------------+----------------------+
0: | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
0: | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
0: | | | MIG M. |
0: |=========================================+==
0: ======================+======================|
0: | 0 NVIDIA GH200 120GB On | 00000009:01:00.0 Off | 0 |
0: | N/A 26C P0 134W / 900W | 814MiB / 97871MiB | 0% Default |
0: | | | Disabled |
0: +-----------------------------------------+------------------------+----------------------+
0: | 1 NVIDIA GH200 120GB On | 00000019:01:00.0 Off | 0 |
0: | N/A 27C P0 143W / 900W | 815MiB / 97871MiB | 0% Default |
0: | | | Disabled |
0: +-----------------------------------------+------------------------+----------------------+
0: | 2 NVIDIA GH200 120GB On | 00000029:01:00.0 Off | 0 |
0: | N/A 28C P0 124W / 900W | 816MiB / 97871MiB | 0% Default |
0: | |
0: | Disabled |
0: +-----------------------------------------+------------------------+----------------------+
0: | 3 NVIDIA GH200 120GB On | 00000039:01:00.0 Off | 0 |
0: | N/A 26C P0 124W / 900W | 816MiB / 97871MiB | 0% Default |
0: | | | Disabled |
0: +-----------------------------------------+------------------------+----------------------+
0:
0: +-----------------------------------------------------------------------------------------+
0: | Processes: |
0: | GPU GI CI PID Type Process name GPU Memory |
0: | ID ID Usage |
0: |======================================================================
0: ===================|
0: | 0 N/A N/A 156593 C python 584MiB |
0: | 1 N/A N/A 156592 C python 584MiB |
0: | 2 N/A N/A 156594 C python 584MiB |
0: | 3 N/A N/A 156595 C python 584MiB |
0: +-----------------------------------------------------------------------------------------+