Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <iter_visitor.h>
#include <ops/arith.h>
#include <scheduler/mma_utils.h>
#include <tensor_metadata.h>

#include <limits>
#include <ranges>
Expand Down Expand Up @@ -1778,4 +1779,66 @@ void swizzleBlockScales(TensorView* tv) {
tv->merge(-2);
}

void resetContiguityFromTensor(TensorView* tv, const at::Tensor& tensor) {
if (!tensor.defined()) {
return;
}
const auto [sizes, strides] =
inferAllocationSizesAndStrides(tensor, tv, ExpressionEvaluator());

const auto& alloc = tv->getMaybeAllocationDomain();

// Custom contiguity inference that considers IterDomain information
std::vector<std::optional<bool>> contiguity(alloc.size(), std::nullopt);

// Single pass from right to left with two dynamic indices:
// - alloc_idx: iterates through allocation domain
// - sizes_idx: tracks position in sizes/strides (excludes reductions)
int64_t sizes_idx = (int64_t)sizes.size() - 1;
int64_t prev_non_skipped_sizes_idx = -1;

for (int64_t alloc_idx = (int64_t)alloc.size() - 1; alloc_idx >= 0;
--alloc_idx) {
auto id = alloc[alloc_idx];

// Reduction dimensions: nullopt contiguity (already set), no entry in
// sizes/strides
if (id->isReduction()) {
// Don't decrement sizes_idx since reductions have no entry
continue;
}

// This dimension has an entry in sizes/strides
NVF_CHECK(sizes_idx >= 0, "Sizes index out of bounds");

// Broadcast dimensions: nullopt contiguity (already set), but has entry in
// sizes/strides
if (id->isBroadcast()) {
sizes_idx--; // Move to next dimension in sizes/strides
continue;
}

// Non-broadcast, non-reduction dimension
if (prev_non_skipped_sizes_idx == -1) {
// This is the rightmost (innermost) non-skipped dimension
// It's contiguous if stride == 1
contiguity[alloc_idx] = (strides[sizes_idx] == 1);
} else {
// A dimension is contiguous if its stride equals the stride of the
// next dimension multiplied by that dimension's size
contiguity[alloc_idx] =
(strides[sizes_idx] ==
strides[prev_non_skipped_sizes_idx] *
sizes[prev_non_skipped_sizes_idx]);
}

prev_non_skipped_sizes_idx = sizes_idx;
sizes_idx--; // Move to next dimension in sizes/strides
}

NVF_CHECK(sizes_idx == -1, "Not all sizes/strides were consumed");

tv->setContiguity(contiguity);
}

} // namespace nvfuser::ir_utils
6 changes: 6 additions & 0 deletions csrc/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -869,4 +869,10 @@ bool isParallelizedBy(const std::vector<IterDomain*>& ids, ParallelType pt);
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#scale-factor-layouts
NVF_API void swizzleBlockScales(TensorView* tv);

// Infers and resets the contiguity of a TensorView based on the sizes and
// strides of an actual at::Tensor. The contiguity is computed by comparing
// consecutive strides in the allocation domain, accounting for broadcast and
// reduction dimensions.
void resetContiguityFromTensor(TensorView* tv, const at::Tensor& tensor);

} // namespace nvfuser::ir_utils
109 changes: 109 additions & 0 deletions tests/cpp/test_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2213,4 +2213,113 @@ TEST_F(UtilsTest, GetOrDefault) {
EXPECT_EQ(getOrDefault(m, out), 0);
}

TEST_F(UtilsTest, ResetContiguityFromTensor) {
Fusion fusion;
FusionGuard fg(&fusion);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

// Manually create a TensorView with broadcast, iteration, and reduction dims
// Pattern: [Iteration, Broadcast, Iteration, Reduction, Iteration]
// This tests skipping of broadcast dimension that's NOT at position 0
std::vector<IterDomain*> domains;

// Regular iteration dimension (symbolic size)
domains.push_back(
IterDomainBuilder(
fusion.zeroVal(), IrBuilder::create<Val>(DataType::Index))
.iter_type(IterType::Iteration)
.build());

// Broadcast dimension (symbolic size) - NOT leftmost!
domains.push_back(IterDomainBuilder(fusion.zeroVal(), fusion.oneVal())
.iter_type(IterType::Broadcast)
.build());

// Regular iteration dimension (symbolic size)
domains.push_back(
IterDomainBuilder(
fusion.zeroVal(), IrBuilder::create<Val>(DataType::Index))
.iter_type(IterType::Iteration)
.build());

// Reduction dimension (symbolic size)
domains.push_back(
IterDomainBuilder(
fusion.zeroVal(), IrBuilder::create<Val>(DataType::Index))
.iter_type(IterType::Reduction)
.build());

// Regular iteration dimension (symbolic size)
domains.push_back(
IterDomainBuilder(
fusion.zeroVal(), IrBuilder::create<Val>(DataType::Index))
.iter_type(IterType::Iteration)
.build());

auto td = IrBuilder::create<TensorDomain>(
domains, TensorDomain::getContiguityFilledWith(domains, true));
auto tv = IrBuilder::create<TensorView>(td, DataType::Float);

// Test all 8 possible contiguity combinations for the 3 iteration dimensions
// Allocation domain: [Iteration, Broadcast, Iteration, Reduction, Iteration]
// Tensor shape: [3, 1, 4, 6] - broadcast has size 1, reduction has no entry
// We test all 2^3 = 8 combinations of [dim0, dim2, dim4] contiguity

auto test_contiguity = [&](const std::vector<int64_t>& strides,
bool expected_dim0,
bool expected_dim2,
bool expected_dim4) {
at::Tensor tensor = at::empty_strided({3, 1, 4, 6}, strides, options);
ir_utils::resetContiguityFromTensor(tv, tensor);
auto contiguity = tv->domain()->contiguity();

EXPECT_EQ(contiguity.size(), 5);
EXPECT_TRUE(contiguity[0].has_value());
EXPECT_EQ(contiguity[0].value(), expected_dim0);
EXPECT_FALSE(contiguity[1].has_value()); // Broadcast: nullopt
EXPECT_TRUE(contiguity[2].has_value());
EXPECT_EQ(contiguity[2].value(), expected_dim2);
EXPECT_FALSE(contiguity[3].has_value()); // Reduction: nullopt
EXPECT_TRUE(contiguity[4].has_value());
EXPECT_EQ(contiguity[4].value(), expected_dim4);
};

// Case 1: [F, F, F] - strides [100, X, 50, 2]
test_contiguity({100, 0, 50, 2}, false, false, false);

// Case 2: [F, F, T] - strides [100, X, 50, 1]
test_contiguity({100, 0, 50, 1}, false, false, true);

// Case 3: [F, T, F] - strides [100, X, 12, 2]
// dim4: stride 2 != 1 → F
// dim2: stride 12 = 2*6 → T
// dim0: stride 100 != 12*4=48 → F
test_contiguity({100, 0, 12, 2}, false, true, false);

// Case 4: [F, T, T] - strides [100, X, 6, 1]
test_contiguity({100, 0, 6, 1}, false, true, true);

// Case 5: [T, F, F] - strides [200, X, 50, 2]
// dim4: stride 2 != 1 → F
// dim2: stride 50 != 2*6=12 → F
// dim0: stride 200 = 50*4 → T
test_contiguity({200, 0, 50, 2}, true, false, false);

// Case 6: [T, F, T] - strides [200, X, 50, 1]
// dim4: stride 1 → T
// dim2: stride 50 != 1*6=6 → F
// dim0: stride 200 = 50*4 → T
test_contiguity({200, 0, 50, 1}, true, false, true);

// Case 7: [T, T, F] - strides [48, X, 12, 2]
// dim4: stride 2 != 1 → F
// dim2: stride 12 = 2*6 → T
// dim0: stride 48 = 12*4 → T
test_contiguity({48, 0, 12, 2}, true, true, false);

// Case 8: [T, T, T] - strides [24, X, 6, 1]
test_contiguity({24, 0, 6, 1}, true, true, true);
}

} // namespace nvfuser