Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 624877716
  • Loading branch information
rchen152 authored and copybara-github committed Apr 15, 2024
1 parent 28f0b0b commit 2bc6c0e
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion vmoe/nn/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Metrics]:
gates_softmax = self._compute_gates_softmax(inputs, self.num_experts)
dispatcher, metrics = self._create_dispatcher_and_metrics(gates_softmax)
metrics["auxiliary_loss"] = 0.
return dispatcher, metrics
return dispatcher, metrics # pytype: disable=bad-return-type

@nn.nowrap
def _compute_gates_softmax(self, inputs: Array, num_experts: int) -> Array:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Metrics]:
dispatcher, metrics = self._create_dispatcher_and_metrics(
gates_dispatch=gates_ot)
metrics["auxiliary_loss"] = 0.
return dispatcher, metrics
return dispatcher, metrics # pytype: disable=bad-return-type

@nn.nowrap
def _compute_gates(self, inputs: Array, num_experts: int) -> Array:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Metrics]:
metrics["num_items_per_expert_max"] = jnp.max(num_items_per_expert, axis=1)
metrics["num_items_per_expert_avg"] = jnp.mean(num_items_per_expert, axis=1)
metrics["auxiliary_loss"] = 0.
return dispatcher, metrics
return dispatcher, metrics # pytype: disable=bad-return-type

@nn.nowrap
def _compute_gates(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Metrics]:
dispatcher, metrics = self._create_dispatcher_and_metrics(
gates_dispatch=gates_ot)
metrics["auxiliary_loss"] = 0.0
return dispatcher, metrics
return dispatcher, metrics # pytype: disable=bad-return-type

@nn.nowrap
def _compute_gates(
Expand Down

0 comments on commit 2bc6c0e

Please sign in to comment.