Skip to content

Commit

Permalink
add support for rwkv and mamba to use intermediate_size
Browse files Browse the repository at this point in the history
  • Loading branch information
jahatef committed Jun 19, 2024
1 parent 6c6a46b commit ecb02f5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
4 changes: 4 additions & 0 deletions megatron/model/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def __init__(
self.d_state = 16 # state dimensions per channel
self.d_conv = 4 # convolution width
self.expand = 2 # linear projection expansion factors
if neox_args.intermediate_size == None:
neox_args.d_inner = self.expand * self.d_model
else:
neox_args.d_inner = neox_args.intermediate_size
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) # rank of dt / Delta parameter
self.dt_scale = 1.0
Expand Down
12 changes: 7 additions & 5 deletions megatron/model/rwkv/v6/rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@ def __init__(self, neox_args, layer_number):
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))

self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False)
self.key = nn.Linear(neox_args.hidden_size, neox_args.ff_dim, bias=False)
self.receptance = nn.Linear(
neox_args.hidden_size, neox_args.hidden_size, bias=False
)
self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False)
self.value = nn.Linear(neox_args.ff_dim, neox_args.hidden_size, bias=False)

def forward(self, x):
xx = self.time_shift(x) - x
Expand All @@ -277,12 +277,14 @@ def __init__(self, neox_args, layer_number):
self.bf16 = neox_args.precision == "bfloat16"
if not hasattr(neox_args, "dim_att"):
neox_args.dim_att = neox_args.hidden_size
if not hasattr(neox_args, "dim_ffn"):
if neox_args.intermediate_size == None:
# Make hidden size 3.5x. Round to nearest multiple of 32 until we add hdim rounding logic
neox_args.dim_ffn = int((neox_args.hidden_size * 3.5) // 32 * 32)
neox_args.ff_dim = int((neox_args.hidden_size * 3.5) // 32 * 32)
else:
neox_args.ff_dim = neox_args.intermediate_size
assert neox_args.hidden_size % 32 == 0
assert neox_args.dim_att % 32 == 0
assert neox_args.dim_ffn % 32 == 0
assert neox_args.ff_dim % 32 == 0
self.neox_args.head_size = neox_args.dim_att // neox_args.num_attention_heads
self.head_size = self.neox_args.head_size
self.num_attention_heads = neox_args.num_attention_heads
Expand Down

0 comments on commit ecb02f5

Please sign in to comment.