Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570964066
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Oct 5, 2023
1 parent 75af72b commit 44906f2
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions vmoe/evaluate/fewshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
DatasetIterator = vmoe.data.input_pipeline.DatasetIterator
PartitionSpec = jax.sharding.PartitionSpec
PyTree = Any
PRNGKey = jax.random.KeyArray
PRNGKey = jax.Array

BIAS_CONSTANT = 100.0
VALID_KEY = vmoe.data.input_pipeline.VALID_KEY
Expand Down Expand Up @@ -333,7 +333,7 @@ def _make_fewshot_step_pjit(
rng_keys: Sequence[str],
):
"""Wraps _fewshot_step with pjit."""
state_axis_resources = FewShotState(
state_axis_resources = FewShotState( # pytype: disable=wrong-arg-types
rngs={key: PartitionSpec() for key in rng_keys})
fewshot_step_pjit = jax.experimental.pjit.pjit(
functools.partial(_fewshot_step, apply_fn=apply_fn),
Expand Down
4 changes: 2 additions & 2 deletions vmoe/train/train_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
# limitations under the License.

"""TrainState and other related classes."""
from typing import Any, Callable, Dict, Mapping, Tuple, Union
from typing import Any, Callable, Dict, Mapping, Tuple

import flax.training.train_state
import jax
import optax

PRNGKey = Union[jax.numpy.ndarray, jax.random.KeyArray]
PRNGKey = jax.Array


class TrainState(flax.training.train_state.TrainState):
Expand Down
2 changes: 1 addition & 1 deletion vmoe/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
Mesh = partitioning.Mesh
NamedSharding = jax.sharding.NamedSharding
PartitionSpec = partitioning.PartitionSpec
PRNGKey = Union[jax.numpy.ndarray, jax.random.KeyArray]
PRNGKey = jax.Array
PyTree = Any
ReportProgress = train_periodic_actions.ReportProgress
SingleProcessPeriodicAction = train_periodic_actions.SingleProcessPeriodicAction
Expand Down
2 changes: 1 addition & 1 deletion vmoe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import jax
import jax.numpy as jnp

PRNGKey = jax.random.KeyArray
PRNGKey = jax.Array


def make_rngs(rng_keys: Tuple[str, ...], seed: int) -> Dict[str, PRNGKey]:
Expand Down

0 comments on commit 44906f2

Please sign in to comment.