@@ -572,35 +572,3 @@ def device_scope(device_name):
572572 else :
573573 jax_device = device_name
574574 return jax .default_device (jax_device )
575-
576-
577- def convert_checkpoint_value (value , dtype , shape ):
578- """Convert a value for checkpoint restoration, preserving JAX arrays for
579- sharding.
580-
581- This function handles the special case of checkpoint restoration where JAX
582- arrays should be preserved for sharding support, while other values are
583- converted to JAX arrays with the specified dtype and shape.
584-
585- Args:
586- value: The value to convert (can be JAX array, numpy array, or other
587- types)
588- dtype: The target dtype
589- shape: The target shape
590-
591- Returns:
592- A JAX array with the specified dtype and shape, or the original JAX
593- array if it was already a JAX array.
594- """
595- # For JAX backend, preserve JAX arrays for sharding support
596- if hasattr (value , "__array_namespace__" ) or str (type (value )).startswith (
597- "<class 'jax"
598- ):
599- # value is already a JAX array, return as-is to preserve sharding
600- return value
601- elif isinstance (value , np .ndarray ):
602- # Convert numpy array to JAX array
603- return jnp .array (value ).astype (dtype ).reshape (shape )
604- else :
605- # Convert other types to JAX array
606- return jnp .array (value , dtype = dtype ).reshape (shape )
0 commit comments