diff --git a/fastsafetensors/tensor_factory.py b/fastsafetensors/tensor_factory.py index 6556aee..46b0174 100644 --- a/fastsafetensors/tensor_factory.py +++ b/fastsafetensors/tensor_factory.py @@ -106,7 +106,7 @@ def shuffle(self, pg: dist.ProcessGroup, tensor_name: str, dim: int, group = Non if self.metadata.framework == "pytorch": dist.broadcast(dst, self.rank, group=pg) elif paddle_loaded and self.metadata.framework == "paddle": - pdist.broadcast(dst, self.rank, group=group, sync_op=False) + pdist.broadcast(dst, self.rank, group=group) else: rank_slices: List[Tuple] = [() for i in range(0, pg.size())] size = frame.shape[dim] @@ -135,7 +135,7 @@ def shuffle(self, pg: dist.ProcessGroup, tensor_name: str, dim: int, group = Non if self.metadata.framework == "pytorch": dist.scatter(dst, scatter_list=scatter_list, src=self.rank, group=pg) elif paddle_loaded and self.metadata.framework == "paddle": - pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group, sync_op=False) + pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group) self.shuffled[tensor_name] = dst return dst @@ -173,7 +173,7 @@ def shuffle_packed_qkv(self, pg: dist.ProcessGroup, tensor_name: str, group = No dist.scatter(dst, scatter_list=scatter_list, src=self.rank, group=pg) elif paddle_loaded and self.metadata.framework == "paddle": dst = paddle.to_tensor(paddle.empty(shape=new_shape, dtype=frame.dtype),place=self.device) - pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group, sync_op=False) + pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group) self.shuffled[tensor_name] = dst return dst @@ -214,7 +214,7 @@ def shuffle_multi_cols(self, pg: dist.ProcessGroup, tensor_names: List[str], dim dist.scatter(dst, scatter_list=scatter_list, src=self.rank, group=pg) elif paddle_loaded and self.metadata.framework == "paddle": dst = paddle.to_tensor(paddle.empty(shape=new_shape, dtype=frame.dtype), place=self.device )# dst should be eariler than scatter_list for less fragmentation - pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group, sync_op=False) + pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group) return dst def free_dev_ptrs(self):