Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2130,7 +2130,46 @@ IterDomain* getConcreteLoopID(IterDomain* id) {
return promotion;
} else {
const auto& ca_map = FusionInfoGuard::current()->caMap();
return ca_map.getConcreteMappedID(id, IdMappingMode::LOOP);
auto disjoint_set = ca_map.disjointSetOf(id, IdMappingMode::LOOP);
auto concrete = ca_map.getConcreteMappedID(id, IdMappingMode::LOOP);

// The following code is a WAR fix of issue-5326
// The CA map's concrete ID may have an incompatible extent.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please link the issue number.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

// Similar to IdModel's loop promotion, we should prefer non-broadcast IDs
// with the largest extent in the loop group.
//
// Check if the concrete ID has a broadcast or size-one extent while
// other IDs in the group have larger extents.
bool concrete_is_broadcast_or_one = concrete->isBroadcast() ||
(concrete->extent()->isConstInt() &&
concrete->extent()->evaluate().as<int64_t>() == 1);

if (concrete_is_broadcast_or_one && disjoint_set->vector().size() > 1) {
// Look for a non-broadcast ID with a larger extent in the same loop group
IterDomain* better_concrete = nullptr;
int64_t max_extent = 1;

for (auto loop_id : disjoint_set->vector()) {
if (loop_id->isBroadcast()) {
continue;
}

if (loop_id->extent()->isConstInt()) {
auto extent_val = loop_id->extent()->evaluate().as<int64_t>();
if (extent_val > max_extent) {
max_extent = extent_val;
better_concrete = loop_id;
}
}
}

// If we found a better candidate, use it instead
if (better_concrete != nullptr && better_concrete != concrete) {
concrete = better_concrete;
}
}

return concrete;
}
}

Expand Down
47 changes: 47 additions & 0 deletions tests/python/direct/test_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -4544,3 +4544,50 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
torch.testing.make_tensor((24578,), dtype=torch.bfloat16, device="cuda:0"),
]
fd.validate(inputs)


def test_ca_map_concrete_loop_id(nvfuser_direct_test):
def nvfuser_fusion_id10(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(
shape=[16, 1, 1],
contiguity=[True, None, None],
dtype=DataType.Float,
is_cpu=False,
)
T1 = fd.define_tensor(
shape=[1, 1, 1024],
contiguity=[None, None, True],
dtype=DataType.Float,
is_cpu=False,
)
T2 = fd.ops.sub(T0, T0)
T3 = fd.ops.exp(T2)
T4 = fd.ops.reciprocal(T3)
T5 = fd.ops.mul(T3, T4)
T6 = fd.ops.broadcast(T5, is_broadcast_dim=[False, False, True, False])
S7 = fd.ops.size(T5, dim=1)
S8 = fd.define_scalar(16, dtype=DataType.Int)
S9 = fd.define_scalar(64, dtype=DataType.Int)
T11 = fd.ops.reshape(T1, new_shape=[S7, S7, S8, S9])
T12 = fd.ops.permute(T11, dims=[0, 2, 1, 3])
T13 = fd.ops.squeeze(T12, dims=[0], squeeze_expanded=True)
T14 = fd.ops.permute(T13, dims=[0, 2, 1])
T15 = fd.ops.broadcast(T14, is_broadcast_dim=[False, True, False, False])
T16 = fd.ops.mul(T6, T15)
T17 = fd.ops.squeeze(T16, dims=[3], squeeze_expanded=True)
T18 = fd.ops.broadcast(T17, is_broadcast_dim=[True, False, False, False])
T19 = fd.ops.permute(T18, dims=[0, 2, 1, 3])
S20 = fd.define_scalar(1024, dtype=DataType.Int)
T22 = fd.ops.reshape(T19, new_shape=[S7, S7, S20])
fd.add_output(T5)
fd.add_output(T13)
fd.add_output(T22)

with FusionDefinition() as fd:
nvfuser_fusion_id10(fd)

inputs = [
torch.testing.make_tensor((16, 1, 1), dtype=torch.float32, device="cuda:0"),
torch.testing.make_tensor((1, 1, 1024), dtype=torch.float32, device="cuda:0"),
]
fd.validate(inputs)