1111#include < host_ir/lower_to_communication.h>
1212#include < host_ir/lowering.h>
1313#include < host_ir/pass/insert_deallocations.h>
14+ #include < multidevice/utils.h>
1415#include < runtime/executor_abstract.h>
1516
1617namespace nvfuser {
@@ -36,6 +37,26 @@ void recomputeOutputTvs(Expr* e, IrCloner& ir_cloner) {
3637 }
3738}
3839
40+ IterDomain* findStreamIterDomain (TensorView* tv) {
41+ const std::vector<IterDomain*>& loop = tv->getLoopDomain ();
42+ // FinalizeMultideviceDomains pass puts the stream IterDomain to the
43+ // front.
44+ if (!loop.empty () && loop.front ()->isStream ()) {
45+ return loop.front ();
46+ }
47+ return nullptr ;
48+ }
49+
50+ // Finds the stream IterDomain in the outputs of a segment.
51+ IterDomain* findStreamIterDomain (const std::vector<Val*>& outs) {
52+ for (auto * out : ir_utils::filterByType<TensorView>(outs)) {
53+ if (auto * stream_id = findStreamIterDomain (out)) {
54+ return stream_id;
55+ }
56+ }
57+ return nullptr ;
58+ }
59+
3960void lowerSegment (
4061 const SegmentedGroup& group,
4162 const AliasInfoMap& aliases,
@@ -72,15 +93,99 @@ void lowerSegment(
7293 }
7394 } break ;
7495 case SchedulerType::ExprEval: {
75- // push back segment's exprs into the container as top level
76- // expressions
77- for (auto * e : group.stablyOrderedExprs ()) {
96+ // Pseudocode:
97+ // clang-format off
98+ // ```
99+ // clone all expressions and store the copies to a list
100+ // if no expressions are stream parallelized:
101+ // append the list to the top level
102+ // return
103+ // for each non-input TensorView:
104+ // if it needs an out-of-loop allocation:
105+ // create an Allocate and append it to the top level
106+ // create a new, empty for loop
107+ // for each cloned expression:
108+ // for each input or output TensorView of that expression:
109+ // shard it by stream if it's allocated outside the loop
110+ // add the cloned expression to the loop body with the maybe-sharded inputs and outputs
111+ // ```
112+ // clang-format on
113+ std::vector<Expr*> cloned_exprs;
114+ cloned_exprs.reserve (group.exprs ().size ());
115+ for (Expr* e : group.stablyOrderedExprs ()) {
78116 auto * e_clone = ir_cloner.clone (e);
79117 recomputeOutputTvs (e, ir_cloner);
80- hic.pushBackTopLevelExprs (e_clone);
118+ cloned_exprs.push_back (e_clone);
119+ }
120+
121+ std::vector<Val*> cloned_outs = ir_cloner.clone (group.outputs ());
122+ // All expressions in the group are expected to be stream parallelized in
123+ // the same way. So it's safe to find the stream IterDomain from any of
124+ // them. Ideally, loop domains should be tied to expressions not
125+ // TensorViews.
126+ IterDomain* stream_id = findStreamIterDomain (cloned_outs);
127+ if (stream_id == nullptr ) {
128+ for (Expr* e : cloned_exprs) {
129+ hic.pushBackTopLevelExprs (e);
130+ }
131+ } else {
132+ for (Expr* e : cloned_exprs) {
133+ for (auto * out : ir_utils::filterByType<TensorView>(e->outputs ())) {
134+ if (getShardedIterDomain (out, ParallelType::Stream) == nullptr ) {
135+ auto * allocate =
136+ IrBuilder::create<kir::Allocate>(out, MemoryType::Global);
137+ hic.pushBackTopLevelExprs (allocate);
138+ }
139+ }
140+ }
141+
142+ auto * stream_index = IrBuilder::create<Val>(DataType::Index);
143+ auto * for_loop =
144+ hir::ForLoop::createFromIterDomain (stream_index, stream_id);
145+ hic.pushBackTopLevelExprs (for_loop);
146+
147+ std::unordered_map<Val*, Val*> replacement_map;
148+ for (Expr* e : cloned_exprs) {
149+ for (auto ins_or_out :
150+ {ir_utils::filterByType<TensorView>(e->inputs ()),
151+ ir_utils::filterByType<TensorView>(e->outputs ())}) {
152+ for (auto * tv : ins_or_out) {
153+ if (replacement_map.count (tv) > 0 ) {
154+ continue ;
155+ }
156+ if (findStreamIterDomain (tv) != nullptr &&
157+ getShardedIterDomain (tv, ParallelType::Stream) == nullptr ) {
158+ // Loop is stream parallelized but allocation is not.
159+ TensorView* sharded_tv = hir::shardByStream (tv, stream_index);
160+ for_loop->body ().push_back (sharded_tv->definition ());
161+ replacement_map[tv] = sharded_tv;
162+ }
163+ }
164+ }
165+
166+ std::vector<Val*> new_inputs;
167+ std::transform (
168+ e->inputs ().begin (),
169+ e->inputs ().end (),
170+ std::back_inserter (new_inputs),
171+ [&replacement_map](Val* input) {
172+ return getOrDefault (replacement_map, input, input);
173+ });
174+ std::vector<Val*> new_outputs;
175+ std::transform (
176+ e->outputs ().begin (),
177+ e->outputs ().end (),
178+ std::back_inserter (new_outputs),
179+ [&replacement_map](Val* output) {
180+ return getOrDefault (replacement_map, output, output);
181+ });
182+ Expr* new_e = e->newObjectFunc ()(
183+ e->container (), new_inputs, new_outputs, e->attributes ());
184+ for_loop->body ().push_back (new_e);
185+ }
81186 }
82187 } break ;
83- default :
188+ default : {
84189 const int group_id = group.groupId ();
85190
86191 // Copy the input/output TensorViews to the container.
@@ -123,6 +228,7 @@ void lowerSegment(
123228 cloned_outs,
124229 cache_id);
125230 hic.pushBackTopLevelExprs (launch_kernel);
231+ }
126232 }
127233}
128234} // namespace
0 commit comments