From 2bc6c0e468ffa480d506ea08647f4d6bcd1abafc Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Mon, 15 Apr 2024 01:58:17 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 624877716 --- vmoe/nn/routing.py | 2 +- vmoe/projects/sparsity_constrained_ot/kl_projection_routing.py | 2 +- .../sparsity_constrained_ot/ksparse_projection_routing.py | 2 +- .../sparsity_constrained_ot/sparse_projection_routing.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vmoe/nn/routing.py b/vmoe/nn/routing.py index 95df4b3..0f7c311 100644 --- a/vmoe/nn/routing.py +++ b/vmoe/nn/routing.py @@ -207,7 +207,7 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Metrics]: gates_softmax = self._compute_gates_softmax(inputs, self.num_experts) dispatcher, metrics = self._create_dispatcher_and_metrics(gates_softmax) metrics["auxiliary_loss"] = 0. - return dispatcher, metrics + return dispatcher, metrics # pytype: disable=bad-return-type @nn.nowrap def _compute_gates_softmax(self, inputs: Array, num_experts: int) -> Array: diff --git a/vmoe/projects/sparsity_constrained_ot/kl_projection_routing.py b/vmoe/projects/sparsity_constrained_ot/kl_projection_routing.py index 4c4a48a..8d3be4a 100644 --- a/vmoe/projects/sparsity_constrained_ot/kl_projection_routing.py +++ b/vmoe/projects/sparsity_constrained_ot/kl_projection_routing.py @@ -197,7 +197,7 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Metrics]: dispatcher, metrics = self._create_dispatcher_and_metrics( gates_dispatch=gates_ot) metrics["auxiliary_loss"] = 0. - return dispatcher, metrics + return dispatcher, metrics # pytype: disable=bad-return-type @nn.nowrap def _compute_gates(self, inputs: Array, num_experts: int) -> Array: diff --git a/vmoe/projects/sparsity_constrained_ot/ksparse_projection_routing.py b/vmoe/projects/sparsity_constrained_ot/ksparse_projection_routing.py index be03437..cb8f336 100644 --- a/vmoe/projects/sparsity_constrained_ot/ksparse_projection_routing.py +++ b/vmoe/projects/sparsity_constrained_ot/ksparse_projection_routing.py @@ -133,7 +133,7 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Metrics]: metrics["num_items_per_expert_max"] = jnp.max(num_items_per_expert, axis=1) metrics["num_items_per_expert_avg"] = jnp.mean(num_items_per_expert, axis=1) metrics["auxiliary_loss"] = 0. - return dispatcher, metrics + return dispatcher, metrics # pytype: disable=bad-return-type @nn.nowrap def _compute_gates( diff --git a/vmoe/projects/sparsity_constrained_ot/sparse_projection_routing.py b/vmoe/projects/sparsity_constrained_ot/sparse_projection_routing.py index 772dcfb..879cde5 100644 --- a/vmoe/projects/sparsity_constrained_ot/sparse_projection_routing.py +++ b/vmoe/projects/sparsity_constrained_ot/sparse_projection_routing.py @@ -62,7 +62,7 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Metrics]: dispatcher, metrics = self._create_dispatcher_and_metrics( gates_dispatch=gates_ot) metrics["auxiliary_loss"] = 0.0 - return dispatcher, metrics + return dispatcher, metrics # pytype: disable=bad-return-type @nn.nowrap def _compute_gates(