diff --git a/csrc/ir/composite_nodes.cpp b/csrc/ir/composite_nodes.cpp index 20b79a7c318..22b343485f6 100644 --- a/csrc/ir/composite_nodes.cpp +++ b/csrc/ir/composite_nodes.cpp @@ -88,16 +88,20 @@ std::vector MatmulOp::evaluate( matmul_out = matmul_out.unsqueeze(rfactor_did_idx); } - const auto& [sizes, strides] = inferShapeAndContiguousStrides(out(), ee); - auto meta_out = at::detail::empty_strided_meta(sizes, strides, a.dtype()); + // Without InferContiguity, we mistakenly assume the output is contiguous. + if (!isOptionEnabled(EnableOption::InferContiguity)) { + const auto& [sizes, strides] = inferShapeAndContiguousStrides(out(), ee); + auto meta_out = at::detail::empty_strided_meta(sizes, strides, a.dtype()); - if (meta_out.is_contiguous()) { - return {matmul_out}; - } + if (meta_out.is_contiguous()) { + return {matmul_out}; + } - auto strided_matmul_out = at::empty_strided(sizes, strides, a.options()); - strided_matmul_out = strided_matmul_out.copy_(matmul_out); - return {strided_matmul_out}; + auto strided_matmul_out = at::empty_strided(sizes, strides, a.options()); + strided_matmul_out = strided_matmul_out.copy_(matmul_out); + return {strided_matmul_out}; + } + return {matmul_out}; } LinearOp::LinearOp( diff --git a/csrc/options.cpp b/csrc/options.cpp index 0cc602bb5c8..743d9c926fc 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -185,6 +185,7 @@ const std::unordered_map& getEnableOptions() { {"p2p_protocol", EnableOption::P2pProtocol}, {"multicast_protocol", EnableOption::MulticastProtocol}, {"parallel_serde", EnableOption::ParallelSerde}, + {"infer-contiguity", EnableOption::InferContiguity}, }; return available_options; } diff --git a/csrc/options.h b/csrc/options.h index 3f21c3d9392..3f70e8ad992 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -130,6 +130,7 @@ enum class EnableOption { MulticastProtocol, //! Prescribe multicast protocol: //! memcpy|multimem|batch_memcpy ParallelSerde, //! Enable deserializing FusionExecutorCache in parallel + InferContiguity, //! Enable contiguity inference EndOfOption //! Placeholder for counting the number of elements }; diff --git a/csrc/runtime/allocations.cpp b/csrc/runtime/allocations.cpp index 75f142033ed..45a87d13d4c 100644 --- a/csrc/runtime/allocations.cpp +++ b/csrc/runtime/allocations.cpp @@ -21,12 +21,12 @@ namespace nvfuser { -KernelArgumentHolder inferOutputShapeAndContiguousStrides( +KernelArgumentHolder inferContiguousOutputMetaTensor( Fusion* fusion, const KernelArgumentHolder& args, PrecomputedValues* evaluator_precomputed_values) { FUSER_PERF_SCOPE( - "fusion_executor::allocations::inferOutputShapeAndContiguousStrides"); + "fusion_executor::allocations::inferContiguousOutputMetaTensor"); ExpressionEvaluator expr_eval; std::unique_ptr evaluator_precomputed_values_up = nullptr; diff --git a/csrc/runtime/allocations.h b/csrc/runtime/allocations.h index 981d3a6bc17..5e9b2da0ae5 100644 --- a/csrc/runtime/allocations.h +++ b/csrc/runtime/allocations.h @@ -46,7 +46,7 @@ struct GlobalBufferInfo { //! pushing scalar int 0 as a place-holder. //! 2. This API does not allocate output in memory, but only returns the //! inferred output sizes. Used in runtime/fusion_executor_cache.cpp. -KernelArgumentHolder inferOutputShapeAndContiguousStrides( +KernelArgumentHolder inferContiguousOutputMetaTensor( Fusion* fusion, const KernelArgumentHolder& args, PrecomputedValues* evaluator_precomputed_values = nullptr); diff --git a/csrc/runtime/fusion_cache_utils.cpp b/csrc/runtime/fusion_cache_utils.cpp index 1c7ba77b1be..3f8b27ff0cd 100644 --- a/csrc/runtime/fusion_cache_utils.cpp +++ b/csrc/runtime/fusion_cache_utils.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp index 18216fa46fd..b9964a07ea8 100644 --- a/csrc/runtime/fusion_kernel_runtime.cpp +++ b/csrc/runtime/fusion_kernel_runtime.cpp @@ -337,6 +337,66 @@ KernelArgumentHolder FusionKernelRuntime::runWithInputs( return fusion_outputs; } +KernelArgumentHolder FusionKernelRuntime::inferOutputMetaTensor( + HeuristicParamsList* heuristics, + SegmentedGroup* group_to_run, + const KernelArgumentHolder& group_runtime_inputs, + PrecomputedValues* evaluator_precomputed_values) const { + FUSER_PERF_SCOPE("FusionKernelRuntime::inferOutputMetaTensor"); + NVF_ERROR(heuristics != nullptr); + Fusion* fusion_to_run = group_to_run->getFusion(); + KernelArgumentHolder group_runtime_outputs; + const auto& heuristic_params = heuristics->at(group_to_run->groupId()); + const bool is_expr_eval = + heuristic_params->scheduler_type == SchedulerType::ExprEval; + if (is_expr_eval && isOptionEnabled(EnableOption::InferContiguity)) { + // For expr evaluated fusion, the striding rules follow that of ATen. + ExpressionEvaluator eval_fusion; + for (auto i : arange(group_to_run->inputs().size())) { + const auto& tensor_pv = group_runtime_inputs[i]; + if (tensor_pv.is()) { + const auto& t = tensor_pv.as(); + if (t.defined()) { + const auto meta_t = at::empty_strided( + t.sizes(), + t.strides(), + at::TensorOptions().device(at::kMeta).dtype(t.dtype())); + eval_fusion.bind(fusion_to_run->inputs()[i], meta_t); + } else { + eval_fusion.bind(fusion_to_run->inputs()[i], t); + } + } else { + eval_fusion.bind(fusion_to_run->inputs()[i], tensor_pv); + } + } + for (auto v : fusion_to_run->outputs()) { + auto result = eval_fusion.evaluate(v); + group_runtime_outputs.push(result); + } + } else { + auto fusion_to_run = group_to_run->getFusion(); + return inferContiguousOutputMetaTensor( + fusion_to_run, group_runtime_inputs, evaluator_precomputed_values); + } + return group_runtime_outputs; +} + +void FusionKernelRuntime::updateContiguityOfSegmentOutputs( + SegmentedGroup* group_to_run, + const KernelArgumentHolder& group_runtime_outputs) const { + FUSER_PERF_SCOPE("FusionKernelRuntime::updateContiguityOfSegmentOutputs"); + if (!isOptionEnabled(EnableOption::InferContiguity)) { + return; + } + for (auto [i, output] : enumerate(group_to_run->outputs())) { + auto tv = dynamic_cast(output); + if (tv) { + const at::Tensor& tensor = group_runtime_outputs[i].as(); + ir_utils::resetContiguityFromTensor(tv, tensor); + } + } +} + std::vector FusionKernelRuntime::prepareInputs( const KernelArgumentHolder& args) const { std::vector all_runtime_inputs; @@ -362,16 +422,14 @@ std::vector FusionKernelRuntime::prepareInputs( group_runtime_inputs.setCacheId(group_cache_id.value()); } - // TODO: inferOutputShapeAndContiguousStrides doesn't seem to strictly - // require a Fusion for each segment. Consider using the complete fusion - // instead. - auto fusion_to_run = segmented_fusion_->makeFusion(group_to_run).second; - auto group_runtime_outputs = inferOutputShapeAndContiguousStrides( - fusion_to_run.get(), group_runtime_inputs); + auto group_runtime_outputs = inferOutputMetaTensor( + heuristics_.get(), group_to_run, group_runtime_inputs); // map output args to tensor map args_manager.updateWithSegmentOutputs( group_to_run->outputs(), group_runtime_outputs, run_order_id); + + updateContiguityOfSegmentOutputs(group_to_run, group_runtime_outputs); } return all_runtime_inputs; @@ -599,13 +657,16 @@ std::optional> FusionKernelRuntime:: } // Generate metadata for the fusion's outputs - auto group_runtime_outputs = inferOutputShapeAndContiguousStrides( - fusion_to_run, + auto group_runtime_outputs = inferOutputMetaTensor( + heuristics.get(), + group_to_run, group_runtime_inputs, evaluator_precomputed_values.get()); args_manager.updateWithSegmentOutputs( group_to_run->outputs(), group_runtime_outputs, run_order_id); + + updateContiguityOfSegmentOutputs(group_to_run, group_runtime_outputs); } return heuristics; } diff --git a/csrc/runtime/fusion_kernel_runtime.h b/csrc/runtime/fusion_kernel_runtime.h index 31965df07c2..cf7f67fa2de 100644 --- a/csrc/runtime/fusion_kernel_runtime.h +++ b/csrc/runtime/fusion_kernel_runtime.h @@ -173,6 +173,27 @@ class FusionKernelRuntime { //! Access the list of schedulers maintained in this runtime instance const std::vector>& schedulers() const; + //! Infer the output shape and stride of the fusion as tensors on Meta device + //! If the group is scheduled to be evaluated using ExprEval, the output + //! tensors are inferred using the ExprEval on meta device. Otherwise, the + //! output tensors are inferred assuming they are contiguous. + KernelArgumentHolder inferOutputMetaTensor( + HeuristicParamsList* heuristics, + SegmentedGroup* group_to_run, + const KernelArgumentHolder& group_runtime_inputs, + PrecomputedValues* evaluator_precomputed_values = nullptr) const; + + //! When a FusionIR is constructed, all intermediate tensors are assumed to + //! be contiguous. Unfortunately, this assumption is not always true, and + //! could not be determined at compile time. Depending on the runtime inputs, + //! we may segment the fusions differently, and some fusion segments would be + //! executed using ATen, which may not generate contiguous tensors. So we have + //! to update the contiguity of the segment outputs on the fly depending on + //! the runtime inputs. + void updateContiguityOfSegmentOutputs( + SegmentedGroup* group_to_run, + const KernelArgumentHolder& group_runtime_outputs) const; + // Create KernelArgumentHolders for all of the segments. Sorted in // the run order. std::vector prepareInputs( diff --git a/tests/cpp/test_alias.cpp b/tests/cpp/test_alias.cpp index 0f419f53d97..3fc9d68d1b7 100644 --- a/tests/cpp/test_alias.cpp +++ b/tests/cpp/test_alias.cpp @@ -497,6 +497,9 @@ TEST_F(AliasTest, Issue1452) { } TEST_F(AliasTest, AliasOutputBeforeNonAliasOutput) { + EnableOptionsGuard opt_guard; + EnableOptionsGuard::getCurOptions().unset(EnableOption::InferContiguity); + auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); diff --git a/tests/cpp/test_indexing_advanced.cpp b/tests/cpp/test_indexing_advanced.cpp index 8044baaafc9..547a797a01d 100644 --- a/tests/cpp/test_indexing_advanced.cpp +++ b/tests/cpp/test_indexing_advanced.cpp @@ -26,6 +26,7 @@ class AdvancedIndexingTest : public NVFuserFixtureParamTest { } else { EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); } + EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity); } }; @@ -33,6 +34,7 @@ class AdvancedIndexingIdModelTest : public NVFuserTest { protected: void SetUp() override { EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel); + EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity); } }; diff --git a/tests/cpp/test_layout_op.cpp b/tests/cpp/test_layout_op.cpp index 4281c503019..e671c450192 100644 --- a/tests/cpp/test_layout_op.cpp +++ b/tests/cpp/test_layout_op.cpp @@ -87,6 +87,7 @@ class LayoutOpTest : public NVFuserTest { void SetUp() override { NVFuserTest::SetUp(); EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel); + EnableOptionsGuard::getCurOptions().unset(EnableOption::InferContiguity); } }; diff --git a/tests/cpp/test_loop_domain_scheduling.cpp b/tests/cpp/test_loop_domain_scheduling.cpp index db6154608d4..a545c4e0dfc 100644 --- a/tests/cpp/test_loop_domain_scheduling.cpp +++ b/tests/cpp/test_loop_domain_scheduling.cpp @@ -41,6 +41,7 @@ class LoopDomainSchedulingTest : public NVFuserTest { protected: void SetUp() override { EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel); + EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity); } }; diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 1261a70a6fa..186b7eec239 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -974,7 +974,13 @@ TEST_F(BlockQuantizationValidationTest, MergesMustBeContiguous) { class BlockQuantizationSchedulingTest : public BlackwellBase, public ::testing::WithParamInterface< - std::tuple, bool, bool>> {}; + std::tuple, bool, bool>> { + protected: + void SetUp() override { + BlackwellBase::SetUp(); + EnableOptionsGuard::getCurOptions().unset(EnableOption::InferContiguity); + } +}; TEST_P(BlockQuantizationSchedulingTest, AutoScheduleSingleOp) { const auto data_type = std::get<0>(GetParam()); diff --git a/tests/cpp/test_matmul_aten_evaluation.cpp b/tests/cpp/test_matmul_aten_evaluation.cpp index 32b69991453..9968c1d2b1f 100644 --- a/tests/cpp/test_matmul_aten_evaluation.cpp +++ b/tests/cpp/test_matmul_aten_evaluation.cpp @@ -371,37 +371,4 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(Sizes({n, 1})), testing::Values(Sizes({n})))); -using MatmulNodeTest = NVFuserTest; - -TEST_F(MatmulNodeTest, OutputStrides) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - TensorView* x = makeSymbolicTensor(2, DataType::Half); - TensorView* y = makeSymbolicTensor(2, DataType::Half); - TensorView* z = matmul(x, y); - - fusion->addInput(x); - fusion->addInput(y); - fusion->addOutput(z); - - z->setAllocationDomain({z->axis(1), z->axis(0), z->axis(2)}, true); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); - at::Tensor x_tensor = at::randn({2, 3}, options); - at::Tensor y_tensor = at::randn({3, 5}, options); - - FusionExecutorCache executor_cache(std::move(fusion)); - auto outs = executor_cache.runFusionWithInputs({x_tensor, y_tensor}); - at::Tensor z_tensor = outs[0].as(); - testValidate( - executor_cache.fusion(), - {z_tensor}, - {x_tensor, y_tensor}, - __LINE__, - __FILE__); - - EXPECT_THAT(z_tensor.strides(), ElementsAre(1, 2)); -} - } // namespace nvfuser diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index fa76d19f1a2..9cdc9e68b50 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -2802,6 +2802,7 @@ class MatmulFusionTest EnableOptionsGuard::getCurOptions().set( EnableOption::FuseMultipleMatmuls); } + EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity); } bool fusion_enabled = GetParam().first; diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index ba859c3f0be..3a404950a93 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -26,6 +26,7 @@ class PointwiseTest : public NVFuserTest { protected: void SetUp() override { EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel); + EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity); } }; diff --git a/tests/cpp/test_rng.cpp b/tests/cpp/test_rng.cpp index 1c145ec6adc..3d258294b67 100644 --- a/tests/cpp/test_rng.cpp +++ b/tests/cpp/test_rng.cpp @@ -69,6 +69,7 @@ at::Tensor generate_normal(int64_t size, at::ScalarType dtype) { class RNGTest : public NVFuserTest { void SetUp() override { EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel); + EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity); } }; diff --git a/tests/cpp/test_segmentation.cpp b/tests/cpp/test_segmentation.cpp index 0d0686d4d61..e7e2519b231 100644 --- a/tests/cpp/test_segmentation.cpp +++ b/tests/cpp/test_segmentation.cpp @@ -849,7 +849,10 @@ TEST_F(SegmentationTest, RevertPrivatizedUpcast) { ++num_upcast_ops; } - EXPECT_EQ(num_upcast_ops, 1); + // There is an unswitched IfThenElse in the generated kernel, and in each + // of its branches, there is an upcast op with tv1 as its producer. So we + // should have two upcast ops. + EXPECT_EQ(num_upcast_ops, 2); } } diff --git a/tests/cpp/utils.cpp b/tests/cpp/utils.cpp index e09614e9278..a74950e703b 100644 --- a/tests/cpp/utils.cpp +++ b/tests/cpp/utils.cpp @@ -59,6 +59,7 @@ void NVFuserTest::SetUp() { GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs"; } EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel); + EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity); } NVFuserTest::~NVFuserTest() { diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 0f7cfe4bff0..b08e9ecc4a2 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -5314,3 +5314,102 @@ def qwen2_cat_fusion_2(fd: FusionDefinition) -> None: qwen2_cat_fusion_2(fd) fd.execute(inputs) + + +def test_issue4888(): + # https://github.com/NVIDIA/Fuser/issues/4888 + def nvfuser_fusion_id2(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[4096, 4097], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.define_tensor( + shape=[4096, 4097], + contiguity=[True, True], + dtype=DataType.Bool, + is_cpu=False, + stride_order=[1, 0], + ) + T2 = fd.define_tensor( + shape=[4096, 4097], + contiguity=[True, True], + dtype=DataType.Bool, + is_cpu=False, + stride_order=[1, 0], + ) + T3 = fd.define_tensor( + shape=[1, 32, 4096, 4096], + contiguity=[None, True, True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[3, 2, 1, 0], + ) + T4 = fd.ops.cast(T0, dtype=DataType.Float) + T5 = fd.ops.bitwise_or(T1, T2) + T6 = fd.ops.set(T5) + fd.add_output(T6, T1) + T7 = fd.ops.cast(T6, dtype=DataType.Float) + T8 = fd.ops.mul(T4, T7) + T9 = fd.ops.cast(T8, dtype=DataType.BFloat16) + T10 = fd.ops.set(T9) + fd.add_output(T10, T0) + T15 = fd.ops.broadcast_in_dim(T10, shape=[1, 4096, 4097], broadcast_dims=[1, 2]) + T21 = fd.ops.broadcast_in_dim( + T15, shape=[1, 1, 4096, 4097], broadcast_dims=[0, 2, 3] + ) + T27 = fd.ops.broadcast_in_dim( + T21, shape=[1, 1, 4096, 4097], broadcast_dims=[0, 1, 2, 3] + ) + T43 = fd.ops.slice( + T27, + start_indices=[0, 0, 0, 0], + end_indices=[1, 1, 4096, 4096], + strides=[1, 1, 1, 1], + manual_normalization=0, + ) + T49 = fd.ops.broadcast_in_dim( + T43, shape=[1, 32, 4096, 4096], broadcast_dims=[0, 1, 2, 3] + ) + T50 = fd.ops.cast(T49, dtype=DataType.Float) + T51 = fd.ops.cast(T3, dtype=DataType.Float) + S52 = fd.define_scalar(0.0883883, dtype=DataType.Double) + T53 = fd.ops.mul(T51, S52) + T54 = fd.ops.add(T53, T50) + T55 = fd.ops.max(T54, dims=[3], keepdim=False, dtype=DataType.Null) + T61 = fd.ops.broadcast_in_dim( + T55, shape=[1, 32, 4096, 1], broadcast_dims=[0, 1, 2] + ) + T67 = fd.ops.broadcast_in_dim( + T61, shape=[1, 32, 4096, 4096], broadcast_dims=[0, 1, 2, 3] + ) + T68 = fd.ops.sub(T54, T67) + T69 = fd.ops.exp(T68) + T70 = fd.ops.sum(T69, dims=[3], keepdim=False, dtype=DataType.Null) + T76 = fd.ops.broadcast_in_dim( + T70, shape=[1, 32, 4096, 1], broadcast_dims=[0, 1, 2] + ) + T82 = fd.ops.broadcast_in_dim( + T76, shape=[1, 32, 4096, 4096], broadcast_dims=[0, 1, 2, 3] + ) + T83 = fd.ops.reciprocal(T82) + T84 = fd.ops.mul(T69, T83) + T85 = fd.ops.cast(T84, dtype=DataType.BFloat16) + fd.add_output(T49) + fd.add_output(T84) + fd.add_output(T85) + + with FusionDefinition() as fd: + nvfuser_fusion_id2(fd) + + inputs = [ + torch.testing.make_tensor((4096, 4097), dtype=torch.bfloat16, device="cuda:0"), + torch.testing.make_tensor((4096, 4097), dtype=torch.bool, device="cuda:0"), + torch.testing.make_tensor((4096, 4097), dtype=torch.bool, device="cuda:0"), + torch.testing.make_tensor( + (1, 32, 4096, 4096), dtype=torch.bfloat16, device="cuda:0" + ), + ] + fd.execute(inputs, _enable_options=["infer-contiguity"])