From 7386d27d38d575e910120fd5dac13745cb92e7de Mon Sep 17 00:00:00 2001 From: Joan Puigcerver Date: Fri, 1 Dec 2023 02:39:45 -0800 Subject: [PATCH] Minor internal changes. PiperOrigin-RevId: 586948915 --- vmoe/nn/vit_moe.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/vmoe/nn/vit_moe.py b/vmoe/nn/vit_moe.py index aeedae4..1a9afed 100644 --- a/vmoe/nn/vit_moe.py +++ b/vmoe/nn/vit_moe.py @@ -269,6 +269,8 @@ def add_position_emb(self, inputs): # By default, for back-compatibility, we use learned positional embeddings. position_emb = self.position_emb or {} name = position_emb.get('name', 'learned') + if name == 'none': + return inputs if name == 'learned': return AddPositionEmbs( posemb_init=nn.initializers.normal(stddev=0.02), # from BERT. @@ -320,8 +322,16 @@ class VisionTransformerMoe(nn.Module): representation_size: Optional[int] = None deterministic: bool = False head_bias_init: float = 0.0 + head_kernel_zero_init: bool = True encoder_cls: Type[nn.Module] = EncoderMoe + @property + def kernel_init(self) -> nn.initializers.Initializer: + if self.head_kernel_zero_init: + return nn.initializers.zeros + else: + return nn.linear.default_kernel_init + @nn.compact def __call__(self, inputs): # Encode patches into tokens of hidden_size. @@ -367,7 +377,7 @@ def __call__(self, inputs): logits = nn.Dense( features=self.num_classes, name='head', - kernel_init=nn.initializers.zeros, + kernel_init=self.kernel_init, bias_init=nn.initializers.constant(self.head_bias_init))(x) return logits, metrics else: