diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 9d20b86363f..fc52ef9d57c 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -699,7 +699,7 @@ class VectorizeValidator : public OptInDispatch { size_t last_alloc_dim_pos = 0; for (size_t i = tv->getMaybeAllocationDomain().size(); i > 0; i--) { auto r_id = tv->getMaybeAllocationDomain()[i - 1]; - if (r_id->isReduction() || r_id->isBroadcast()) { + if (r_id->isReduction() || r_id->isBroadcast() || r_id->isDeviceDim()) { continue; } if ((tv->getMemoryType() == MemoryType::Shared || diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index cd0accab25e..161b0de2b74 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -228,7 +228,7 @@ IterDomain* getShardedIterDomain( NVF_THROW("Unexpected parallel type: ", parallel_type); }(); - for (auto&& [index, id] : enumerate(domain)) { + for (IterDomain* id : domain | TensorDomain::kNoReductions) { if (id->getParallelType() == parallel_type) { return id; } @@ -266,9 +266,10 @@ std::vector unshardedSizes( } const int64_t sharded_axis = getProducingLogicalAxis(tv, sharded_id); - if (sharded_axis == -1) { - continue; - } + NVF_ERROR( + sharded_axis != -1, + "Producing logical axis not found for ", + sharded_id); auto multiplier = [&]() -> int64_t { if (parallel_type == ParallelType::Stream) { diff --git a/csrc/preseg_passes/finalize_multidevice_domains.cpp b/csrc/preseg_passes/finalize_multidevice_domains.cpp index bca2d97fd2c..74212039e39 100644 --- a/csrc/preseg_passes/finalize_multidevice_domains.cpp +++ b/csrc/preseg_passes/finalize_multidevice_domains.cpp @@ -91,39 +91,8 @@ void setLoopAndAllocationDomain(TensorView* tv, bool is_resharding) { new_allocation_domain.push_back(id); new_contiguity.push_back(contiguity); } - - std::optional> permutation = - ir_utils::computePermutation(new_allocation_domain, tv->getLoopDomain()); - NVF_ERROR( - permutation.has_value(), - "Failed to find a valid permutation for reordering ", - tv->getLoopDomain(), - " as ", - new_allocation_domain); - tv->reorder(permutation.value()); - - if (is_resharding) { - // Resharding expressions have specific requirements on position of - // gathered/scattered dimensions in the allocation domain that is ensured - // by ReorderShardedAxisPass. So we do not move the DIDx to the front in - // this case. For example, in reduce-scatter, the scattered axis is the - // outer-most dimension in communication input and output. - tv->setAllocationDomain(tv->getLoopDomain(), new_contiguity); - return; - } - - // Most schedulers require DIDx to be at the front of the loop domain. - auto old2new = reorderParallelizedToFront(tv); - auto new2old = ir_utils::normalizeOld2New(old2new, tv->nDims()); - std::vector> reordered_contiguity; - std::transform( - new2old.begin(), - new2old.end(), - std::back_inserter(reordered_contiguity), - [&new_contiguity](int64_t i) -> std::optional { - return new_contiguity[i]; - }); - tv->setAllocationDomain(tv->getLoopDomain(), reordered_contiguity); + tv->setAllocationDomain(new_allocation_domain, new_contiguity); + reorderParallelizedToFront(tv); } } // namespace