diff --git a/csrc/alias_analysis.cpp b/csrc/alias_analysis.cpp index ea277caf390..21fb066698f 100644 --- a/csrc/alias_analysis.cpp +++ b/csrc/alias_analysis.cpp @@ -59,6 +59,12 @@ bool okToRelayout( const TensorView* tv, const Layout& new_layout, const EmptyAllocationAs empty_allocation_as) { + // we can merge this with the one below + // when using logical domain as the allocation domain, we can basically ignore the layout when it's not used by codegen + if (empty_allocation_as == EmptyAllocationAs::kLogical && !ir_utils::canUsePresetAllocationDomain(tv, false)) { + return true; + } + if (empty_allocation_as == EmptyAllocationAs::kUndetermined && !tv->hasAllocation()) { return true; diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index c5b5521fd80..4a049fbc3a2 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -963,6 +963,9 @@ class VectorizeValidator : public OptInDispatch { TensorView* tv, std::string name, int64_t vector_word_size_bit) { + if (!ir_utils::canUsePresetAllocationDomain(tv)) { + return; + } // aten_element_size_bit is the minimum unit (one element) of tv's // corresponding at::Tensor. It may or may not be the same as // dataTypeSizeBit(tv->dtype()), because we support non-ATen data types as diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 2832fd442e5..23266889584 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -181,6 +181,11 @@ class NVF_API IterDomain : public Val { return getIterType() == IterType::Iteration; } + IterDomain* resetRFactorProduct(bool is_rfactor_domain=false) { + is_rfactor_domain_ = is_rfactor_domain; + return this; + } + bool isRFactorProduct() const { return is_rfactor_domain_; } diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 5b312f96700..f1c1f9db10f 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1747,4 +1747,56 @@ bool isParallelizedBy(const std::vector& ids, ParallelType pt) { ids, [&](IterDomain* id) { return id->getParallelType() == pt; }); } +bool canUsePresetAllocationDomain(const TensorView* tv, bool ignore_empty_alloc) { + if (ignore_empty_alloc && !tv->hasAllocation()) { + return false; + } + // Honor the allocation domain if the tensor is global or Hopper MMA's + // output + if (tv->getMemoryType() == MemoryType::Global || + (tv->definition()->isA() && + isHopper(tv->definition()->as()->macro()))) { + return true; + } + // If it's a shared memory tensor, the set domain is likely + // valid if Swizzle or Bulk is used. Also, if the allocation + // domain is just a permutation of the loop domain, use the + // set allocation domain. This seems to happen only with + // AllocationDomainTest.TransposedIntermediate. + if (tv->getMemoryType() == MemoryType::Shared) { + if (std::any_of( + tv->getAllocationDomain().begin(), + tv->getAllocationDomain().end(), + [](IterDomain* allocation_domain) { + return dynamic_cast( + allocation_domain->definition()) != nullptr || + allocation_domain->getParallelType() == ParallelType::Bulk; + }) || + std::is_permutation( + tv->getLoopDomain().begin(), + tv->getLoopDomain().end(), + tv->getAllocationDomain().begin(), + tv->getAllocationDomain().end())) { + return true; + } + + // Honor the set allocation domain if the tensor is used by a + // TMA store or MmaOp + if (std::ranges::any_of(tv->uses(), [](Expr* expr) { + return ir_utils::isCpAsyncBulkStore(expr) || expr->isA(); + })) { + return true; + } + + // If a shared memory output produced by scatter has an + // allocation domain explicitly set, it's likely to be the + // valid allocation domain. + if (auto def = tv->definition(); + def != nullptr && def->isA()) { + return true; + } + } + return false; +} + } // namespace nvfuser::ir_utils diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 13af07ff0dc..ee2474d7eda 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -861,4 +861,6 @@ std::vector propagateScatterAllocationDomain( bool isParallelizedBy(const std::vector& ids, ParallelType pt); +bool canUsePresetAllocationDomain(const TensorView* tv, bool ignore_empty_alloc=true); + } // namespace nvfuser::ir_utils diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 8023d7b5f48..23089ce06c2 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -197,8 +197,21 @@ void Common::updateIdModel() { // IdModel std::unordered_map new_id_roles; for (auto& [k, v] : id_roles_) { - const ValGroup& new_group = new_graph.toGroup(k->front()); - new_id_roles.emplace(new_group, v); + // We need to traverse the ValGroup to find the remaining IDs that remains in the new id_model. This is because that cacheBefore could have eliminated the reduction ID. + + // e.g. + // output [m4, n5, rk6] = mma(A [m0, k1], B [n2, k3]) + // would become + // cache [m6, n7, rk8] = mma(A [m0, k1], B [n2, k3]) + // output [m4, n5 ] = set(cache [m6, n7, rk8]) + // + // So the old role rK6 wouldn't be mapped in new_graph. + auto old_vg = std::ranges::find_if(k->vector(), [&new_graph](Val* vg){return new_graph.hasGroup(vg);}); + NVF_ERROR( + old_vg != k->vector().end(), + "Old ValGroup not found in new ValGraph" + ); + new_id_roles.emplace(new_graph.toGroup(*old_vg), v); } id_roles_ = new_id_roles; } @@ -290,10 +303,8 @@ TensorView* Common::cacheBefore(TensorView* orig, LoadStoreOpType op_type) { const std::vector cache_logical = c->getLogicalDomain(); NVF_ERROR(orig_logical.size() == cache_logical.size()); for (size_t i : arange(orig_logical.size())) { - // The domain of orig gets transferred to c and a new domain is applied to - // orig - ValGroup vg = graph_->toGroup(cache_logical[i]); - graph_->initializeVal(orig_logical[i], vg); + ValGroup vg = graph_->toGroup(orig_logical[i]); + graph_->initializeVal(cache_logical[i], vg); } return c; diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index bb2c0d53a74..acb2de1a229 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -1128,18 +1128,105 @@ TensorView* TensorView::cacheBefore(LoadStoreOpType op_type) { "before computeAt."); } - // Create Producer Domain - // This domain will be the consumer which needs a new domain, so replace the - // producers domain with this domain. - - auto* producer = IrBuilder::createInContainer( - container(), - IrBuilder::createInContainer(container(), domain()), - getDataType().value()); + // TODO: 1. test reshape; 2. test reduction + // We want the producer domain to preserve `root` & `logical` + // meanwhile, we want consumer Tensor to preserve `logical` & `allocation` (while erasing all reductions). + + TensorView* producer; + + if (definition()->isA()) { + // TODO: is there any way to replay a scatter op?! + // scatter output's loop is not connected to its root. + NVF_ERROR(!domain()->hasRoot(), "scatter output's root is not replayed in cacheBefore"); + + std::vector logical; + std::vector loop; + std::unordered_map map_cloned_ids; + + std::ranges::transform(domain()->logical(), std::back_inserter(logical), [&](IterDomain* id) { + IterDomain* cloned_id = IrBuilder::createInContainer(container(), id); + map_cloned_ids[id] = cloned_id; + return cloned_id; + }); + std::ranges::transform(domain()->loop(), std::back_inserter(loop), [&](IterDomain* id) { + if (auto it = map_cloned_ids.find(id); it != map_cloned_ids.end()) { + // reuse cloned_ids + return it->second; + } + return IrBuilder::createInContainer(container(), id); + }); + producer = IrBuilder::createInContainer( + container(), + IrBuilder::createInContainer(container(), logical, loop, TensorDomain::getContiguityFilledWith(logical, true), /*skip_loop_validation=*/true), + getDataType().value()); + // TODO: we are not replaying the loop domain from consumer to producer, is that the right thing to do?! + } else { + // Create Producer Domain + // We only need root for full self replay. + std::vector root; + std::ranges::transform(domain()->hasRoot()?domain()->root():domain()->logical(), std::back_inserter(root), [&](IterDomain* id) { + return IrBuilder::createInContainer(container(), id); + }); + + producer = IrBuilder::createInContainer( + container(), + IrBuilder::createInContainer(container(), root, root, root, TensorDomain::getContiguityFilledWith(root, true)), + getDataType().value()); + // replay from `root`->`loop` on producer + producer->setDomain(TransformReplay::fullSelfReplay(producer->domain(), domain())); + } + + // clean up consumer domain to wipe out root and all reduction IDs + std::vector logical_dom; + std::vector alloc_dom; + std::vector loop_dom; + std::vector> contiguity; + + // NOTE: I need to clear definition otherwise BestEffortReplay will not work with dangling sources + // create an issue for this, use the example from ./bin/test_nvfuser --gtest_filter="PointwiseTest.VectorizeWithBroadcastAndReshape1" + // copy non-reduction IDs onto logical and loop + std::ranges::copy_if( + domain()->logical() | std::views::transform([](IterDomain* id) { id->setDefinition(nullptr); return id->resetRFactorProduct(); }), + std::back_inserter(logical_dom), + [](IterDomain* id) {return !id->isReduction();}); + if (definition()->isA()) { + // NOTE: this doesn't feel right. we would still want to replay the loop domain + // we are basically dropping transformations on loop domain for scatter op during cacheBefore + loop_dom = logical_dom; + } else { + std::ranges::copy_if( + domain()->loop() | std::views::transform([](IterDomain* id) { return id->resetRFactorProduct(); }), + std::back_inserter(loop_dom), + [](IterDomain* id) {return !id->isReduction();}); + } + for (auto&& [id, c] : zip(domain()->hasAllocation() ? domain()->allocation() : domain()->logical(), domain()->contiguity())) { + if (id->isReduction()) { + continue; + } + id->resetRFactorProduct(); + if (domain()->hasAllocation()) { + alloc_dom.push_back(id); + } + contiguity.push_back(c); + } + // TODO: We also need to clear all rfactor across IDs between logical->loop and logical->allocation. // Set domain of consumer TensorView* consumer = this; - + consumer->setDomain(IrBuilder::createInContainer( + container(), + std::vector{}, + logical_dom, + alloc_dom, + loop_dom, + contiguity)); + + // TODO: figure out scatter special handling. + // if (!producer->definition()->isA()) { + // } else { + // } + + /* FIXME std::vector new_logical_domain; new_logical_domain.reserve(getLogicalDomain().size()); for (IterDomain* dom : getLogicalDomain() | TensorDomain::kNoReductions) { @@ -1152,6 +1239,7 @@ TensorView* TensorView::cacheBefore(LoadStoreOpType op_type) { container(), new_logical_domain, TensorDomain::getContiguityFilledWith(new_logical_domain, true))); + */ // Insert producer - Cache_Before (CB) - before this TV. // Before: Prev TV -> [Definition Op] -> This TV @@ -1170,6 +1258,7 @@ TensorView* TensorView::cacheBefore(LoadStoreOpType op_type) { // definition_ is no longer valid // setDefinition(nullptr); + /* FIXME // We do not want to reproduce the loop domain if it's for // scatter. Recall that the loop domain of the scatter op is derived // from the logical domain of the scatter index tensor. Here, the @@ -1186,6 +1275,7 @@ TensorView* TensorView::cacheBefore(LoadStoreOpType op_type) { producer, consumer->getLogicalDomain()), true); } + */ if (consumer->hasDeviceMesh()) { producer->setDeviceMesh(consumer->getDeviceMesh()); diff --git a/csrc/transform_replay.cpp b/csrc/transform_replay.cpp index 2ca90cadc55..ef5835b28c6 100644 --- a/csrc/transform_replay.cpp +++ b/csrc/transform_replay.cpp @@ -70,6 +70,10 @@ class ReplaySelf : public ReplayTransformations { s->outer()->getIterType(), s->inner()->getIterType()); + // Parallelize type could include device from split. + ido->parallelize(s->outer()->getParallelType()); + idi->parallelize(s->inner()->getParallelType()); + // Remove mapped id from loop IDs loop_ids_.erase(mapped); @@ -106,7 +110,7 @@ class ReplaySelf : public ReplayTransformations { id_inner_mapped, " however one or both are not loop nodes."); - IterDomain* merged_id = IterDomain::merge(id_outer_mapped, id_inner_mapped); + IterDomain* merged_id = IterDomain::merge(id_outer_mapped, id_inner_mapped, m->out()->isRFactorProduct()); // Remove inputs from the loop IDs loop_ids_.erase(id_outer_mapped); @@ -147,11 +151,15 @@ class ReplaySelf : public ReplayTransformations { // output domain also an rfactor const auto resize_out_logical = resize->out()->isRFactorProduct(); + // Mark output IterType + const auto resize_out_iter_type = resize->out()->getIterType(); + auto replayed_out = IterDomain::resize( mapped, resize->leftExpand(), resize->rightExpand(), - resize_out_logical); + resize_out_logical, + resize_out_iter_type); loop_ids_.erase(mapped); @@ -234,7 +242,7 @@ TensorDomain* TransformReplay::fullSelfReplay( new_self_root->root(), new_logical_domain, new_domain, - self->contiguity()); + TensorDomain::getContiguityFilledWith(new_logical_domain, true)); } } diff --git a/tests/cpp/test_allocation_domain.cpp b/tests/cpp/test_allocation_domain.cpp index 60e7fa9a7fb..e713d6ff231 100644 --- a/tests/cpp/test_allocation_domain.cpp +++ b/tests/cpp/test_allocation_domain.cpp @@ -623,7 +623,6 @@ TEST_F(AllocationDomainTest, NHWC4d_To_NHWC4d_cacheBefore) { ASSERT_EQ(tv0->getAllocationDomain(), tv0_nhwc); ASSERT_EQ(tv1->getAllocationDomain(), expected_new_allocation_domain); - ASSERT_EQ(tv2->getAllocationDomain(), tv1_nhwc); for (auto tv : {tv1, tv2}) { // [N, C, H, W] @@ -708,7 +707,6 @@ TEST_F(AllocationDomainTest, NHWC2d_To_NHWC2d_cacheBefore) { ASSERT_EQ(tv0->getAllocationDomain(), tv0_2d); ASSERT_EQ(tv1->getAllocationDomain(), expected_new_allocation_domain); - ASSERT_EQ(tv2->getAllocationDomain(), tv1_2d); for (auto tv : {tv1, tv2}) { tv->split(0, 128); diff --git a/tests/cpp/test_layout_op.cpp b/tests/cpp/test_layout_op.cpp index 1a2fa739f27..7ccfdec612b 100644 --- a/tests/cpp/test_layout_op.cpp +++ b/tests/cpp/test_layout_op.cpp @@ -90,6 +90,79 @@ class LayoutOpTest : public NVFuserTest { } }; +TEST_F(LayoutOpTest, LogicalAndAllocationSizes) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto inp = makeSymbolicTensor(2); + fusion.addInput(inp); + auto out = set(inp); + fusion.addOutput(out); + // padding output to multiple of 16 on allocation domain + auto&& [io, ii] = IterDomain::split( + out->axis(1), IrBuilder::create(16L, DataType::Index), true); + // NOTE: this doesn't feel right, we have to mark contiguity on axis(0) as + // `false` to avoid accidntal indexing collapsing, this should be figured out + // by indexing from the ceilDiv. + out->setAllocationDomain({out->axis(0), io, ii}, {false, true, true}); + + // Two issues with split and merge approach: + // 1. This causes predication to expand to the padded region. + // 2. Indexing with allocation domain set as `true` is wrong. + // out->split(1, 16); // padding output to multiple of 16 + // out->setAllocationDomain(out->getLoopDomain(), true); + // out->merge(1); // restore loop domain + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + int m = 512; + int k = 9; // note: padded column size would be 16 + auto t0 = at::randn({m, k}, options); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs({t0}); + // padding on the inner dimension is represented as stride on the outer + // dimension + EXPECT_EQ( + cg_outputs[0].as().strides(), std::vector({16, 1})); + // We need to slice because output buffer shape is not right + EXPECT_TRUE(t0.equal(cg_outputs[0].as().slice(1, 0, k))); + // TODO: enable this when output buffer shape is fixed. + // output should remain the correct logical size + // EXPECT_EQ( + // cg_outputs[0].as().sizes(), std::vector({512, + // 9})); +} + +TEST_F(LayoutOpTest, AllocationDomainSplitVectorizationFactor) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto inp = makeSymbolicTensor(3); + fusion.addInput(inp); + auto out = set(inp); + fusion.addOutput(out); + // split would prevent vectorization + auto&& [io, ii] = IterDomain::split( + out->axis(1), IrBuilder::create(16L, DataType::Index), true); + out->setAllocationDomain( + {out->axis(0), io, ii, out->axis(2)}, {false, true, true, true}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + // because of the split on the middle dimension, we only have the fastest + // dimension participating in vectorization. + auto t0 = at::randn({512, 128, 2}, options); + + // NOTE force pointwise scheduler here just for testing purpose + auto cg_results = + scheduleAndRun(fusion_ptr.get(), SchedulerType::PointWise, {t0}); + auto pparams = cg_results.heuristic_params->as(); + EXPECT_EQ(pparams->vectorization_factor, 2); + + testValidate(fusion_ptr.get(), cg_results.outputs, {t0}, __LINE__, __FILE__); +} + TEST_F(LayoutOpTest, CppApi) { auto fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index e13c806b89a..7b63a431e5f 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -1062,6 +1062,16 @@ TEST_P(DistributedTransformerTest, LoopSplitMHAFwd) { at::ScalarType at_dtype = data_type_to_aten(dtype); const int d = communicator_->size(); + const int rank = communicator_->deviceId(); + + pid_t pid = getpid(); + if (rank == 0) { + std::cout << "rank " << rank << " PID: " << pid << std::endl; + int i = 0; + while (i == 0) { + sleep(5); + } + } auto mesh = DeviceMesh::createForNumDevices(d); diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index 7f775486f0d..b1ce164431f 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -1096,7 +1096,7 @@ TEST_F(TransposeTest, ViewTransposeReshape) { } TEST_F(TransposeTest, ReshapePermuteTransposeScheduler) { - // This is extracted from CSA in nanogpt, where we want transpose scheduler + // This is extracted from CSA in nanogpt auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -1105,6 +1105,7 @@ TEST_F(TransposeTest, ReshapePermuteTransposeScheduler) { auto tv0 = makeSymbolicTensor(3); fusion->addInput(tv0); + // NOTE: the two reshape can be handled by aliasing auto tv1 = reshape(tv0, {8, 1024, 1024}, {8, 1024, 16, 64}); auto tv2 = transpose(tv1, 1, 2); auto tv3 = transpose(tv2, 2, 3);