From 6c6a46bbc84f3f66e5019a3f4c3cce69c27ec80c Mon Sep 17 00:00:00 2001 From: dtamayo <119006120+dtamayo-nlp@users.noreply.github.com> Date: Fri, 10 May 2024 09:58:49 +0200 Subject: [PATCH] Update transformer.py -> Add `intermediate_size` --- megatron/model/transformer.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index c154b09f4..1290a59a1 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -102,13 +102,19 @@ def __init__( self.activation_type = neox_args.activation self.bias_gelu_fusion = neox_args.bias_gelu_fusion - # auto scale so geglu has equal parameters - ff_mult = int(4 * 2 / 3) if self.activation_type == "geglu" else 4 - ff_dim = ( - int(ff_mult * neox_args.hidden_size) * 2 - if self.activation_type == "geglu" - else ff_mult * neox_args.hidden_size - ) + + if neox_args.intermediate_size: + ff_dim = neox_args.intermediate_size + + else: + # auto scale so geglu has equal parameters + ff_mult = int(4 * 2 / 3) if self.activation_type == "geglu" else 4 + ff_dim = ( + int(ff_mult * neox_args.hidden_size) * 2 + if self.activation_type == "geglu" + else int(ff_mult * neox_args.hidden_size) + ) + self.dense_h_to_4h = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size,