@@ -77,22 +77,25 @@ def tensor_model_parallel_all_reduce(
7777from paddle .distributed .communication import stream
7878from paddle .distributed .communication .reduce import ReduceOp
7979
80+ try :
81+
82+ def all_reduce (
83+ tensor ,
84+ op ,
85+ group ,
86+ sync_op : bool = True ,
87+ ):
88+ return stream .all_reduce (tensor , op = op , group = group , sync_op = sync_op , use_calc_stream = True )
8089
81- def all_reduce (
82- tensor ,
83- op ,
84- group ,
85- sync_op : bool = True ,
86- ):
87- return stream .all_reduce (tensor , op = op , group = group , sync_op = sync_op , use_calc_stream = True )
88-
89-
90- @paddle .jit .marker .unified
91- def tensor_model_parallel_all_reduce_custom (input_ : paddle .Tensor ) -> paddle .Tensor :
92- """All-reduce the input tensor across model parallel group on calc stream."""
93- if paddle .in_dynamic_mode ():
94- hcg = dist .fleet .get_hybrid_communicate_group ()
95- mp_group = hcg .get_model_parallel_group ()
96- all_reduce (input_ , op = ReduceOp .SUM , group = mp_group )
97- else :
98- dist .all_reduce (input_ )
90+ @paddle .jit .marker .unified
91+ def tensor_model_parallel_all_reduce_custom (input_ : paddle .Tensor ) -> paddle .Tensor :
92+ """All-reduce the input tensor across model parallel group on calc stream."""
93+ if paddle .in_dynamic_mode ():
94+ hcg = dist .fleet .get_hybrid_communicate_group ()
95+ mp_group = hcg .get_model_parallel_group ()
96+ all_reduce (input_ , op = ReduceOp .SUM , group = mp_group )
97+ else :
98+ dist .all_reduce (input_ )
99+
100+ except :
101+ tensor_model_parallel_all_reduce_custom = None
0 commit comments