File tree Expand file tree Collapse file tree 1 file changed +10
-2
lines changed
csrc/device_lower/analysis Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -669,14 +669,22 @@ class ConcretizedBroadcastRedundantWriteRemover {
669669 return merged_logical_domains_sorted;
670670 }
671671
672- // Get the index of the loop domain if we skip the broadcasted logical domains
672+ // Get the index of the loop domain if we skip the broadcasted
673+ // logical domains
674+ // TODO: The result of this function is not necessary when
675+ // ThreadPredicateMap is used only for its analysis result. This
676+ // function is used to prepare for predicate generation, e.g.,
677+ // ThreadPredicateMap::getPredicate. Consider a different design
678+ // so that this preparation is only done when necessary.
673679 std::vector<Val*> getIndexOfBroadcastLogicalDomains (
674680 const std::vector<IterDomain*>& merged_logical_domains,
675681 ParallelType pt) {
676682 const int64_t ndim = (int64_t )merged_logical_domains.size ();
677683 // get the stride if we index the loop domain using its logical domains
678684 std::vector<Val*> logical_stride (ndim);
679- logical_stride.at (ndim - 1 ) = GpuLower::current ()->kernel ()->oneVal ();
685+ NVF_ERROR (!merged_logical_domains.empty ());
686+ logical_stride.at (ndim - 1 ) =
687+ merged_logical_domains.front ()->fusion ()->oneVal ();
680688 for (int64_t i = ndim - 2 ; i >= 0 ; i--) {
681689 auto pre_crd = merged_logical_domains.at (i + 1 );
682690 Val* pre_extent = pre_crd->isBroadcast ()
You can’t perform that action at this time.
0 commit comments