@@ -272,24 +272,53 @@ def _distribute_initializer(
272272 Raises:
273273 ValueError: If init_func or seed is None.
274274 If init_func.func is not a supported random function.
275+ Supported jax.random func: normal, truncated_normal, uniform
275276 TypeError: If init_func is not a functools.partial object.
276277 """
277278 import warnings
278279 from functools import partial
279280
280- # Validate all required arguments
281- if seed is None :
282- raise ValueError ("seed cannot be None. Use keras.random.SeedGenerator." )
281+ # Create SeedGenerator to ensure backend variable exists
282+ # For future state tracking for distributed keys, add
283+ # attributes for base/split keys and number of devices sharded.
284+ if isinstance (seed , jax .Array ):
285+ seed_gen = seed_generator .SeedGenerator (seed = int (seed [0 ]))
286+ elif isinstance (seed , int ):
287+ seed_gen = seed_generator .SeedGenerator (seed = seed )
288+ elif isinstance (seed , seed_generator .SeedGenerator ):
289+ seed_gen = seed
290+ else :
291+ raise ValueError (
292+ f"seed must be int, JAX array, or SeedGenerator, got { type (seed )} "
293+ )
283294
284- if init_func is None :
295+ # Extract the state value as JAX array
296+ jax_seed = seed_gen .state .value
297+
298+ # Convert to JAX PRNG key format (swap counter and seed value)
299+ jax_compatible_seed = jax .numpy .array (
300+ [jax_seed [1 ], jax_seed [0 ]], dtype = jax .numpy .uint32
301+ )
302+
303+ # Validate all required arguments
304+ if init_func is None or init_func .func .__name__ not in [
305+ "normal" ,
306+ "truncated_normal" ,
307+ "uniform" ,
308+ ]:
285309 raise ValueError (
286- "init_func cannot be None. Shape and dtype info are required."
310+ "init_func cannot be None or "
311+ "Unsupported initializer: {init_func.func.__name__}."
312+ "only JAX-compatible random initializers are supported. "
313+ "Supported jax.random funcs: normal, truncated_normal, uniform"
287314 )
288315
289316 # Ensure init_func is a partial
290317 if not isinstance (init_func , partial ):
291318 raise TypeError (
292319 f"init_func must be functools.partial object, got { type (init_func )} "
320+ "init_func is a jax.random.* function with shape and "
321+ "dtype bound via partial"
293322 )
294323
295324 # Shard based on tensor layout
@@ -301,12 +330,28 @@ def _distribute_initializer(
301330 else :
302331 sharding = _to_backend_layout (layout )
303332
304- # The init_func has static arguments baked in as per initializer.
305- compiled_init = jax .jit (
306- lambda seed : init_func (seed ), out_shardings = sharding
307- )
333+ # JAX PRNG key handling within JIT:
334+ # The key is passed directly to jax.random.* functions which are
335+ # JIT-compatible and functional. JAX automatically ensures different
336+ # random values per shard when out_shardings is specified.
337+ try :
338+ compiled_init = jax .jit (
339+ lambda jax_compatible_seed : init_func (jax_compatible_seed ),
340+ out_shardings = sharding ,
341+ )
342+ sample = compiled_init (jax_compatible_seed )
343+ except RuntimeError as e :
344+ warnings .warn (
345+ f"Sharding failed due to: { e } , falling back to single device"
346+ )
347+ compiled_init = jax .jit (
348+ lambda jax_compatible_seed : init_func (jax_compatible_seed ),
349+ out_shardings = None ,
350+ )
351+ sample = compiled_init (jax_compatible_seed )
308352
309- sample = compiled_init (seed )
353+ # Store the SeedGenerator for state tracking
354+ seed = seed_gen .next ()
310355
311356 # Apply mean/stddev only for distributions where it makes sense
312357 if init_func .func in (jax .random .normal , jax .random .truncated_normal ):
@@ -318,8 +363,3 @@ def _distribute_initializer(
318363 "mean and stddev are ignored for uniform distribution"
319364 )
320365 return sample
321- else :
322- raise ValueError (
323- f"Unsupported initializer: { init_func .func .__name__ } . "
324- f"Supported: normal, truncated_normal, uniform"
325- )
0 commit comments