diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 9418b137688..4d0e093da5d 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -1778,4 +1779,66 @@ void swizzleBlockScales(TensorView* tv) { tv->merge(-2); } +void resetContiguityFromTensor(TensorView* tv, const at::Tensor& tensor) { + if (!tensor.defined()) { + return; + } + const auto [sizes, strides] = + inferAllocationSizesAndStrides(tensor, tv, ExpressionEvaluator()); + + const auto& alloc = tv->getMaybeAllocationDomain(); + + // Custom contiguity inference that considers IterDomain information + std::vector> contiguity(alloc.size(), std::nullopt); + + // Single pass from right to left with two dynamic indices: + // - alloc_idx: iterates through allocation domain + // - sizes_idx: tracks position in sizes/strides (excludes reductions) + int64_t sizes_idx = (int64_t)sizes.size() - 1; + int64_t prev_non_skipped_sizes_idx = -1; + + for (int64_t alloc_idx = (int64_t)alloc.size() - 1; alloc_idx >= 0; + --alloc_idx) { + auto id = alloc[alloc_idx]; + + // Reduction dimensions: nullopt contiguity (already set), no entry in + // sizes/strides + if (id->isReduction()) { + // Don't decrement sizes_idx since reductions have no entry + continue; + } + + // This dimension has an entry in sizes/strides + NVF_CHECK(sizes_idx >= 0, "Sizes index out of bounds"); + + // Broadcast dimensions: nullopt contiguity (already set), but has entry in + // sizes/strides + if (id->isBroadcast()) { + sizes_idx--; // Move to next dimension in sizes/strides + continue; + } + + // Non-broadcast, non-reduction dimension + if (prev_non_skipped_sizes_idx == -1) { + // This is the rightmost (innermost) non-skipped dimension + // It's contiguous if stride == 1 + contiguity[alloc_idx] = (strides[sizes_idx] == 1); + } else { + // A dimension is contiguous if its stride equals the stride of the + // next dimension multiplied by that dimension's size + contiguity[alloc_idx] = + (strides[sizes_idx] == + strides[prev_non_skipped_sizes_idx] * + sizes[prev_non_skipped_sizes_idx]); + } + + prev_non_skipped_sizes_idx = sizes_idx; + sizes_idx--; // Move to next dimension in sizes/strides + } + + NVF_CHECK(sizes_idx == -1, "Not all sizes/strides were consumed"); + + tv->setContiguity(contiguity); +} + } // namespace nvfuser::ir_utils diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 2c9c94ea457..3f62b1ec4d3 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -869,4 +869,10 @@ bool isParallelizedBy(const std::vector& ids, ParallelType pt); // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#scale-factor-layouts NVF_API void swizzleBlockScales(TensorView* tv); +// Infers and resets the contiguity of a TensorView based on the sizes and +// strides of an actual at::Tensor. The contiguity is computed by comparing +// consecutive strides in the allocation domain, accounting for broadcast and +// reduction dimensions. +void resetContiguityFromTensor(TensorView* tv, const at::Tensor& tensor); + } // namespace nvfuser::ir_utils diff --git a/tests/cpp/test_utils.cpp b/tests/cpp/test_utils.cpp index acafba4a144..b237d592497 100644 --- a/tests/cpp/test_utils.cpp +++ b/tests/cpp/test_utils.cpp @@ -2213,4 +2213,113 @@ TEST_F(UtilsTest, GetOrDefault) { EXPECT_EQ(getOrDefault(m, out), 0); } +TEST_F(UtilsTest, ResetContiguityFromTensor) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + // Manually create a TensorView with broadcast, iteration, and reduction dims + // Pattern: [Iteration, Broadcast, Iteration, Reduction, Iteration] + // This tests skipping of broadcast dimension that's NOT at position 0 + std::vector domains; + + // Regular iteration dimension (symbolic size) + domains.push_back( + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(DataType::Index)) + .iter_type(IterType::Iteration) + .build()); + + // Broadcast dimension (symbolic size) - NOT leftmost! + domains.push_back(IterDomainBuilder(fusion.zeroVal(), fusion.oneVal()) + .iter_type(IterType::Broadcast) + .build()); + + // Regular iteration dimension (symbolic size) + domains.push_back( + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(DataType::Index)) + .iter_type(IterType::Iteration) + .build()); + + // Reduction dimension (symbolic size) + domains.push_back( + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(DataType::Index)) + .iter_type(IterType::Reduction) + .build()); + + // Regular iteration dimension (symbolic size) + domains.push_back( + IterDomainBuilder( + fusion.zeroVal(), IrBuilder::create(DataType::Index)) + .iter_type(IterType::Iteration) + .build()); + + auto td = IrBuilder::create( + domains, TensorDomain::getContiguityFilledWith(domains, true)); + auto tv = IrBuilder::create(td, DataType::Float); + + // Test all 8 possible contiguity combinations for the 3 iteration dimensions + // Allocation domain: [Iteration, Broadcast, Iteration, Reduction, Iteration] + // Tensor shape: [3, 1, 4, 6] - broadcast has size 1, reduction has no entry + // We test all 2^3 = 8 combinations of [dim0, dim2, dim4] contiguity + + auto test_contiguity = [&](const std::vector& strides, + bool expected_dim0, + bool expected_dim2, + bool expected_dim4) { + at::Tensor tensor = at::empty_strided({3, 1, 4, 6}, strides, options); + ir_utils::resetContiguityFromTensor(tv, tensor); + auto contiguity = tv->domain()->contiguity(); + + EXPECT_EQ(contiguity.size(), 5); + EXPECT_TRUE(contiguity[0].has_value()); + EXPECT_EQ(contiguity[0].value(), expected_dim0); + EXPECT_FALSE(contiguity[1].has_value()); // Broadcast: nullopt + EXPECT_TRUE(contiguity[2].has_value()); + EXPECT_EQ(contiguity[2].value(), expected_dim2); + EXPECT_FALSE(contiguity[3].has_value()); // Reduction: nullopt + EXPECT_TRUE(contiguity[4].has_value()); + EXPECT_EQ(contiguity[4].value(), expected_dim4); + }; + + // Case 1: [F, F, F] - strides [100, X, 50, 2] + test_contiguity({100, 0, 50, 2}, false, false, false); + + // Case 2: [F, F, T] - strides [100, X, 50, 1] + test_contiguity({100, 0, 50, 1}, false, false, true); + + // Case 3: [F, T, F] - strides [100, X, 12, 2] + // dim4: stride 2 != 1 → F + // dim2: stride 12 = 2*6 → T + // dim0: stride 100 != 12*4=48 → F + test_contiguity({100, 0, 12, 2}, false, true, false); + + // Case 4: [F, T, T] - strides [100, X, 6, 1] + test_contiguity({100, 0, 6, 1}, false, true, true); + + // Case 5: [T, F, F] - strides [200, X, 50, 2] + // dim4: stride 2 != 1 → F + // dim2: stride 50 != 2*6=12 → F + // dim0: stride 200 = 50*4 → T + test_contiguity({200, 0, 50, 2}, true, false, false); + + // Case 6: [T, F, T] - strides [200, X, 50, 1] + // dim4: stride 1 → T + // dim2: stride 50 != 1*6=6 → F + // dim0: stride 200 = 50*4 → T + test_contiguity({200, 0, 50, 1}, true, false, true); + + // Case 7: [T, T, F] - strides [48, X, 12, 2] + // dim4: stride 2 != 1 → F + // dim2: stride 12 = 2*6 → T + // dim0: stride 48 = 12*4 → T + test_contiguity({48, 0, 12, 2}, true, true, false); + + // Case 8: [T, T, T] - strides [24, X, 6, 1] + test_contiguity({24, 0, 6, 1}, true, true, true); +} + } // namespace nvfuser