Skip to content

Commit fc76733

Browse files
authored
Minor fix (#5313)
`ThreadPredicateMap` is now used outside of lowering, so it can't assume `GpuLower::current()` is available. Hit an assertion error while testing #5278
1 parent 117a4b8 commit fc76733

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

csrc/device_lower/analysis/thread_predicate.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -669,14 +669,22 @@ class ConcretizedBroadcastRedundantWriteRemover {
669669
return merged_logical_domains_sorted;
670670
}
671671

672-
// Get the index of the loop domain if we skip the broadcasted logical domains
672+
// Get the index of the loop domain if we skip the broadcasted
673+
// logical domains
674+
// TODO: The result of this function is not necessary when
675+
// ThreadPredicateMap is used only for its analysis result. This
676+
// function is used to prepare for predicate generation, e.g.,
677+
// ThreadPredicateMap::getPredicate. Consider a different design
678+
// so that this preparation is only done when necessary.
673679
std::vector<Val*> getIndexOfBroadcastLogicalDomains(
674680
const std::vector<IterDomain*>& merged_logical_domains,
675681
ParallelType pt) {
676682
const int64_t ndim = (int64_t)merged_logical_domains.size();
677683
// get the stride if we index the loop domain using its logical domains
678684
std::vector<Val*> logical_stride(ndim);
679-
logical_stride.at(ndim - 1) = GpuLower::current()->kernel()->oneVal();
685+
NVF_ERROR(!merged_logical_domains.empty());
686+
logical_stride.at(ndim - 1) =
687+
merged_logical_domains.front()->fusion()->oneVal();
680688
for (int64_t i = ndim - 2; i >= 0; i--) {
681689
auto pre_crd = merged_logical_domains.at(i + 1);
682690
Val* pre_extent = pre_crd->isBroadcast()

0 commit comments

Comments
 (0)