diff --git a/vmoe/evaluate/fewshot.py b/vmoe/evaluate/fewshot.py index 62d2f4c..2829b1b 100644 --- a/vmoe/evaluate/fewshot.py +++ b/vmoe/evaluate/fewshot.py @@ -39,7 +39,7 @@ DatasetIterator = vmoe.data.input_pipeline.DatasetIterator PartitionSpec = jax.sharding.PartitionSpec PyTree = Any -PRNGKey = jax.random.KeyArray +PRNGKey = jax.Array BIAS_CONSTANT = 100.0 VALID_KEY = vmoe.data.input_pipeline.VALID_KEY @@ -333,7 +333,7 @@ def _make_fewshot_step_pjit( rng_keys: Sequence[str], ): """Wraps _fewshot_step with pjit.""" - state_axis_resources = FewShotState( + state_axis_resources = FewShotState( # pytype: disable=wrong-arg-types rngs={key: PartitionSpec() for key in rng_keys}) fewshot_step_pjit = jax.experimental.pjit.pjit( functools.partial(_fewshot_step, apply_fn=apply_fn), diff --git a/vmoe/train/train_state.py b/vmoe/train/train_state.py index 455ba3c..41d80bd 100644 --- a/vmoe/train/train_state.py +++ b/vmoe/train/train_state.py @@ -13,13 +13,13 @@ # limitations under the License. """TrainState and other related classes.""" -from typing import Any, Callable, Dict, Mapping, Tuple, Union +from typing import Any, Callable, Dict, Mapping, Tuple import flax.training.train_state import jax import optax -PRNGKey = Union[jax.numpy.ndarray, jax.random.KeyArray] +PRNGKey = jax.Array class TrainState(flax.training.train_state.TrainState): diff --git a/vmoe/train/trainer.py b/vmoe/train/trainer.py index e16143a..f44eaa4 100644 --- a/vmoe/train/trainer.py +++ b/vmoe/train/trainer.py @@ -62,7 +62,7 @@ Mesh = partitioning.Mesh NamedSharding = jax.sharding.NamedSharding PartitionSpec = partitioning.PartitionSpec -PRNGKey = Union[jax.numpy.ndarray, jax.random.KeyArray] +PRNGKey = jax.Array PyTree = Any ReportProgress = train_periodic_actions.ReportProgress SingleProcessPeriodicAction = train_periodic_actions.SingleProcessPeriodicAction diff --git a/vmoe/utils.py b/vmoe/utils.py index eedbe65..8e8a845 100644 --- a/vmoe/utils.py +++ b/vmoe/utils.py @@ -21,7 +21,7 @@ import jax import jax.numpy as jnp -PRNGKey = jax.random.KeyArray +PRNGKey = jax.Array def make_rngs(rng_keys: Tuple[str, ...], seed: int) -> Dict[str, PRNGKey]: