diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 581fb77253c..1b4d14fe1cd 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -2130,7 +2130,46 @@ IterDomain* getConcreteLoopID(IterDomain* id) { return promotion; } else { const auto& ca_map = FusionInfoGuard::current()->caMap(); - return ca_map.getConcreteMappedID(id, IdMappingMode::LOOP); + auto disjoint_set = ca_map.disjointSetOf(id, IdMappingMode::LOOP); + auto concrete = ca_map.getConcreteMappedID(id, IdMappingMode::LOOP); + + // The following code is a WAR fix of issue-5326 + // The CA map's concrete ID may have an incompatible extent. + // Similar to IdModel's loop promotion, we should prefer non-broadcast IDs + // with the largest extent in the loop group. + // + // Check if the concrete ID has a broadcast or size-one extent while + // other IDs in the group have larger extents. + bool concrete_is_broadcast_or_one = concrete->isBroadcast() || + (concrete->extent()->isConstInt() && + concrete->extent()->evaluate().as() == 1); + + if (concrete_is_broadcast_or_one && disjoint_set->vector().size() > 1) { + // Look for a non-broadcast ID with a larger extent in the same loop group + IterDomain* better_concrete = nullptr; + int64_t max_extent = 1; + + for (auto loop_id : disjoint_set->vector()) { + if (loop_id->isBroadcast()) { + continue; + } + + if (loop_id->extent()->isConstInt()) { + auto extent_val = loop_id->extent()->evaluate().as(); + if (extent_val > max_extent) { + max_extent = extent_val; + better_concrete = loop_id; + } + } + } + + // If we found a better candidate, use it instead + if (better_concrete != nullptr && better_concrete != concrete) { + concrete = better_concrete; + } + } + + return concrete; } } diff --git a/tests/python/direct/test_repro.py b/tests/python/direct/test_repro.py index eabad2eb72b..0bfe970311f 100644 --- a/tests/python/direct/test_repro.py +++ b/tests/python/direct/test_repro.py @@ -4544,3 +4544,50 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: torch.testing.make_tensor((24578,), dtype=torch.bfloat16, device="cuda:0"), ] fd.validate(inputs) + + +def test_ca_map_concrete_loop_id(nvfuser_direct_test): + def nvfuser_fusion_id10(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[16, 1, 1], + contiguity=[True, None, None], + dtype=DataType.Float, + is_cpu=False, + ) + T1 = fd.define_tensor( + shape=[1, 1, 1024], + contiguity=[None, None, True], + dtype=DataType.Float, + is_cpu=False, + ) + T2 = fd.ops.sub(T0, T0) + T3 = fd.ops.exp(T2) + T4 = fd.ops.reciprocal(T3) + T5 = fd.ops.mul(T3, T4) + T6 = fd.ops.broadcast(T5, is_broadcast_dim=[False, False, True, False]) + S7 = fd.ops.size(T5, dim=1) + S8 = fd.define_scalar(16, dtype=DataType.Int) + S9 = fd.define_scalar(64, dtype=DataType.Int) + T11 = fd.ops.reshape(T1, new_shape=[S7, S7, S8, S9]) + T12 = fd.ops.permute(T11, dims=[0, 2, 1, 3]) + T13 = fd.ops.squeeze(T12, dims=[0], squeeze_expanded=True) + T14 = fd.ops.permute(T13, dims=[0, 2, 1]) + T15 = fd.ops.broadcast(T14, is_broadcast_dim=[False, True, False, False]) + T16 = fd.ops.mul(T6, T15) + T17 = fd.ops.squeeze(T16, dims=[3], squeeze_expanded=True) + T18 = fd.ops.broadcast(T17, is_broadcast_dim=[True, False, False, False]) + T19 = fd.ops.permute(T18, dims=[0, 2, 1, 3]) + S20 = fd.define_scalar(1024, dtype=DataType.Int) + T22 = fd.ops.reshape(T19, new_shape=[S7, S7, S20]) + fd.add_output(T5) + fd.add_output(T13) + fd.add_output(T22) + + with FusionDefinition() as fd: + nvfuser_fusion_id10(fd) + + inputs = [ + torch.testing.make_tensor((16, 1, 1), dtype=torch.float32, device="cuda:0"), + torch.testing.make_tensor((1, 1, 1024), dtype=torch.float32, device="cuda:0"), + ] + fd.validate(inputs)