Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions csrc/host_ir/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ void recomputeOutputTvs(Expr* e, IrCloner& ir_cloner) {
}
}

// Finds the stream IterDomain in the outputs of a segment.
IterDomain* findStreamIterDomain(const std::vector<Val*>& outs) {
for (auto* out : ir_utils::filterByType<TensorView>(outs)) {
const std::vector<IterDomain*>& loop = out->getLoopDomain();
// FinalizeMultideviceDomains pass puts the stream IterDomain to the
// front.
if (!loop.empty() && loop.front()->isStream()) {
return loop.front();
}
}
return nullptr;
}

void lowerSegment(
const SegmentedGroup& group,
const AliasInfoMap& aliases,
Expand Down Expand Up @@ -112,17 +125,34 @@ void lowerSegment(

// Add the LaunchKernel instruction.
KernelExecutor& ke = hic.getKernelExecutor(group_id);
IterDomain* stream_id = findStreamIterDomain(cloned_outs);
// Needed for KernelExecutor. Should be removed once #4927 is fixed.
auto* cache_id =
IrBuilder::create<NamedScalar>("cacheId", DataType::UInt64);
auto launch_kernel = IrBuilder::create<hir::LaunchKernel>(
group_id,
launch_params,
ke.compiledKernel()->compileParams(),
cloned_ins,
cloned_outs,
cache_id);
hic.pushBackTopLevelExprs(launch_kernel);
if (stream_id == nullptr) {
auto launch_kernel = IrBuilder::create<hir::LaunchKernel>(
group_id,
launch_params,
ke.compiledKernel()->compileParams(),
cloned_ins,
cloned_outs,
cache_id);
hic.pushBackTopLevelExprs(launch_kernel);
} else {
auto* stream_index = IrBuilder::create<Val>(DataType::Index);
auto* for_loop =
hir::createForLoopFromIterDomain(stream_index, stream_id);
cloned_ins.push_back(stream_index);
auto launch_kernel = IrBuilder::create<hir::LaunchKernel>(
group_id,
launch_params,
ke.compiledKernel()->compileParams(),
cloned_ins,
cloned_outs,
cache_id);
for_loop->body().push_back(launch_kernel);
hic.pushBackTopLevelExprs(for_loop);
}
}
}
} // namespace
Expand Down
3 changes: 2 additions & 1 deletion csrc/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,8 @@ void ensureStaticIndexing(
tv->getLoopDomain().begin(),
tv->getLoopDomain().end(),
[loop_id, &id_map](IterDomain* id) {
if (id->isBroadcast() || id->isReduction() || id->isStride()) {
if (id->isBroadcast() || id->isReduction() || id->isStride() ||
id->isStream()) {
return false;
}
auto id_replacement = id_map.find(id);
Expand Down
7 changes: 5 additions & 2 deletions csrc/ir/internal_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ class NVF_API TensorDomain : public Val {
}

int64_t nDims() const {
return static_cast<int64_t>(loop_domain_.size());
return std::ssize(loop_domain_);
}

bool sameAs(const Statement* other) const override;
Expand Down Expand Up @@ -623,7 +623,7 @@ class NVF_API TensorDomain : public Val {
return std::find(loop().begin(), loop().end(), id) != loop().end();
}

// Check if id is an intial loop ID.
// Check if id is an initial loop ID.
bool isInitialLoop(const IterDomain* id) const {
return std::find(initialLoop().begin(), initialLoop().end(), id) !=
loop().end();
Expand Down Expand Up @@ -733,6 +733,7 @@ class NVF_API TensorDomain : public Val {
static std::vector<IterDomain*> noReductions(const std::vector<IterDomain*>&);
static std::vector<IterDomain*> noBroadcasts(const std::vector<IterDomain*>&);
static std::vector<IterDomain*> noDevices(const std::vector<IterDomain*>&);
static std::vector<IterDomain*> noStream(const std::vector<IterDomain*>&);
// Usage example: `domain | TensorDomain::kNoDevices`. Unlike noDevices, this
// returns a view so is more efficient. However, make sure `domain` outlives
// the view.
Expand All @@ -742,6 +743,8 @@ class NVF_API TensorDomain : public Val {
[](IterDomain* id) { return !id->isReduction() && !id->isStride(); });
inline static constexpr auto kNoBroadcasts =
std::views::filter([](IterDomain* id) { return !id->isBroadcast(); });
inline static constexpr auto kNoStreams =
std::views::filter([](IterDomain* id) { return !id->isStream(); });

static bool hasBroadcast(const std::vector<IterDomain*>&);
static bool hasReduction(const std::vector<IterDomain*>&);
Expand Down
43 changes: 25 additions & 18 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3903,35 +3903,42 @@ void TensorDomain::resize(

std::vector<IterDomain*> TensorDomain::noReductions(
const std::vector<IterDomain*>& td) {
std::vector<IterDomain*> noReductionDomain;
std::vector<IterDomain*> filtered;
std::copy_if(
td.begin(),
td.end(),
std::back_inserter(noReductionDomain),
[](IterDomain* id) { return !id->isReduction() && !id->isStride(); });
return noReductionDomain;
td.begin(), td.end(), std::back_inserter(filtered), [](IterDomain* id) {
return !id->isReduction() && !id->isStride();
});
return filtered;
}

std::vector<IterDomain*> TensorDomain::noBroadcasts(
const std::vector<IterDomain*>& td) {
std::vector<IterDomain*> noBroadcastDomain;
std::vector<IterDomain*> filtered;
std::copy_if(
td.begin(),
td.end(),
std::back_inserter(noBroadcastDomain),
[](IterDomain* id) { return !id->isBroadcast(); });
return noBroadcastDomain;
td.begin(), td.end(), std::back_inserter(filtered), [](IterDomain* id) {
return !id->isBroadcast();
});
return filtered;
}

std::vector<IterDomain*> TensorDomain::noDevices(
const std::vector<IterDomain*>& td) {
std::vector<IterDomain*> noDeviceDomain;
std::vector<IterDomain*> filtered;
std::copy_if(
td.begin(), td.end(), std::back_inserter(filtered), [](IterDomain* id) {
return !id->isDeviceDim();
});
return filtered;
}

std::vector<IterDomain*> TensorDomain::noStream(
const std::vector<IterDomain*>& td) {
std::vector<IterDomain*> filtered;
std::copy_if(
td.begin(),
td.end(),
std::back_inserter(noDeviceDomain),
[](IterDomain* id) { return !id->isDeviceDim(); });
return noDeviceDomain;
td.begin(), td.end(), std::back_inserter(filtered), [](IterDomain* id) {
return !id->isStream();
});
return filtered;
}

/*static*/ std::vector<std::optional<bool>> TensorDomain::
Expand Down
6 changes: 5 additions & 1 deletion csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
#include <multidevice/utils.h>

#include <algorithm>
#include <optional>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include <device_lower/utils.h>
Expand Down Expand Up @@ -313,7 +315,9 @@ int64_t numDeviceDims(const TensorView* tv) {
return std::count_if(
tv->getLoopDomain().begin(),
tv->getLoopDomain().end(),
[](IterDomain* id) { return id->isDeviceDim() && !id->isReduction(); });
[](IterDomain* id) {
return (id->isDeviceDim() || id->isStream()) && !id->isReduction();
});
}

namespace {
Expand Down
20 changes: 3 additions & 17 deletions csrc/preseg_passes/finalize_multidevice_domains.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,28 +112,14 @@ void setLoopAndAllocationDomain(TensorView* tv, bool is_resharding) {
return;
}

// Most schedulers require DIDx to be at the front of the loop domain.
auto old2new = reorderParallelizedToFront(tv);
auto new2old = ir_utils::normalizeOld2New(old2new, tv->nDims());
std::vector<std::optional<bool>> reordered_contiguity;
std::transform(
new2old.begin(),
new2old.end(),
std::back_inserter(reordered_contiguity),
[&new_contiguity](int64_t i) -> std::optional<bool> {
return new_contiguity[i];
});
tv->setAllocationDomain(tv->getLoopDomain(), reordered_contiguity);
// Most schedulers require Stream and DIDx to be at the front of the loop
// domain.
reorderParallelizedToFront(tv);
}

} // namespace

void FinalizeMultideviceDomainsPass::runPass(Fusion* fusion) {
bool has_mesh = validateMeshes(fusion);
if (!has_mesh) {
return;
}

for (Expr* expr : fusion->exprs()) {
auto inputs = ir_utils::filterByType<TensorView>(expr->inputs());
auto outputs = ir_utils::filterByType<TensorView>(expr->outputs());
Expand Down
5 changes: 3 additions & 2 deletions csrc/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
}
// We always cacheBefore output at the beginning of the scheduling. And after
// cacheBefore, the reference tensor will have all reduction IDs removed.
ref_loop = TensorDomain::noDevices(TensorDomain::noReductions(ref_loop));
ref_loop = TensorDomain::noStream(
TensorDomain::noDevices(TensorDomain::noReductions(ref_loop)));

std::vector<int64_t> elem_counts;
elem_counts.reserve(ref_loop.size());
Expand Down Expand Up @@ -357,7 +358,7 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(

auto& view_disjoint_sets = broadcast_info.get().view_disjoint_set_ids;
auto& broadcast_bit_multiples = broadcast_info.get().broadcast_multiples;
NVF_ERROR(broadcast_bit_multiples.size() == ref_loop.size());
NVF_ERROR_EQ(broadcast_bit_multiples.size(), ref_loop.size());

int64_t dtype_sum_bit = 0;
for (auto inp : ir_utils::filterByType<TensorView>(fusion->inputs())) {
Expand Down
2 changes: 2 additions & 0 deletions csrc/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ std::pair<TensorDomain*, int64_t> TransformReplay::replayCasP(
", requested in replay.");
continue;
}
it->second->parallelize(p_id->getParallelType());
new_loop.push_back(it->second);
used_loop.emplace(it->second);
}
Expand All @@ -840,6 +841,7 @@ std::pair<TensorDomain*, int64_t> TransformReplay::replayCasP(
continue;
}
if (used_loop.find(id) == used_loop.end()) {
id->parallelize(p_id->getParallelType());
new_loop.push_back(id);
used_loop.emplace(id);
}
Expand Down
1 change: 0 additions & 1 deletion tests/cpp/test_multidevice_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

namespace nvfuser {

using testing::Contains;
using testing::Each;
using testing::ElementsAre;
using testing::Not;
Expand Down
41 changes: 20 additions & 21 deletions tests/cpp/test_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,28 @@
#include <fusion_guard.h>
#include <ir/interface_nodes.h>
#include <ops/arith.h>
#include <runtime/fusion_executor_cache.h>
#include <tests/cpp/utils.h>
#include <tests/cpp/validator.h>

namespace nvfuser {

using StreamTest = NVFuserTest;
class StreamTest : public NVFuserTest {
protected:
StreamTest() {
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
}
};

TEST_F(StreamTest, AddPerStream) {
constexpr int64_t c = 3;
Fusion fusion;
FusionGuard fg(&fusion);
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

TensorView* in = makeContigTensor(2);
TensorView* out = add(in, in);
fusion.addInput(in);
fusion.addOutput(out);
fusion->addInput(in);
fusion->addOutput(out);

in->outer_split(1, c);
in->axis(1)->parallelize(ParallelType::Stream);
Expand All @@ -37,22 +44,14 @@ TEST_F(StreamTest, AddPerStream) {

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA);
at::Tensor in_tensor = at::randn({5, c * 2}, options);
at::Tensor out_tensor = at::zeros_like(in_tensor);

KernelExecutor ke;
ke.compile(&fusion, {in_tensor});
constexpr int64_t kStreamIndex = 1;
ke.run({in_tensor, kStreamIndex}, {out_tensor});

at::Tensor expected_out_tensor = in_tensor + in_tensor;
std::vector<at::Tensor> chunks = expected_out_tensor.chunk(c, 1);
for (auto [i, chunk] : enumerate(chunks)) {
if (i != kStreamIndex) {
chunk.zero_();
}
}
EXPECT_TRUE(at::allclose(out_tensor, expected_out_tensor))
<< out_tensor << " vs " << expected_out_tensor;

FusionExecutorCache executor_cache(std::move(fusion));
KernelArgumentHolder out_tensors =
executor_cache.runFusionWithInputs({in_tensor});
auto out_tensor = out_tensors[0].as<at::Tensor>();

testValidate(
executor_cache.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__);
}

} // namespace nvfuser