diff --git a/megatron/model/mamba/mamba.py b/megatron/model/mamba/mamba.py index d5d6b336f..6582a723e 100644 --- a/megatron/model/mamba/mamba.py +++ b/megatron/model/mamba/mamba.py @@ -50,6 +50,10 @@ def __init__( self.d_state = 16 # state dimensions per channel self.d_conv = 4 # convolution width self.expand = 2 # linear projection expansion factors + if neox_args.intermediate_size == None: + neox_args.d_inner = self.expand * self.d_model + else: + neox_args.d_inner = neox_args.intermediate_size self.d_inner = int(self.expand * self.d_model) self.dt_rank = math.ceil(self.d_model / 16) # rank of dt / Delta parameter self.dt_scale = 1.0 diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 5d4e0d144..6e1affb58 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -247,11 +247,11 @@ def __init__(self, neox_args, layer_number): self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) - self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) + self.key = nn.Linear(neox_args.hidden_size, neox_args.ff_dim, bias=False) self.receptance = nn.Linear( neox_args.hidden_size, neox_args.hidden_size, bias=False ) - self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) + self.value = nn.Linear(neox_args.ff_dim, neox_args.hidden_size, bias=False) def forward(self, x): xx = self.time_shift(x) - x @@ -277,12 +277,14 @@ def __init__(self, neox_args, layer_number): self.bf16 = neox_args.precision == "bfloat16" if not hasattr(neox_args, "dim_att"): neox_args.dim_att = neox_args.hidden_size - if not hasattr(neox_args, "dim_ffn"): + if neox_args.intermediate_size == None: # Make hidden size 3.5x. Round to nearest multiple of 32 until we add hdim rounding logic - neox_args.dim_ffn = int((neox_args.hidden_size * 3.5) // 32 * 32) + neox_args.ff_dim = int((neox_args.hidden_size * 3.5) // 32 * 32) + else: + neox_args.ff_dim = neox_args.intermediate_size assert neox_args.hidden_size % 32 == 0 assert neox_args.dim_att % 32 == 0 - assert neox_args.dim_ffn % 32 == 0 + assert neox_args.ff_dim % 32 == 0 self.neox_args.head_size = neox_args.dim_att // neox_args.num_attention_heads self.head_size = self.neox_args.head_size self.num_attention_heads = neox_args.num_attention_heads