Skip to content

Commit

Permalink
refactor: add extra mlp arg
Browse files Browse the repository at this point in the history
  • Loading branch information
spravil committed Mar 25, 2024
1 parent 18a573f commit cc1754e
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/modalities/models/coca/multi_modal_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
):
super().__init__()
self.with_context = with_context
self.add_extra_mlp = add_extra_mlp or not with_context
self.add_extra_mlp = add_extra_mlp

if activation == ActivationType.GELU:
mlp = partial(MLP, in_features=n_embd, hidden_features=ffn_hidden, bias=bias, dropout=dropout)
Expand Down Expand Up @@ -105,6 +105,7 @@ def __init__(
with_context=True,
attention_type=AttentionType.CAUSAL_SELF_ATTENTION,
attention_config=attention_config,
add_extra_mlp=False,
)
for _ in range(n_layer)
]
Expand Down

0 comments on commit cc1754e

Please sign in to comment.