diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index c154b09f4..1290a59a1 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -102,13 +102,19 @@ def __init__( self.activation_type = neox_args.activation self.bias_gelu_fusion = neox_args.bias_gelu_fusion - # auto scale so geglu has equal parameters - ff_mult = int(4 * 2 / 3) if self.activation_type == "geglu" else 4 - ff_dim = ( - int(ff_mult * neox_args.hidden_size) * 2 - if self.activation_type == "geglu" - else ff_mult * neox_args.hidden_size - ) + + if neox_args.intermediate_size: + ff_dim = neox_args.intermediate_size + + else: + # auto scale so geglu has equal parameters + ff_mult = int(4 * 2 / 3) if self.activation_type == "geglu" else 4 + ff_dim = ( + int(ff_mult * neox_args.hidden_size) * 2 + if self.activation_type == "geglu" + else int(ff_mult * neox_args.hidden_size) + ) + self.dense_h_to_4h = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size,