Skip to content

Commit 20756cd

Browse files
authored
fix import jit.marker.unified (#4622)
1 parent 561b9f3 commit 20756cd

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

fastdeploy/distributed/communication.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,22 +77,25 @@ def tensor_model_parallel_all_reduce(
7777
from paddle.distributed.communication import stream
7878
from paddle.distributed.communication.reduce import ReduceOp
7979

80+
try:
81+
82+
def all_reduce(
83+
tensor,
84+
op,
85+
group,
86+
sync_op: bool = True,
87+
):
88+
return stream.all_reduce(tensor, op=op, group=group, sync_op=sync_op, use_calc_stream=True)
8089

81-
def all_reduce(
82-
tensor,
83-
op,
84-
group,
85-
sync_op: bool = True,
86-
):
87-
return stream.all_reduce(tensor, op=op, group=group, sync_op=sync_op, use_calc_stream=True)
88-
89-
90-
@paddle.jit.marker.unified
91-
def tensor_model_parallel_all_reduce_custom(input_: paddle.Tensor) -> paddle.Tensor:
92-
"""All-reduce the input tensor across model parallel group on calc stream."""
93-
if paddle.in_dynamic_mode():
94-
hcg = dist.fleet.get_hybrid_communicate_group()
95-
mp_group = hcg.get_model_parallel_group()
96-
all_reduce(input_, op=ReduceOp.SUM, group=mp_group)
97-
else:
98-
dist.all_reduce(input_)
90+
@paddle.jit.marker.unified
91+
def tensor_model_parallel_all_reduce_custom(input_: paddle.Tensor) -> paddle.Tensor:
92+
"""All-reduce the input tensor across model parallel group on calc stream."""
93+
if paddle.in_dynamic_mode():
94+
hcg = dist.fleet.get_hybrid_communicate_group()
95+
mp_group = hcg.get_model_parallel_group()
96+
all_reduce(input_, op=ReduceOp.SUM, group=mp_group)
97+
else:
98+
dist.all_reduce(input_)
99+
100+
except:
101+
tensor_model_parallel_all_reduce_custom = None

0 commit comments

Comments
 (0)