Skip to content
Draft
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
fe90fb5
PR0: Relax assert on non-device split on allocation domain
jjsjann123 Sep 18, 2025
a0df5e9
relaxing the check
jjsjann123 Sep 18, 2025
5097533
Adding test validating vectorization
jjsjann123 Sep 18, 2025
d4b7c8b
renaming
jjsjann123 Sep 19, 2025
4b07e79
clangformat
jjsjann123 Sep 19, 2025
051fc9e
I think it's working now!
jjsjann123 Sep 19, 2025
bf85c0b
clangformat
jjsjann123 Sep 19, 2025
6ff1050
quick patch
jjsjann123 Sep 22, 2025
f2f43be
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Sep 22, 2025
a32f54b
fix clearing allocation domain on cache for cacheBefore
jjsjann123 Sep 22, 2025
cf6e609
revert changes
jjsjann123 Sep 23, 2025
b303923
updating tests
jjsjann123 Sep 23, 2025
17dbf23
i was dumb as always
jjsjann123 Sep 23, 2025
c2a3aeb
why is it so hard for me
jjsjann123 Sep 23, 2025
1a156be
Apply suggestions from code review
jjsjann123 Sep 23, 2025
2081d0c
clangformat
jjsjann123 Sep 23, 2025
6f674aa
Merge branch 'main' into jj/allocation_PR_0
jjsjann123 Sep 23, 2025
4f8ecfc
Merge remote-tracking branch 'origin/main' into jj/allocation_PR_0
jjsjann123 Sep 26, 2025
ee37038
reverting selfReplay & cacheBefore changes per reviewer's comments
jjsjann123 Sep 26, 2025
f87e99d
wip
jjsjann123 Sep 26, 2025
c5155ff
wip
jjsjann123 Sep 26, 2025
ded16ec
wip
jjsjann123 Sep 26, 2025
f02440c
wip
jjsjann123 Sep 26, 2025
aa084bc
errr zip
jjsjann123 Sep 26, 2025
f82ad1f
wip
jjsjann123 Sep 26, 2025
d9a33d8
err, forgot to push something last night
jjsjann123 Sep 26, 2025
6dda5e2
typo
jjsjann123 Sep 26, 2025
173a7e9
skipping checks
jjsjann123 Sep 26, 2025
98654a0
wip
jjsjann123 Sep 26, 2025
a870f9d
relaxing checks in tests
jjsjann123 Sep 26, 2025
bca1734
wip
jjsjann123 Sep 26, 2025
d91ac03
clean up IDs for cacheBefore
jjsjann123 Sep 26, 2025
eff3069
clear up definition of output TV for cacheBefore
jjsjann123 Sep 26, 2025
2105e1e
fixing one alias test!
jjsjann123 Sep 27, 2025
bdaaccb
wip
jjsjann123 Sep 27, 2025
fdf9dba
fixing definition
jjsjann123 Sep 27, 2025
65022bd
wip
jjsjann123 Sep 27, 2025
5431648
not set allocation domain when original output doesn't have it
jjsjann123 Sep 27, 2025
75b06b5
update output itertype
jjsjann123 Sep 27, 2025
c2bf4cf
wip
jjsjann123 Sep 27, 2025
c5d66b6
wip
jjsjann123 Sep 27, 2025
b1836f5
wip
jjsjann123 Sep 27, 2025
7ee9317
fixing contiguity in fullselfreplay
jjsjann123 Sep 27, 2025
f7bbab2
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Sep 27, 2025
afccea0
fixing transpose tests
jjsjann123 Sep 27, 2025
482afc8
set parallelization type after fullseflreplay
jjsjann123 Sep 27, 2025
255055d
fix mark alias
jjsjann123 Oct 1, 2025
599d809
fixing alias analysis
jjsjann123 Oct 1, 2025
eadf148
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Oct 24, 2025
fe9c1f6
quick patch on nvfuser::schedule_matmul::Common::cacheBefore
jjsjann123 Oct 24, 2025
f927809
quick patch on nvfuser::schedule_matmul::Common::updateIdModel
jjsjann123 Oct 25, 2025
493434a
agent you can do better!
jjsjann123 Oct 25, 2025
92fb6f9
err
jjsjann123 Oct 25, 2025
642d9a8
Merge branch 'main' into jj/allocation_PR_0
jjsjann123 Oct 28, 2025
10eb4e0
Merge branch 'main' into jj/allocation_PR_0
jjsjann123 Oct 30, 2025
6ff66f2
try self replay so allocation domain is preserved for multi device
jjsjann123 Oct 30, 2025
2950054
err revert something that's not working
jjsjann123 Oct 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,21 +776,16 @@ std::unordered_set<TensorView*> getTvsWithDifferentSharding(
return ret;
}

void validateDeviceSplit(Expr* expr) {
NVF_ERROR(expr != nullptr, "Expected a valid expression.");
auto* split = dynamic_cast<Split*>(expr);
NVF_ERROR(
split != nullptr,
"Only split expressions are supported for producing device ids: ",
expr->toString());
NVF_ERROR(
split->outer()->isDeviceDim(),
"Expected the outer dimension to be a device dimension: ",
expr->toString());
NVF_ERROR(
!split->innerSplit(),
"Inner split by device dimension is not supported: ",
expr->toString());
bool isValidDeviceSplit(Expr* expr) {
if (expr == nullptr || !expr->isA<Split>()) {
return false;
}
auto* split = expr->as<Split>();
if (split == nullptr || !split->outer()->isDeviceDim() ||
split->innerSplit()) {
return false;
}
return true;
}

IterDomain* projectShardedAllocationToLogical(
Expand All @@ -806,7 +801,8 @@ IterDomain* projectShardedAllocationToLogical(

IterDomain* logical_id = allocation_id;
for (Expr* expr : exprs | std::views::reverse) {
validateDeviceSplit(expr);
NVF_ERROR(
isValidDeviceSplit(expr), "invalid device split: ", expr->toString());
logical_id = expr->as<Split>()->in();
}
return logical_id;
Expand All @@ -825,7 +821,8 @@ IterDomain* projectLogicalToShardedAllocation(
tv->getMaybeAllocationDomain().end()});
IterDomain* allocation_id = logical_id;
for (auto expr : exprs) {
validateDeviceSplit(expr);
NVF_ERROR(
isValidDeviceSplit(expr), "invalid device split: ", expr->toString());
allocation_id = expr->as<Split>()->inner();
}
return allocation_id;
Expand Down
2 changes: 1 addition & 1 deletion csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ std::vector<int64_t> unshardedSizes(

// Validate the expression is a valid DID split: expr is an outer split with
// device dim as the outer dimension.
void validateDeviceSplit(Expr* expr);
bool isValidDeviceSplit(Expr* expr);

// Find the producing logical id of the given allocation id traversing
// through device splits. For unsharded allocation_id, logical_id is the same as
Expand Down
11 changes: 10 additions & 1 deletion csrc/scheduler/vectorize_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,13 +807,22 @@ Val* ContiguousInnerDimensionsMapper::getContigMergeOfInnerSize(
{alloc_iid});
IterDomain* logical_id = alloc_iid;
Val* num_devices = of_tv->container()->oneVal();
bool only_valid_device_split = true;
for (Expr* expr : exprs | std::views::reverse) {
validateDeviceSplit(expr);
if (!isValidDeviceSplit(expr)) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@protonu You might need this relaxed. (things coming from vectorize_helper.cpp and multidevice/...

I'll start a PR on the side for this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - For now, I modified it here:
https://github.com/NVIDIA/Fuser/pull/5322/files

only_valid_device_split = false;
break;
}
auto* split = expr->as<Split>();
logical_id = split->in();
num_devices = SimplifyingIrBuilder::mulExpr(num_devices, split->factor());
}

// Non device split could lead to padding, which prevents vectorization
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have non-device split if not for padding? Should we throw an error here so we do not have random transforms?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might have missed your suggestion.

If I throw on seeing a split here, we wouldn't be able to support padding via transformation on allocation domains then.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about the confusion. To clarify, I was wondering if we can assert that a non-device split is infact padding, and not a random transform. For example, if it is a divisible split, we can include it in vectorization, correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got'ya.

I think a split on allocation domain is meant for padding.
vectorization should be handled in loop domain.

A more complicated case is, if we use split and permute on allocation domain to represent blocking, i.e. in some sense it's indeed used to facilitate vectorization/TMA. But conceptually I think they are still different.

Going back to the topic between divisible vs non-divisible split. We would need to specialize it during concretization, if we are to distinguish them (assuming dynamic shape). I tend to think that's something more like an optimization.
Without that assert, we are leaving margin for max vectorization factor, which I think isn't too big a deal. 🤞

if (!only_valid_device_split) {
break;
}

// Mapping order isn't correct, cannot expand vectorization dimension.
if (projected_dims[--projected_dims_i] != logical_id) {
break;
Expand Down
17 changes: 13 additions & 4 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1163,10 +1163,19 @@ TensorView* TensorView::cacheBefore(LoadStoreOpType op_type) {
// consumer tensor needs to copy the whole producer tensor, so the
// loop domain must be based on the logical domain.
if (!producer->definition()->isA<ScatterOp>()) {
auto replayed_consumer_pair = TransformReplay::replayCasP(
consumer, producer, -1, TransformReplayOptions().replayAllocation());

consumer->setDomain(replayed_consumer_pair.first);
// NOTE: Refactored from using TransformReplay::replayCasP doesn't
// replay transformation between logical to allocation. The map only works
// when the transformations are also on the path from logical to loop. I
// cannot comprehend what that replay code was doing and decided to switch
// to selfReplay, which targets replay of loop and allocation.
// NOTE: producer and consumer is linked by a LoadStoreOp, otherwise we
// cannot use selfReplay on general pari of producer-consumer TVs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// cannot use selfReplay on general pari of producer-consumer TVs.
// cannot use selfReplay on a general pair of producer-consumer TVs.

TransformReplay::selfReplay(
producer->domain(), consumer->domain(), /*ignore_reductions=*/true);
// TODO: remove allocation domain from cached TV
// technically we shouldn't let output allocation domain to dictate layout
Copy link
Collaborator

@naoyam naoyam Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's what's happening here. The transformations of producer are replayed on to consumer, so the producer allocation domain, i.e., the layout of the cache tensor, dictates the layout of the fusion output, which seems wrong.

I think the simplest way, at least conceptually, is to never touch the allocation domain of the consumer throughout cacheBefore since it doesn't make sense to modify the memory layout when a cache is created.

I think we just need to remove the reduction allocation domains.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transformations of producer are replayed on to consumer, so the producer allocation domain, i.e., the layout of the cache tensor, dictates the layout of the fusion output, which seems wrong.

The layout of the producer tensor is inherited from the original output.

Fuser/csrc/tensor_view.cpp

Lines 1122 to 1125 in ccbc581

auto* producer = IrBuilder::createInContainer<TensorView>(
container(),
IrBuilder::createInContainer<TensorDomain>(container(), domain()),
getDataType().value());

If we don't replay it, we are mutating allocation domain of an output TV and that doesn't seem right.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean by that. My proposal is that when doing setDomain with the output tensor, we should not modify its allocation domain except for reduction IDs since they are stripped if any.

It seems that what's actually implemented here is to create an invalid allocation domain and try to fix it up later, and I'm just saying why not avoiding breaking it from the beginning.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Record offline conversation with @naoyam . I'll refactor this one more time and follow up with the failing cases.

Wondering if you can mark request for changes on this PR to avoid accidental merge.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

converted this to draft to avoid accidental merge.

// of the cache. But existing scheduler expects the behavior and allocation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which scheduler is impacted if the allocation domain of an intermediate tensorview is not propagated?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC: @wujingyue as he ran into a different issue with cacheBefore: #5090 (comment)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment in the PR makes sense to me. The allocation of the cache (thus a MemoryType::Local TV) doesn't change external behavior and can be safely removed when schedulers start. See my definition of consistency (point 3) here: #5090 (comment)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm hitting these errors. Out of the existing schedulers, only the pointwise scheduler has somewhat good allocation domain support, so it's not surprising that we only have existing tests on that.

Note: AllocationDomainTest are failing on explicit assertions; DistributedTransformerTest are failing validation, but I'm not sure exactly what's causing the failure yet.

00:04:55 [  FAILED  ] PointwiseTest.Issue1567VectorizeAllocationDomain
00:06:09 [  FAILED  ] PointwiseTest.VectorizeAllocationDomain
00:04:55 [  FAILED  ] PointwiseTest.VectorizePadLoweringPermuted

00:05:04 [  FAILED  ] NVFuserTest.IndexSelectVectorizationIndexTensor
00:05:04 [  FAILED  ] NVFuserTest.IndexSelectVectorization3DCase1
00:06:09 [  FAILED  ] NVFuserTest.IndexSelectVectorizationIndexTensorNoBroadcast
00:04:55 [  FAILED  ] NVFuserTest.IndexSelectVectorization3DCase0

00:05:04 [  FAILED  ] AllocationDomainTest.NHWC4d_To_NHWC4d_cacheBefore
00:05:04 [  FAILED  ] AllocationDomainTest.VectorizationIssue902
00:06:09 [  FAILED  ] AllocationDomainTest.NHWC2d_To_NHWC2d_cacheBefore

00:04:55 [  FAILED  ] AliasTest.AliasOutputBeforeNonAliasOutput

00:11:28 [  FAILED  ] DistributedTransformerTest.MLP_Backward/__half, where GetParam() = __half
00:11:28 [  FAILED  ] DistributedTransformerTest.MHA_Backward/__bfloat, where GetParam() = __bfloat
00:11:28 [  FAILED  ] DistributedTransformerTest.Backward/__bfloat, where GetParam() = __bfloat
00:11:21 [  FAILED  ] DistributedTransformerTest.MLP_Backward/__bfloat, where GetParam() = __bfloat

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, you could debug the validation error using optimization fuel. For example, you could put a static counter before a transform:

static int64_t count = 0;
if (count < threshold) {
  perform a certain transform, e.g., set allocation domain
  count++
}

then you bisect threshold.

// domain to be preserved on the cache
} else if (producer->hasAllocation()) {
consumer->setAllocationDomain(
ir_utils::propagateScatterAllocationDomain(
Expand Down
26 changes: 18 additions & 8 deletions csrc/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ void TransformReplay::selfReplay(
// Replay allocation.
if (self->hasAllocation()) {
const std::vector<IterDomain*>& self_allocation = self->allocation();
// replay on allocation, for cases when the transformation is not on path to
// loop.
ReplaySelf allocation_dom_replay(self_allocation, axis_map);
const std::vector<std::optional<bool>>& self_contiguity =
self->contiguity();
NVF_ERROR_EQ(self_allocation.size(), self_contiguity.size());
Expand All @@ -347,18 +350,25 @@ void TransformReplay::selfReplay(
if (ignore_reductions && alloc_id->isReduction()) {
continue;
}
auto it = replay.getReplay().find(alloc_id);
NVF_ERROR(
it != replay.getReplay().end(),
"failed to replay IterDomain: ",
alloc_id);
IterDomain* id = nullptr;
// NOTE: try to use replay on loop domain first, to avoid unnecessarily
// duplicated transformation
for (const auto& re :
{replay.getReplay(), allocation_dom_replay.getReplay()}) {
auto it = re.find(alloc_id);
if (it != re.end()) {
id = it->second;
break;
}
}
NVF_ERROR(id, "failed to replay IterDomain: ", alloc_id);
NVF_ERROR_EQ(
it->second->isBroadcast(),
id->isBroadcast(),
!contiguity.has_value(),
"Contiguity should be nullopt iff broadcast.");
new_contiguity.push_back(contiguity);
it->second->parallelize(alloc_id->getParallelType());
new_alloc_domain.push_back(it->second);
id->parallelize(alloc_id->getParallelType());
new_alloc_domain.push_back(id);
}

new_self->setAllocationDomain(new_alloc_domain, new_contiguity);
Expand Down
73 changes: 73 additions & 0 deletions tests/cpp/test_layout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,79 @@ class LayoutOpTest : public NVFuserTest {
}
};

TEST_F(LayoutOpTest, LogicalAndAllocationSizes) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is being tested here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the relaxation in vectorization analysis, this test will trigger an assert.

So the test just verifies that we do allow allocation domain split now.
In the follow up PR, we added more validation to this test to check the produce tensor matches the logical sizes.

auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);

auto inp = makeSymbolicTensor(2);
fusion.addInput(inp);
auto out = set(inp);
fusion.addOutput(out);
// padding output to multiple of 16 on allocation domain
auto&& [io, ii] = IterDomain::split(
out->axis(1), IrBuilder::create<Val>(16L, DataType::Index), true);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tagging @naoyam changed the test to only apply split on logical -> allocation.

// NOTE: this doesn't feel right, we have to mark contiguity on axis(0) as
// `false` to avoid accidntal indexing collapsing, this should be figured out
// by indexing from the ceilDiv.
out->setAllocationDomain({out->axis(0), io, ii}, {false, true, true});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I understanding this issue correctly?
a) The tensor actually is contiguous with respect to this allocation domain, which has size M, ceilDiv(K, 16), 16.
b) The tensor winds up not being contiguous with respect to its logical domain which is of size M, K, because the nondivisible split adds some padding to K.
b) By "indexing collapsing" you mean it does contiguous indexing so that stride is not part of the index? Is that wrong? It seems like indexing as contiguous allocation is what we want here.

My question is what specifically goes wrong when allocation is set to contiguous?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are absolutely correct about a) and b).

Indexing collapsing is wrong here, because we are mapping from logical to allocation, which is not accessing contiguous memory (because of non-divisible split).

This is the before and after of the indexing.

with false contiguity flag

root@812ada01cb39:/opt/pytorch/nvfuser# NVFUSER_DUMP=cuda_kernel ./bin/test_layout_op --gtest_filter="*LogicalAndAllocationSizes"
Running main() from /opt/pytorch/nvfuser/third_party/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *LogicalAndAllocationSizes
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from LayoutOpTest
[ RUN      ] LayoutOpTest.LogicalAndAllocationSizes

======= Codegen output for kernel: nvfuser_pointwise_f0_c1_r0_g0 =======

// Codegen generated code
__global__ void nvfuser_pointwise_f0_c1_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 2, 3> T1) {
  nvfuser_index_t i0;
  i0 = ((nvfuser_index_t)threadIdx.x) + (128LL * ((nvfuser_index_t)blockIdx.x));
  nvfuser_index_t i1;
  i1 = i0 % T0.logical_size[1LL];
  nvfuser_index_t i2;
  i2 = i0 / T0.logical_size[1LL];
  if ((i0 < (T0.logical_size[0LL] * T0.logical_size[1LL]))) {
    Array<float, 1LL, 1> T2;
    T2[0LL] = 0LL;
    T2[0LL]
       = T0[((T0.alloc_stride[0LL] * i2) + (T0.alloc_stride[1LL] * i1))];
    Array<float, 1LL, 1> T3;
    T3[0LL]
       = T2[0LL];
    T1[(i1 + (T1.alloc_stride[0LL] * i2))]
       = T3[0LL];
  }
}

======================================

[       OK ] LayoutOpTest.LogicalAndAllocationSizes (966 ms)
[----------] 1 test from LayoutOpTest (966 ms total)

with true contiguity flag

root@558d9dfeefb8:/opt/pytorch/nvfuser# NVFUSER_DUMP=cuda_kernel ./bin/test_layout_op --gtest_filter="*LogicalAndAllocationSizes"
Running main() from /opt/pytorch/nvfuser/third_party/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *LogicalAndAllocationSizes
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from LayoutOpTest
[ RUN      ] LayoutOpTest.LogicalAndAllocationSizes

======= Codegen output for kernel: nvfuser_pointwise_f0_c1_r0_g0 =======

// Codegen generated code
__global__ void nvfuser_pointwise_f0_c1_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 2, 3> T1) {
  nvfuser_index_t i0;
  i0 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x));
  if ((i0 < (T0.logical_size[0LL] * T0.logical_size[1LL]))) {
    Array<float, 1, 1> T2;
    T2[0] = 0;
    T2[0]
       = T0[((T0.alloc_stride[0LL] * (i0 / T0.logical_size[1LL])) + (T0.alloc_stride[1LL] * (i0 % T0.logical_size[1LL])))];
    Array<float, 1, 1> T3;
    T3[0]
       = T2[0];
    T1[i0]
       = T3[0];
  }
}

======================================

/opt/pytorch/nvfuser/tests/cpp/test_layout_op.cpp:128: Failure
Value of: t0.equal(cg_outputs[0].as<at::Tensor>().slice(1, 0, k))
  Actual: false
Expected: true


// Two issues with split and merge approach:
// 1. This causes predication to expand to the padded region.
// 2. Indexing with allocation domain set as `true` is wrong.
// out->split(1, 16); // padding output to multiple of 16
// out->setAllocationDomain(out->getLoopDomain(), true);
// out->merge(1); // restore loop domain

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
int m = 512;
int k = 9; // note: padded column size would be 16
auto t0 = at::randn({m, k}, options);

FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto cg_outputs = executor_cache.runFusionWithInputs({t0});
// padding on the inner dimension is represented as stride on the outer
// dimension
EXPECT_EQ(
cg_outputs[0].as<at::Tensor>().strides(), std::vector<int64_t>({16, 1}));
// We need to slice because output buffer shape is not right
EXPECT_TRUE(t0.equal(cg_outputs[0].as<at::Tensor>().slice(1, 0, k)));
// TODO: enable this when output buffer shape is fixed.
// output should remain the correct logical size
// EXPECT_EQ(
// cg_outputs[0].as<at::Tensor>().sizes(), std::vector<int64_t>({512,
// 9}));
}

TEST_F(LayoutOpTest, AllocationDomainSplitVectorizationFactor) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);

auto inp = makeSymbolicTensor(3);
fusion.addInput(inp);
auto out = set(inp);
fusion.addOutput(out);
// split would prevent vectorization
auto&& [io, ii] = IterDomain::split(
out->axis(1), IrBuilder::create<Val>(16L, DataType::Index), true);
out->setAllocationDomain(
{out->axis(0), io, ii, out->axis(2)}, {false, true, true, true});

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
// because of the split on the middle dimension, we only have the fastest
// dimension participating in vectorization.
auto t0 = at::randn({512, 128, 2}, options);

// NOTE force pointwise scheduler here just for testing purpose
auto cg_results =
scheduleAndRun(fusion_ptr.get(), SchedulerType::PointWise, {t0});
auto pparams = cg_results.heuristic_params->as<PointwiseParams>();
EXPECT_EQ(pparams->vectorization_factor, 2);

testValidate(fusion_ptr.get(), cg_results.outputs, {t0}, __LINE__, __FILE__);
}

TEST_F(LayoutOpTest, CppApi) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
Expand Down