Skip to content

Commit

Permalink
tp world_size fix
Browse files Browse the repository at this point in the history
Signed-off-by: sahil suneja <[email protected]>
  • Loading branch information
sahilsuneja1 committed Sep 6, 2024
1 parent b1937db commit f057729
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions fms_extras/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ def forward(
# if use_cache=True, we return the hidden_state as well as the kv cache.
# We only reduce the output, and keep the cache thread-local
if use_cache:
out = reduce_from_tensor_model_parallel_region(out_par[0])
out = reduce_from_tensor_model_parallel_region(out_par[0], self.world_size)
return out, out_par[1]
else:
out = reduce_from_tensor_model_parallel_region(out_par)
out = reduce_from_tensor_model_parallel_region(out_par, self.world_size)
return out

0 comments on commit f057729

Please sign in to comment.