Skip to content

Commit

Permalink
Minor internal changes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586948915
  • Loading branch information
jpuigcerver authored and copybara-github committed Dec 1, 2023
1 parent ad20387 commit 7386d27
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion vmoe/nn/vit_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ def add_position_emb(self, inputs):
# By default, for back-compatibility, we use learned positional embeddings.
position_emb = self.position_emb or {}
name = position_emb.get('name', 'learned')
if name == 'none':
return inputs
if name == 'learned':
return AddPositionEmbs(
posemb_init=nn.initializers.normal(stddev=0.02), # from BERT.
Expand Down Expand Up @@ -320,8 +322,16 @@ class VisionTransformerMoe(nn.Module):
representation_size: Optional[int] = None
deterministic: bool = False
head_bias_init: float = 0.0
head_kernel_zero_init: bool = True
encoder_cls: Type[nn.Module] = EncoderMoe

@property
def kernel_init(self) -> nn.initializers.Initializer:
if self.head_kernel_zero_init:
return nn.initializers.zeros
else:
return nn.linear.default_kernel_init

@nn.compact
def __call__(self, inputs):
# Encode patches into tokens of hidden_size.
Expand Down Expand Up @@ -367,7 +377,7 @@ def __call__(self, inputs):
logits = nn.Dense(
features=self.num_classes,
name='head',
kernel_init=nn.initializers.zeros,
kernel_init=self.kernel_init,
bias_init=nn.initializers.constant(self.head_bias_init))(x)
return logits, metrics
else:
Expand Down

0 comments on commit 7386d27

Please sign in to comment.