Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions csrc/ir/composite_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,20 @@ std::vector<PolymorphicValue> MatmulOp::evaluate(
matmul_out = matmul_out.unsqueeze(rfactor_did_idx);
}

const auto& [sizes, strides] = inferShapeAndContiguousStrides(out(), ee);
auto meta_out = at::detail::empty_strided_meta(sizes, strides, a.dtype());
// Without InferContiguity, we mistakenly assume the output is contiguous.
if (!isOptionEnabled(EnableOption::InferContiguity)) {
const auto& [sizes, strides] = inferShapeAndContiguousStrides(out(), ee);
auto meta_out = at::detail::empty_strided_meta(sizes, strides, a.dtype());

if (meta_out.is_contiguous()) {
return {matmul_out};
}
if (meta_out.is_contiguous()) {
return {matmul_out};
}

auto strided_matmul_out = at::empty_strided(sizes, strides, a.options());
strided_matmul_out = strided_matmul_out.copy_(matmul_out);
return {strided_matmul_out};
auto strided_matmul_out = at::empty_strided(sizes, strides, a.options());
strided_matmul_out = strided_matmul_out.copy_(matmul_out);
return {strided_matmul_out};
}
return {matmul_out};
}

LinearOp::LinearOp(
Expand Down
1 change: 1 addition & 0 deletions csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ const std::unordered_map<std::string, EnableOption>& getEnableOptions() {
{"p2p_protocol", EnableOption::P2pProtocol},
{"multicast_protocol", EnableOption::MulticastProtocol},
{"parallel_serde", EnableOption::ParallelSerde},
{"infer-contiguity", EnableOption::InferContiguity},
};
return available_options;
}
Expand Down
1 change: 1 addition & 0 deletions csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ enum class EnableOption {
MulticastProtocol, //! Prescribe multicast protocol:
//! memcpy|multimem|batch_memcpy
ParallelSerde, //! Enable deserializing FusionExecutorCache in parallel
InferContiguity, //! Enable contiguity inference
EndOfOption //! Placeholder for counting the number of elements
};

Expand Down
4 changes: 2 additions & 2 deletions csrc/runtime/allocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@

namespace nvfuser {

KernelArgumentHolder inferOutputShapeAndContiguousStrides(
KernelArgumentHolder inferContiguousOutputMetaTensor(
Fusion* fusion,
const KernelArgumentHolder& args,
PrecomputedValues* evaluator_precomputed_values) {
FUSER_PERF_SCOPE(
"fusion_executor::allocations::inferOutputShapeAndContiguousStrides");
"fusion_executor::allocations::inferContiguousOutputMetaTensor");
ExpressionEvaluator expr_eval;

std::unique_ptr<PrecomputedValues> evaluator_precomputed_values_up = nullptr;
Expand Down
2 changes: 1 addition & 1 deletion csrc/runtime/allocations.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct GlobalBufferInfo {
//! pushing scalar int 0 as a place-holder.
//! 2. This API does not allocate output in memory, but only returns the
//! inferred output sizes. Used in runtime/fusion_executor_cache.cpp.
KernelArgumentHolder inferOutputShapeAndContiguousStrides(
KernelArgumentHolder inferContiguousOutputMetaTensor(
Fusion* fusion,
const KernelArgumentHolder& args,
PrecomputedValues* evaluator_precomputed_values = nullptr);
Expand Down
1 change: 1 addition & 0 deletions csrc/runtime/fusion_cache_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <fusion_segmenter.h>
#include <ir/all_nodes.h>
#include <ir/utils.h>
#include <polymorphic_value.h>
#include <runtime/executor_kernel_arg.h>

Expand Down
77 changes: 69 additions & 8 deletions csrc/runtime/fusion_kernel_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,66 @@ KernelArgumentHolder FusionKernelRuntime::runWithInputs(
return fusion_outputs;
}

KernelArgumentHolder FusionKernelRuntime::inferOutputMetaTensor(
HeuristicParamsList* heuristics,
SegmentedGroup* group_to_run,
const KernelArgumentHolder& group_runtime_inputs,
PrecomputedValues* evaluator_precomputed_values) const {
FUSER_PERF_SCOPE("FusionKernelRuntime::inferOutputMetaTensor");
NVF_ERROR(heuristics != nullptr);
Fusion* fusion_to_run = group_to_run->getFusion();
KernelArgumentHolder group_runtime_outputs;
const auto& heuristic_params = heuristics->at(group_to_run->groupId());
const bool is_expr_eval =
heuristic_params->scheduler_type == SchedulerType::ExprEval;
if (is_expr_eval && isOptionEnabled(EnableOption::InferContiguity)) {
// For expr evaluated fusion, the striding rules follow that of ATen.
ExpressionEvaluator eval_fusion;
for (auto i : arange(group_to_run->inputs().size())) {
const auto& tensor_pv = group_runtime_inputs[i];
if (tensor_pv.is<at::Tensor>()) {
const auto& t = tensor_pv.as<at::Tensor>();
if (t.defined()) {
const auto meta_t = at::empty_strided(
t.sizes(),
t.strides(),
at::TensorOptions().device(at::kMeta).dtype(t.dtype()));
eval_fusion.bind(fusion_to_run->inputs()[i], meta_t);
} else {
eval_fusion.bind(fusion_to_run->inputs()[i], t);
}
} else {
eval_fusion.bind(fusion_to_run->inputs()[i], tensor_pv);
}
}
for (auto v : fusion_to_run->outputs()) {
auto result = eval_fusion.evaluate(v);
group_runtime_outputs.push(result);
}
} else {
auto fusion_to_run = group_to_run->getFusion();
return inferContiguousOutputMetaTensor(
fusion_to_run, group_runtime_inputs, evaluator_precomputed_values);
}
return group_runtime_outputs;
}

void FusionKernelRuntime::updateContiguityOfSegmentOutputs(
SegmentedGroup* group_to_run,
const KernelArgumentHolder& group_runtime_outputs) const {
FUSER_PERF_SCOPE("FusionKernelRuntime::updateContiguityOfSegmentOutputs");
if (!isOptionEnabled(EnableOption::InferContiguity)) {
return;
}
for (auto [i, output] : enumerate(group_to_run->outputs())) {
auto tv = dynamic_cast<TensorView*>(output);
if (tv) {
const at::Tensor& tensor = group_runtime_outputs[i].as<at::Tensor>();
ir_utils::resetContiguityFromTensor(tv, tensor);
}
}
}

std::vector<KernelArgumentHolder> FusionKernelRuntime::prepareInputs(
const KernelArgumentHolder& args) const {
std::vector<KernelArgumentHolder> all_runtime_inputs;
Expand All @@ -362,16 +422,14 @@ std::vector<KernelArgumentHolder> FusionKernelRuntime::prepareInputs(
group_runtime_inputs.setCacheId(group_cache_id.value());
}

// TODO: inferOutputShapeAndContiguousStrides doesn't seem to strictly
// require a Fusion for each segment. Consider using the complete fusion
// instead.
auto fusion_to_run = segmented_fusion_->makeFusion(group_to_run).second;
auto group_runtime_outputs = inferOutputShapeAndContiguousStrides(
fusion_to_run.get(), group_runtime_inputs);
auto group_runtime_outputs = inferOutputMetaTensor(
heuristics_.get(), group_to_run, group_runtime_inputs);

// map output args to tensor map
args_manager.updateWithSegmentOutputs(
group_to_run->outputs(), group_runtime_outputs, run_order_id);

updateContiguityOfSegmentOutputs(group_to_run, group_runtime_outputs);
}

return all_runtime_inputs;
Expand Down Expand Up @@ -599,13 +657,16 @@ std::optional<std::unique_ptr<HeuristicParamsList>> FusionKernelRuntime::
}

// Generate metadata for the fusion's outputs
auto group_runtime_outputs = inferOutputShapeAndContiguousStrides(
fusion_to_run,
auto group_runtime_outputs = inferOutputMetaTensor(
heuristics.get(),
group_to_run,
group_runtime_inputs,
evaluator_precomputed_values.get());

args_manager.updateWithSegmentOutputs(
group_to_run->outputs(), group_runtime_outputs, run_order_id);

updateContiguityOfSegmentOutputs(group_to_run, group_runtime_outputs);
}
return heuristics;
}
Expand Down
21 changes: 21 additions & 0 deletions csrc/runtime/fusion_kernel_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,27 @@ class FusionKernelRuntime {
//! Access the list of schedulers maintained in this runtime instance
const std::vector<std::unique_ptr<HeuristicParams>>& schedulers() const;

//! Infer the output shape and stride of the fusion as tensors on Meta device
//! If the group is scheduled to be evaluated using ExprEval, the output
//! tensors are inferred using the ExprEval on meta device. Otherwise, the
//! output tensors are inferred assuming they are contiguous.
KernelArgumentHolder inferOutputMetaTensor(
HeuristicParamsList* heuristics,
SegmentedGroup* group_to_run,
const KernelArgumentHolder& group_runtime_inputs,
PrecomputedValues* evaluator_precomputed_values = nullptr) const;

//! When a FusionIR is constructed, all intermediate tensors are assumed to
//! be contiguous. Unfortunately, this assumption is not always true, and
//! could not be determined at compile time. Depending on the runtime inputs,
//! we may segment the fusions differently, and some fusion segments would be
//! executed using ATen, which may not generate contiguous tensors. So we have
//! to update the contiguity of the segment outputs on the fly depending on
//! the runtime inputs.
void updateContiguityOfSegmentOutputs(
SegmentedGroup* group_to_run,
const KernelArgumentHolder& group_runtime_outputs) const;

// Create KernelArgumentHolders for all of the segments. Sorted in
// the run order.
std::vector<KernelArgumentHolder> prepareInputs(
Expand Down
3 changes: 3 additions & 0 deletions tests/cpp/test_alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,9 @@ TEST_F(AliasTest, Issue1452) {
}

TEST_F(AliasTest, AliasOutputBeforeNonAliasOutput) {
EnableOptionsGuard opt_guard;
EnableOptionsGuard::getCurOptions().unset(EnableOption::InferContiguity);

auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/test_indexing_advanced.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ class AdvancedIndexingTest : public NVFuserFixtureParamTest<bool> {
} else {
EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel);
}
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}
};

class AdvancedIndexingIdModelTest : public NVFuserTest {
protected:
void SetUp() override {
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}
};

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_layout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class LayoutOpTest : public NVFuserTest {
void SetUp() override {
NVFuserTest::SetUp();
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().unset(EnableOption::InferContiguity);
}
};

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_loop_domain_scheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class LoopDomainSchedulingTest : public NVFuserTest {
protected:
void SetUp() override {
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}
};

Expand Down
8 changes: 7 additions & 1 deletion tests/cpp/test_low_precision_recipe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,13 @@ TEST_F(BlockQuantizationValidationTest, MergesMustBeContiguous) {
class BlockQuantizationSchedulingTest
: public BlackwellBase,
public ::testing::WithParamInterface<
std::tuple<DataType, std::pair<int, int>, bool, bool>> {};
std::tuple<DataType, std::pair<int, int>, bool, bool>> {
protected:
void SetUp() override {
BlackwellBase::SetUp();
EnableOptionsGuard::getCurOptions().unset(EnableOption::InferContiguity);
}
};

TEST_P(BlockQuantizationSchedulingTest, AutoScheduleSingleOp) {
const auto data_type = std::get<0>(GetParam());
Expand Down
33 changes: 0 additions & 33 deletions tests/cpp/test_matmul_aten_evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,37 +371,4 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values(Sizes({n, 1})),
testing::Values(Sizes({n}))));

using MatmulNodeTest = NVFuserTest;

TEST_F(MatmulNodeTest, OutputStrides) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

TensorView* x = makeSymbolicTensor(2, DataType::Half);
TensorView* y = makeSymbolicTensor(2, DataType::Half);
TensorView* z = matmul(x, y);

fusion->addInput(x);
fusion->addInput(y);
fusion->addOutput(z);

z->setAllocationDomain({z->axis(1), z->axis(0), z->axis(2)}, true);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
at::Tensor x_tensor = at::randn({2, 3}, options);
at::Tensor y_tensor = at::randn({3, 5}, options);

FusionExecutorCache executor_cache(std::move(fusion));
auto outs = executor_cache.runFusionWithInputs({x_tensor, y_tensor});
at::Tensor z_tensor = outs[0].as<at::Tensor>();
testValidate(
executor_cache.fusion(),
{z_tensor},
{x_tensor, y_tensor},
__LINE__,
__FILE__);

EXPECT_THAT(z_tensor.strides(), ElementsAre(1, 2));
}

} // namespace nvfuser
1 change: 1 addition & 0 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2802,6 +2802,7 @@ class MatmulFusionTest
EnableOptionsGuard::getCurOptions().set(
EnableOption::FuseMultipleMatmuls);
}
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}

bool fusion_enabled = GetParam().first;
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class PointwiseTest : public NVFuserTest {
protected:
void SetUp() override {
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}
};

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_rng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ at::Tensor generate_normal(int64_t size, at::ScalarType dtype) {
class RNGTest : public NVFuserTest {
void SetUp() override {
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}
};

Expand Down
5 changes: 4 additions & 1 deletion tests/cpp/test_segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,10 @@ TEST_F(SegmentationTest, RevertPrivatizedUpcast) {

++num_upcast_ops;
}
EXPECT_EQ(num_upcast_ops, 1);
// There is an unswitched IfThenElse in the generated kernel, and in each
// of its branches, there is an upcast op with tv1 as its producer. So we
// should have two upcast ops.
EXPECT_EQ(num_upcast_ops, 2);
}
}

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ void NVFuserTest::SetUp() {
GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs";
}
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}

NVFuserTest::~NVFuserTest() {
Expand Down
Loading