Skip to content

Commit dc3ece7

Browse files
committed
WIP
1 parent 06f565d commit dc3ece7

File tree

9 files changed

+184
-31
lines changed

9 files changed

+184
-31
lines changed

csrc/host_ir/container.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,13 @@ class HostIrContainer final : public Fusion {
4545
Stream* getDefaultStream();
4646

4747
private:
48+
// Consider using a linkedlist so insertion is faster.
4849
std::vector<Expr*> top_level_exprs_;
50+
4951
// Indexed by group ID. This way, parallel compilation can write to disjoint
5052
// locations without having to precompute a global index.
5153
std::vector<std::unique_ptr<KernelExecutor>> kernel_executors_;
54+
5255
Stream* default_stream_ = nullptr;
5356
};
5457

csrc/host_ir/host_ir.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <kernel_ir.h>
1818
#include <multidevice/communication.h>
1919
#include <ops/all_ops.h>
20+
#include <transform_replay.h>
2021
#include <utils.h>
2122

2223
namespace nvfuser::hir {
@@ -450,6 +451,16 @@ std::string ShardByStream::toInlineString(int indent_size) const {
450451
NVF_CHECK(false, "Cannot be printed inline");
451452
}
452453

454+
TensorView* shardByStream(TensorView* in, Val* stream_index) {
455+
auto* out = ops::newValLike(in, *in->getDataType())->as<TensorView>();
456+
457+
TransformReplay::selfReplay(in->domain(), out->domain());
458+
out->setAllocationDomain(out->getLoopDomain(), false);
459+
460+
IrBuilder::create<ShardByStream>(out, in, stream_index);
461+
return out;
462+
}
463+
453464
ForLoop::ForLoop(IrBuilderPasskey passkey, Val* index, Val* start, Val* stop)
454465
: Expr(passkey, {index, start, stop}, {}, {}) {
455466
NVF_ERROR(passkey.ir_container_ != nullptr);
@@ -473,7 +484,9 @@ std::string ForLoop::toInlineString(int indent_size) const {
473484
NVF_CHECK(false, "Cannot be printed inline");
474485
}
475486

476-
ForLoop* createForLoopFromIterDomain(Val* index, IterDomain* iter_domain) {
487+
/*static*/ ForLoop* ForLoop::createFromIterDomain(
488+
Val* index,
489+
IterDomain* iter_domain) {
477490
return IrBuilder::create<ForLoop>(
478491
index, iter_domain->start(), iter_domain->stop());
479492
}

csrc/host_ir/host_ir.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,10 @@ class ShardByStream : public Expr {
479479
}
480480
};
481481

482+
// Creates a ShardByStream without needing the output TensorView. Returns the
483+
// output TensorView.
484+
TensorView* shardByStream(TensorView* in, Val* stream_index);
485+
482486
class ForLoop : public Expr {
483487
public:
484488
using Expr::Expr;
@@ -492,6 +496,8 @@ class ForLoop : public Expr {
492496

493497
NVFUSER_DECLARE_CLONE_AND_CREATE
494498

499+
static ForLoop* createFromIterDomain(Val* index, IterDomain* iter_domain);
500+
495501
std::string toString(int indent_size = 0) const override;
496502
std::string toInlineString(int indent_size = 0) const override;
497503
const char* getOpString() const override {
@@ -519,6 +525,4 @@ class ForLoop : public Expr {
519525
}
520526
};
521527

522-
ForLoop* createForLoopFromIterDomain(Val* index, IterDomain* iter_domain);
523-
524528
} // namespace nvfuser::hir

csrc/host_ir/lowering.cpp

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <host_ir/lower_to_communication.h>
1212
#include <host_ir/lowering.h>
1313
#include <host_ir/pass/insert_deallocations.h>
14+
#include <multidevice/utils.h>
1415
#include <runtime/executor_abstract.h>
1516

1617
namespace nvfuser {
@@ -36,6 +37,26 @@ void recomputeOutputTvs(Expr* e, IrCloner& ir_cloner) {
3637
}
3738
}
3839

40+
IterDomain* findStreamIterDomain(TensorView* tv) {
41+
const std::vector<IterDomain*>& loop = tv->getLoopDomain();
42+
// FinalizeMultideviceDomains pass puts the stream IterDomain to the
43+
// front.
44+
if (!loop.empty() && loop.front()->isStream()) {
45+
return loop.front();
46+
}
47+
return nullptr;
48+
}
49+
50+
// Finds the stream IterDomain in the outputs of a segment.
51+
IterDomain* findStreamIterDomain(const std::vector<Val*>& outs) {
52+
for (auto* out : ir_utils::filterByType<TensorView>(outs)) {
53+
if (auto* stream_id = findStreamIterDomain(out)) {
54+
return stream_id;
55+
}
56+
}
57+
return nullptr;
58+
}
59+
3960
void lowerSegment(
4061
const SegmentedGroup& group,
4162
const AliasInfoMap& aliases,
@@ -72,15 +93,99 @@ void lowerSegment(
7293
}
7394
} break;
7495
case SchedulerType::ExprEval: {
75-
// push back segment's exprs into the container as top level
76-
// expressions
77-
for (auto* e : group.stablyOrderedExprs()) {
96+
// Pseudocode:
97+
// clang-format off
98+
// ```
99+
// clone all expressions and store the copies to a list
100+
// if no expressions are stream parallelized:
101+
// append the list to the top level
102+
// return
103+
// for each non-input TensorView:
104+
// if it needs an out-of-loop allocation:
105+
// create an Allocate and append it to the top level
106+
// create a new, empty for loop
107+
// for each cloned expression:
108+
// for each input or output TensorView of that expression:
109+
// shard it by stream if it's allocated outside the loop
110+
// add the cloned expression to the loop body with the maybe-sharded inputs and outputs
111+
// ```
112+
// clang-format on
113+
std::vector<Expr*> cloned_exprs;
114+
cloned_exprs.reserve(group.exprs().size());
115+
for (Expr* e : group.stablyOrderedExprs()) {
78116
auto* e_clone = ir_cloner.clone(e);
79117
recomputeOutputTvs(e, ir_cloner);
80-
hic.pushBackTopLevelExprs(e_clone);
118+
cloned_exprs.push_back(e_clone);
119+
}
120+
121+
std::vector<Val*> cloned_outs = ir_cloner.clone(group.outputs());
122+
// All expressions in the group are expected to be stream parallelized in
123+
// the same way. So it's safe to find the stream IterDomain from any of
124+
// them. Ideally, loop domains should be tied to expressions not
125+
// TensorViews.
126+
IterDomain* stream_id = findStreamIterDomain(cloned_outs);
127+
if (stream_id == nullptr) {
128+
for (Expr* e : cloned_exprs) {
129+
hic.pushBackTopLevelExprs(e);
130+
}
131+
} else {
132+
for (Expr* e : cloned_exprs) {
133+
for (auto* out : ir_utils::filterByType<TensorView>(e->outputs())) {
134+
if (getShardedIterDomain(out, ParallelType::Stream) == nullptr) {
135+
auto* allocate =
136+
IrBuilder::create<kir::Allocate>(out, MemoryType::Global);
137+
hic.pushBackTopLevelExprs(allocate);
138+
}
139+
}
140+
}
141+
142+
auto* stream_index = IrBuilder::create<Val>(DataType::Index);
143+
auto* for_loop =
144+
hir::ForLoop::createFromIterDomain(stream_index, stream_id);
145+
hic.pushBackTopLevelExprs(for_loop);
146+
147+
std::unordered_map<Val*, Val*> replacement_map;
148+
for (Expr* e : cloned_exprs) {
149+
for (auto ins_or_out :
150+
{ir_utils::filterByType<TensorView>(e->inputs()),
151+
ir_utils::filterByType<TensorView>(e->outputs())}) {
152+
for (auto* tv : ins_or_out) {
153+
if (replacement_map.count(tv) > 0) {
154+
continue;
155+
}
156+
if (findStreamIterDomain(tv) != nullptr &&
157+
getShardedIterDomain(tv, ParallelType::Stream) == nullptr) {
158+
// Loop is stream parallelized but allocation is not.
159+
TensorView* sharded_tv = hir::shardByStream(tv, stream_index);
160+
for_loop->body().push_back(sharded_tv->definition());
161+
replacement_map[tv] = sharded_tv;
162+
}
163+
}
164+
}
165+
166+
std::vector<Val*> new_inputs;
167+
std::transform(
168+
e->inputs().begin(),
169+
e->inputs().end(),
170+
std::back_inserter(new_inputs),
171+
[&replacement_map](Val* input) {
172+
return getOrDefault(replacement_map, input, input);
173+
});
174+
std::vector<Val*> new_outputs;
175+
std::transform(
176+
e->outputs().begin(),
177+
e->outputs().end(),
178+
std::back_inserter(new_outputs),
179+
[&replacement_map](Val* output) {
180+
return getOrDefault(replacement_map, output, output);
181+
});
182+
Expr* new_e = e->newObjectFunc()(
183+
e->container(), new_inputs, new_outputs, e->attributes());
184+
for_loop->body().push_back(new_e);
185+
}
81186
}
82187
} break;
83-
default:
188+
default: {
84189
const int group_id = group.groupId();
85190

86191
// Copy the input/output TensorViews to the container.
@@ -123,6 +228,7 @@ void lowerSegment(
123228
cloned_outs,
124229
cache_id);
125230
hic.pushBackTopLevelExprs(launch_kernel);
231+
}
126232
}
127233
}
128234
} // namespace

csrc/multidevice/multidevice.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#pragma once
1010

11+
#include <vector>
12+
1113
#include <c10/core/Device.h>
1214

1315
namespace nvfuser {

csrc/preseg_passes/finalize_multidevice_domains.cpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -113,27 +113,12 @@ void setLoopAndAllocationDomain(TensorView* tv, bool is_resharding) {
113113
}
114114

115115
// Most schedulers require DIDx to be at the front of the loop domain.
116-
auto old2new = reorderParallelizedToFront(tv);
117-
auto new2old = ir_utils::normalizeOld2New(old2new, tv->nDims());
118-
std::vector<std::optional<bool>> reordered_contiguity;
119-
std::transform(
120-
new2old.begin(),
121-
new2old.end(),
122-
std::back_inserter(reordered_contiguity),
123-
[&new_contiguity](int64_t i) -> std::optional<bool> {
124-
return new_contiguity[i];
125-
});
126-
tv->setAllocationDomain(tv->getLoopDomain(), reordered_contiguity);
116+
reorderParallelizedToFront(tv);
127117
}
128118

129119
} // namespace
130120

131121
void FinalizeMultideviceDomainsPass::runPass(Fusion* fusion) {
132-
bool has_mesh = validateMeshes(fusion);
133-
if (!has_mesh) {
134-
return;
135-
}
136-
137122
for (Expr* expr : fusion->exprs()) {
138123
auto inputs = ir_utils::filterByType<TensorView>(expr->inputs());
139124
auto outputs = ir_utils::filterByType<TensorView>(expr->outputs());

csrc/runtime/fusion_executor_cache.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class NVF_API FusionExecutorCache {
134134
//! what inputs and the fusion look like. This may be useful in some
135135
//! cases as our analysis of index type may be overly conservative
136136
//! for intermediate tensors.
137-
//! WARING: Correctness is not guaranteed.
137+
//! WANRING: Correctness is not guaranteed.
138138
//! TODO: Check usage of forced_index_type. It's a lot of plumbing, what's the
139139
//! value.
140140
KernelArgumentHolder runFusionWithInputs(

csrc/transform_replay.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,13 @@
77
// clang-format on
88
#pragma once
99

10+
#include <unordered_map>
11+
1012
#include <exceptions.h>
1113
#include <ir/internal_nodes.h>
1214
#include <scheduler/tools/maxinfo_propagator.h>
1315
#include <visibility.h>
1416

15-
#include <algorithm>
16-
#include <unordered_map>
17-
#include <unordered_set>
18-
#include <vector>
19-
2017
namespace nvfuser {
2118

2219
/*

tests/cpp/test_stream.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,19 @@
1414
#include <fusion_guard.h>
1515
#include <ir/interface_nodes.h>
1616
#include <ops/arith.h>
17+
#include <ops/composite.h>
18+
#include <runtime/fusion_executor_cache.h>
1719
#include <tests/cpp/utils.h>
20+
#include <tests/cpp/validator.h>
1821

1922
namespace nvfuser {
2023

21-
using StreamTest = NVFuserTest;
24+
class StreamTest : public NVFuserTest {
25+
public:
26+
StreamTest() {
27+
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
28+
}
29+
};
2230

2331
TEST_F(StreamTest, AddPerStream) {
2432
constexpr int64_t c = 3;
@@ -55,4 +63,39 @@ TEST_F(StreamTest, AddPerStream) {
5563
<< out_tensor << " vs " << expected_out_tensor;
5664
}
5765

66+
TEST_F(StreamTest, Matmul) {
67+
constexpr int64_t c = 3;
68+
69+
auto fusion = std::make_unique<Fusion>();
70+
{
71+
FusionGuard fg(fusion.get());
72+
TensorView* in = makeSymbolicTensor(2);
73+
TensorView* w = makeSymbolicTensor(2);
74+
TensorView* out = matmul(in, w);
75+
fusion->addInput(in);
76+
fusion->addInput(w);
77+
fusion->addOutput(out);
78+
79+
w->outer_split(1, c);
80+
w->axis(1)->parallelize(ParallelType::Stream);
81+
out->outer_split(1, c);
82+
out->axis(1)->parallelize(ParallelType::Stream);
83+
}
84+
85+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA);
86+
at::Tensor in_tensor = at::randn({5, 7}, options);
87+
at::Tensor w_tensor = at::randn({7, c * 2}, options);
88+
89+
FusionExecutorCache executor_cache(std::move(fusion));
90+
auto out_tensor = executor_cache.runFusionWithInputs({in_tensor, w_tensor})[0]
91+
.as<at::Tensor>();
92+
93+
testValidate(
94+
executor_cache.fusion(),
95+
{out_tensor},
96+
{in_tensor, w_tensor},
97+
__LINE__,
98+
__FILE__);
99+
}
100+
58101
} // namespace nvfuser

0 commit comments

Comments
 (0)