From f0577298c12be0ff6c7066da8b5ea062762bdc40 Mon Sep 17 00:00:00 2001 From: sahil suneja Date: Fri, 6 Sep 2024 19:23:09 +0000 Subject: [PATCH] tp world_size fix Signed-off-by: sahil suneja --- fms_extras/modules/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fms_extras/modules/attention.py b/fms_extras/modules/attention.py index 8d3031a..b0023d3 100644 --- a/fms_extras/modules/attention.py +++ b/fms_extras/modules/attention.py @@ -348,8 +348,8 @@ def forward( # if use_cache=True, we return the hidden_state as well as the kv cache. # We only reduce the output, and keep the cache thread-local if use_cache: - out = reduce_from_tensor_model_parallel_region(out_par[0]) + out = reduce_from_tensor_model_parallel_region(out_par[0], self.world_size) return out, out_par[1] else: - out = reduce_from_tensor_model_parallel_region(out_par) + out = reduce_from_tensor_model_parallel_region(out_par, self.world_size) return out