Skip to content

Commit b36b051

Browse files
Address review feedback: improve error messages and add PRNG key handling comments
1 parent ac871f1 commit b36b051

File tree

1 file changed

+55
-15
lines changed

1 file changed

+55
-15
lines changed

keras/src/backend/jax/distribution_lib.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -272,24 +272,53 @@ def _distribute_initializer(
272272
Raises:
273273
ValueError: If init_func or seed is None.
274274
If init_func.func is not a supported random function.
275+
Supported jax.random func: normal, truncated_normal, uniform
275276
TypeError: If init_func is not a functools.partial object.
276277
"""
277278
import warnings
278279
from functools import partial
279280

280-
# Validate all required arguments
281-
if seed is None:
282-
raise ValueError("seed cannot be None. Use keras.random.SeedGenerator.")
281+
# Create SeedGenerator to ensure backend variable exists
282+
# For future state tracking for distributed keys, add
283+
# attributes for base/split keys and number of devices sharded.
284+
if isinstance(seed, jax.Array):
285+
seed_gen = seed_generator.SeedGenerator(seed=int(seed[0]))
286+
elif isinstance(seed, int):
287+
seed_gen = seed_generator.SeedGenerator(seed=seed)
288+
elif isinstance(seed, seed_generator.SeedGenerator):
289+
seed_gen = seed
290+
else:
291+
raise ValueError(
292+
f"seed must be int, JAX array, or SeedGenerator, got {type(seed)}"
293+
)
283294

284-
if init_func is None:
295+
# Extract the state value as JAX array
296+
jax_seed = seed_gen.state.value
297+
298+
# Convert to JAX PRNG key format (swap counter and seed value)
299+
jax_compatible_seed = jax.numpy.array(
300+
[jax_seed[1], jax_seed[0]], dtype=jax.numpy.uint32
301+
)
302+
303+
# Validate all required arguments
304+
if init_func is None or init_func.func.__name__ not in [
305+
"normal",
306+
"truncated_normal",
307+
"uniform",
308+
]:
285309
raise ValueError(
286-
"init_func cannot be None. Shape and dtype info are required."
310+
"init_func cannot be None or "
311+
"Unsupported initializer: {init_func.func.__name__}."
312+
"only JAX-compatible random initializers are supported. "
313+
"Supported jax.random funcs: normal, truncated_normal, uniform"
287314
)
288315

289316
# Ensure init_func is a partial
290317
if not isinstance(init_func, partial):
291318
raise TypeError(
292319
f"init_func must be functools.partial object, got {type(init_func)}"
320+
"init_func is a jax.random.* function with shape and "
321+
"dtype bound via partial"
293322
)
294323

295324
# Shard based on tensor layout
@@ -301,12 +330,28 @@ def _distribute_initializer(
301330
else:
302331
sharding = _to_backend_layout(layout)
303332

304-
# The init_func has static arguments baked in as per initializer.
305-
compiled_init = jax.jit(
306-
lambda seed: init_func(seed), out_shardings=sharding
307-
)
333+
# JAX PRNG key handling within JIT:
334+
# The key is passed directly to jax.random.* functions which are
335+
# JIT-compatible and functional. JAX automatically ensures different
336+
# random values per shard when out_shardings is specified.
337+
try:
338+
compiled_init = jax.jit(
339+
lambda jax_compatible_seed: init_func(jax_compatible_seed),
340+
out_shardings=sharding,
341+
)
342+
sample = compiled_init(jax_compatible_seed)
343+
except RuntimeError as e:
344+
warnings.warn(
345+
f"Sharding failed due to: {e}, falling back to single device"
346+
)
347+
compiled_init = jax.jit(
348+
lambda jax_compatible_seed: init_func(jax_compatible_seed),
349+
out_shardings=None,
350+
)
351+
sample = compiled_init(jax_compatible_seed)
308352

309-
sample = compiled_init(seed)
353+
# Store the SeedGenerator for state tracking
354+
seed = seed_gen.next()
310355

311356
# Apply mean/stddev only for distributions where it makes sense
312357
if init_func.func in (jax.random.normal, jax.random.truncated_normal):
@@ -318,8 +363,3 @@ def _distribute_initializer(
318363
"mean and stddev are ignored for uniform distribution"
319364
)
320365
return sample
321-
else:
322-
raise ValueError(
323-
f"Unsupported initializer: {init_func.func.__name__}. "
324-
f"Supported: normal, truncated_normal, uniform"
325-
)

0 commit comments

Comments
 (0)