File tree Expand file tree Collapse file tree 4 files changed +7
-7
lines changed
keras_rs/src/layers/embedding/jax Expand file tree Collapse file tree 4 files changed +7
-7
lines changed Original file line number Diff line number Diff line change 3636def _get_partition_spec (
3737 layout : (
3838 keras .distribution .TensorLayout
39- | jax_layout .Layout
39+ | jax_layout .Format
4040 | jax .sharding .NamedSharding
4141 | jax .sharding .PartitionSpec
4242 ),
@@ -45,7 +45,7 @@ def _get_partition_spec(
4545 if isinstance (layout , keras .distribution .TensorLayout ):
4646 layout = layout .backend_layout
4747
48- if isinstance (layout , jax_layout .Layout ):
48+ if isinstance (layout , jax_layout .Format ):
4949 layout = layout .sharding
5050
5151 if isinstance (layout , jax .sharding .NamedSharding ):
@@ -217,7 +217,7 @@ def _create_sparsecore_distribution(
217217 sparsecore_layout = keras .distribution .TensorLayout (axes , device_mesh )
218218 # Custom sparsecore layout with tiling.
219219 # pylint: disable-next=protected-access
220- sparsecore_layout ._backend_layout = jax_layout .Layout (
220+ sparsecore_layout ._backend_layout = jax_layout .Format (
221221 jax_layout .DeviceLocalLayout (
222222 major_to_minor = (0 , 1 ),
223223 _tiling = ((8 ,),),
Original file line number Diff line number Diff line change @@ -41,7 +41,7 @@ def _create_sparsecore_layout(
4141 )
4242 sparsecore_layout = keras .distribution .TensorLayout (axes , device_mesh )
4343 # Custom sparsecore layout with tiling.
44- sparsecore_layout ._backend_layout = jax_layout .Layout ( # pylint: disable=protected-access
44+ sparsecore_layout ._backend_layout = jax_layout .Format ( # pylint: disable=protected-access
4545 jax_layout .DeviceLocalLayout (
4646 major_to_minor = (0 , 1 ),
4747 _tiling = ((8 ,),),
Original file line number Diff line number Diff line change 88
99import jax
1010import numpy as np
11- from jax .experimental import layout
11+ from jax .experimental import layout as jax_layout
1212from jax_tpu_embedding .sparsecore .lib .nn import embedding
1313from jax_tpu_embedding .sparsecore .lib .nn import embedding_spec
1414from jax_tpu_embedding .sparsecore .utils import utils as jte_utils
2020shard_map = jax .experimental .shard_map .shard_map # type: ignore[attr-defined]
2121
2222ArrayLike : TypeAlias = jax .Array | np .ndarray [Any , Any ]
23- JaxLayout : TypeAlias = jax .sharding .NamedSharding | layout . Layout
23+ JaxLayout : TypeAlias = jax .sharding .NamedSharding | jax_layout . Format
2424
2525
2626class EmbeddingLookupConfiguration :
Original file line number Diff line number Diff line change @@ -8,7 +8,7 @@ torch>=2.1.0
88# Jax with cuda support.
99# Keep same version as Keras repo.
1010--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
11- jax[cuda12_pip]==0.6.0
11+ jax[cuda12_pip]==0.6.2
1212
1313# Support for large embeddings.
1414jax-tpu-embedding;sys_platform == 'linux' and platform_machine == 'x86_64'
You can’t perform that action at this time.
0 commit comments