Skip to content

Commit 0b25630

Browse files
pi314evertjruwaseloadams
authored
Add arctic model support by adding w2 to all_reduce (#6856)
As title says. Default behavior of arctic model produces shape issues with AutoTP due to the MLP layer performing `w2 * act(w1*w3)`. However, method provided to fix Mixtral-7x8b in #5257 does not work since the MLP for Arctic is also used within a ModuleList for the MoE. This results in MLP weights hiding behind individual experts as layers `#.w#`, which is not caught by the fix in #5257. This adds the check directly within replace, where it can check for actual layer names for the `w2` key in the model to patch with `all_reduce`. --------- Signed-off-by: Daniel Huang <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]>
1 parent 4cd1d97 commit 0b25630

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

deepspeed/module_inject/auto_tp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,11 +346,15 @@ def _replace(self, child, name, conv_linear_layer):
346346
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
347347
dist.get_world_size(), False)
348348
return LinearAllreduce(weight, bias, self.mp_group)
349+
# For Arctic model, bypass to all_reduce replacement for w2 weights
350+
arctic_w2_all_reduce_linear = False
351+
if 'Arctic' in str(self.module) and 'w2' in name:
352+
arctic_w2_all_reduce_linear = True
349353
# For MLP including chunk layer.
350354
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
351355
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
352356
return LinearLayer(weight=weight, bias=bias)
353-
if name in self.all_reduce_linears:
357+
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear:
354358
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
355359
# else [weight_shape[0], weight_shape[1] // mp_size]
356360

docs/_tutorials/automatic-tensor-parallelism.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ The following results were collected using V100 SXM2 32GB GPUs.
121121
The following model families have been successfully tested with automatic tensor parallelism. Other models may work but have not been tested yet.
122122

123123
- albert
124+
- arctic
124125
- baichuan
125126
- bert
126127
- bigbird_pegasus

0 commit comments

Comments
 (0)