-
Notifications
You must be signed in to change notification settings - Fork 74
Combine for RaggedIterDomain #5716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 36 commits
Commits
Show all changes
52 commits
Select commit
Hold shift + click to select a range
d87e6d7
Initial introduction of RaggedIterDomain
naoyam 77c6a07
Merge remote-tracking branch 'origin/main' into raggediterdomain_init…
naoyam f16fc4d
cleanup
naoyam 23d55f1
fix
naoyam 8392332
fix
naoyam 787dfec
unit test
naoyam a0b40a3
cleanup
naoyam dbdd917
Fix IterVisitor
naoyam cdbd81e
cleanup
naoyam d4c8d7f
WIP: partition
naoyam 9575a13
Partition expr
naoyam a054ae0
TensorView::partition
naoyam 69dbe0f
cleanup
naoyam db3b359
Merge remote-tracking branch 'origin/main' into raggediterdomain_part…
naoyam 2348dde
cleanup
naoyam 7090b9c
WIP: asNested
naoyam b07e285
cleanup
naoyam a2c504b
asNested
naoyam b1d8cf4
warpdim
naoyam 201c148
Make sure RaggedIterDomain is propagated to output tensors
naoyam 9e0b161
Extend ops to be aware with RaggediterDomain
naoyam 60a2dd5
RaggedIterDomain and reduction
naoyam 566d63d
WIP
naoyam 144b206
WIP
naoyam e2efe75
cleanup
naoyam 0b68d6b
cleanup
naoyam 8a73bb2
cleanup
naoyam 550e0c5
Merge branch 'raggediterdomain_partition' into raggediterdomain-asnested
naoyam 82bd85e
Merge remote-tracking branch 'origin/main' into raggediterdomain-asne…
naoyam f215f07
Use extents as a parameter
naoyam 5b99432
Merge remote-tracking branch 'origin/main' into raggediterdomain-asne…
naoyam 2dd9287
Merge branch 'raggediterdomain-asnested' into raggediterdomain_clone
naoyam c3aebec
combine
naoyam a22bb1f
Add tests
naoyam f521c38
WIP
naoyam 8d0d9cb
don't hold component ID in RaggedIterDomain
naoyam 67aac1b
Add design doc
naoyam 3a80926
license
naoyam f75ecb6
Merge branch 'main' into raggediterdomain-asnested
naoyam 8aa854e
feedback
naoyam 72ae14f
fix
naoyam 85d48df
Merge branch 'raggediterdomain-asnested' into raggediterdomain_clone
naoyam 5f86d9c
Merge branch 'main' into raggediterdomain_clone
naoyam bf5b627
Merge remote-tracking branch 'origin/main' into raggediterdomain_clone
naoyam bec4c09
Merge remote-tracking branch 'origin/main' into raggediterdomain_clone
naoyam 4d8acab
cleanup
naoyam 3b082ba
cleanup
naoyam 72dbc41
Merge branch 'raggediterdomain_clone' into ragged_combine
naoyam 5002407
expand doc
naoyam be0e2ea
cleanup
naoyam 05a6201
Merge remote-tracking branch 'origin/main' into ragged_combine
naoyam d2b5384
format
naoyam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -30,8 +30,8 @@ | |||||
|
|
||||||
| namespace nvfuser { | ||||||
|
|
||||||
| IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent) | ||||||
| : start_(_start), extent_(_extent) { | ||||||
| IterDomainBuilder::IterDomainBuilder(Val* start, Val* extent) | ||||||
| : start_(start), extent_(extent) { | ||||||
| NVF_ERROR( | ||||||
| start_ != nullptr && extent_ != nullptr, | ||||||
| "Start and extent are required to build an iter domain."); | ||||||
|
|
@@ -48,7 +48,11 @@ IterDomainBuilder::IterDomainBuilder(const IterDomain* id) | |||||
| is_rfactor_domain_(id->isRFactorProduct()), | ||||||
| is_padded_dimension_(id->hasPaddingToMultipleOfWarp()), | ||||||
| is_clustered_dimension_(id->isClusteredBlockDim()), | ||||||
| padded_to_size_(id->getMaybeSizeAfterPadding()) {} | ||||||
| padded_to_size_(id->getMaybeSizeAfterPadding()) { | ||||||
| if (id->isA<RaggedIterDomain>()) { | ||||||
| ragged_extents_ = id->as<RaggedIterDomain>()->extents(); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| IterDomainBuilder& IterDomainBuilder::resetSchedulingParams() { | ||||||
| parallel_type_ = ParallelType::Serial; | ||||||
|
|
@@ -63,60 +67,72 @@ IterDomainBuilder& IterDomainBuilder::resetRfactor() { | |||||
| return is_rfactor_domain(false); | ||||||
| } | ||||||
|
|
||||||
| IterDomainBuilder& IterDomainBuilder::start(Val* _start) { | ||||||
| start_ = _start; | ||||||
| IterDomainBuilder& IterDomainBuilder::start(Val* start) { | ||||||
| start_ = start; | ||||||
| return *this; | ||||||
| } | ||||||
|
|
||||||
| IterDomainBuilder& IterDomainBuilder::extent(Val* _extent) { | ||||||
| extent_ = _extent; | ||||||
| IterDomainBuilder& IterDomainBuilder::extent(Val* extent) { | ||||||
| extent_ = extent; | ||||||
| return *this; | ||||||
| } | ||||||
|
|
||||||
| IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* _expanded_extent) { | ||||||
| expanded_extent_ = _expanded_extent; | ||||||
| IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* expanded_extent) { | ||||||
| expanded_extent_ = expanded_extent; | ||||||
| return *this; | ||||||
| } | ||||||
|
|
||||||
| IterDomainBuilder& IterDomainBuilder::stop_offset(Val* _stop_offset) { | ||||||
| stop_offset_ = _stop_offset; | ||||||
| IterDomainBuilder& IterDomainBuilder::stop_offset(Val* stop_offset) { | ||||||
| stop_offset_ = stop_offset; | ||||||
| return *this; | ||||||
| } | ||||||
|
|
||||||
| IterDomainBuilder& IterDomainBuilder::parallel_type( | ||||||
| ParallelType _parallel_type) { | ||||||
| parallel_type_ = _parallel_type; | ||||||
| ParallelType parallel_type) { | ||||||
| parallel_type_ = parallel_type; | ||||||
| return *this; | ||||||
| } | ||||||
|
|
||||||
| IterDomainBuilder& IterDomainBuilder::iter_type(IterType _iter_type) { | ||||||
| iter_type_ = _iter_type; | ||||||
| IterDomainBuilder& IterDomainBuilder::iter_type(IterType iter_type) { | ||||||
| iter_type_ = iter_type; | ||||||
| return *this; | ||||||
| } | ||||||
|
|
||||||
| IterDomainBuilder& IterDomainBuilder::is_rfactor_domain( | ||||||
| bool _is_rfactor_domain) { | ||||||
| is_rfactor_domain_ = _is_rfactor_domain; | ||||||
| bool is_rfactor_domain) { | ||||||
| is_rfactor_domain_ = is_rfactor_domain; | ||||||
| return *this; | ||||||
| } | ||||||
|
|
||||||
| IterDomainBuilder& IterDomainBuilder::is_padded_dimension( | ||||||
| bool _is_padded_dimension) { | ||||||
| is_padded_dimension_ = _is_padded_dimension; | ||||||
| bool is_padded_dimension) { | ||||||
| is_padded_dimension_ = is_padded_dimension; | ||||||
| return *this; | ||||||
| } | ||||||
|
|
||||||
| IterDomainBuilder& IterDomainBuilder::padded_to_size( | ||||||
| std::optional<int64_t> _padded_to_size) { | ||||||
| padded_to_size_ = _padded_to_size; | ||||||
| std::optional<int64_t> padded_to_size) { | ||||||
| padded_to_size_ = padded_to_size; | ||||||
| return *this; | ||||||
| } | ||||||
|
|
||||||
| IterDomainBuilder& IterDomainBuilder::ragged_extents( | ||||||
| TensorView* ragged_extents) { | ||||||
| ragged_extents_ = ragged_extents; | ||||||
| return *this; | ||||||
| } | ||||||
|
|
||||||
| IterDomain* IterDomainBuilder::build() const { | ||||||
| NVF_ERROR( | ||||||
| start_ != nullptr && extent_ != nullptr, | ||||||
| "Start and extent are required to build an iter domain."); | ||||||
| return IrBuilder::createInContainer<IterDomain>(start_->container(), *this); | ||||||
|
|
||||||
| if (ragged_extents_ != nullptr) { | ||||||
| return IrBuilder::createInContainer<RaggedIterDomain>( | ||||||
| start_->container(), *this); | ||||||
| } else { | ||||||
| return IrBuilder::createInContainer<IterDomain>(start_->container(), *this); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| IterDomain::IterDomain( | ||||||
|
|
@@ -815,6 +831,77 @@ void validateLoopDomain( | |||||
|
|
||||||
| } // namespace | ||||||
|
|
||||||
| RaggedIterDomain::RaggedIterDomain( | ||||||
| IrBuilderPasskey passkey, | ||||||
| const IterDomainBuilder& args) | ||||||
| : IterDomain( | ||||||
| passkey, | ||||||
| ValType::RaggedIterDomain, | ||||||
| args.start_, | ||||||
| args.extent_, | ||||||
| args.expanded_extent_, | ||||||
| args.stop_offset_, | ||||||
| args.parallel_type_, | ||||||
| args.iter_type_, | ||||||
| args.is_rfactor_domain_, | ||||||
| args.is_padded_dimension_, | ||||||
| args.is_clustered_dimension_, | ||||||
| args.padded_to_size_), | ||||||
| extents_(args.ragged_extents_) { | ||||||
| // Extents must be non-null | ||||||
| NVF_ERROR( | ||||||
| extents_ != nullptr, "RaggedIterDomain requires non-null extents tensor"); | ||||||
|
|
||||||
| // Extents must have integer dtype | ||||||
| NVF_ERROR_EQ( | ||||||
| extents_->dtype(), | ||||||
| DataType::Index, | ||||||
| "RaggedIterDomain extents must have index type, got ", | ||||||
| extents_->dtype()); | ||||||
|
|
||||||
| // Only IterType::Iteration is supported at this moment | ||||||
| NVF_ERROR_EQ( | ||||||
| iter_type_, | ||||||
| IterType::Iteration, | ||||||
| "Only IterType::Iteration is supported: ", | ||||||
| iter_type_); | ||||||
|
|
||||||
| // RaggedIterDomain has specific requirements on member values | ||||||
| NVF_ERROR( | ||||||
| start_->isZeroInt(), | ||||||
| "RaggedIterDomain start must be zero, got: ", | ||||||
| start_->toInlineString()); | ||||||
|
|
||||||
| NVF_ERROR( | ||||||
| extent_->isOneInt(), | ||||||
| "RaggedIterDomain extent must be one (placeholder), got: ", | ||||||
| extent_->toInlineString()); | ||||||
|
|
||||||
| NVF_ERROR( | ||||||
| expanded_extent_ == nullptr, | ||||||
| "RaggedIterDomain does not support expanded_extent"); | ||||||
|
|
||||||
| NVF_ERROR( | ||||||
| stop_offset_ == nullptr || stop_offset_->isZeroInt(), | ||||||
| "RaggedIterDomain stop_offset must be nullptr or zero, got: ", | ||||||
| stop_offset_ ? stop_offset_->toInlineString() : "nullptr"); | ||||||
|
|
||||||
| NVF_ERROR( | ||||||
| !is_rfactor_domain_, "RaggedIterDomain does not support rfactor domains"); | ||||||
|
|
||||||
| NVF_ERROR( | ||||||
| !is_padded_dimension_, | ||||||
| "RaggedIterDomain does not support padded dimensions"); | ||||||
|
|
||||||
| NVF_ERROR( | ||||||
| !is_clustered_dimension_, | ||||||
| "RaggedIterDomain does not support clustered dimensions"); | ||||||
|
|
||||||
| NVF_ERROR( | ||||||
| !padded_to_size_.has_value(), | ||||||
| "RaggedIterDomain does not support padded_to_size"); | ||||||
| } | ||||||
|
|
||||||
| RaggedIterDomain::RaggedIterDomain( | ||||||
| IrBuilderPasskey passkey, | ||||||
| TensorView* extents, | ||||||
|
|
@@ -895,6 +982,20 @@ std::string RaggedIterDomain::toString(int indent_size) const { | |||||
| return toInlineString(indent_size); | ||||||
| } | ||||||
|
|
||||||
| IterDomain* RaggedIterDomain::cloneWithoutRFactor(bool map_with_original) { | ||||||
| // Create a new RaggedIterDomain with the same extents and properties | ||||||
| auto cloned = IrBuilder::create<RaggedIterDomain>( | ||||||
| extents_, getIterType(), getParallelType()); | ||||||
|
|
||||||
| // Optionally map the clone with the original in the Exact graph | ||||||
| if (map_with_original) { | ||||||
| // TODO: Implement mapping if needed | ||||||
| NVF_THROW("Not implemented"); | ||||||
| } | ||||||
|
|
||||||
| return cloned; | ||||||
| } | ||||||
|
|
||||||
| std::pair<IterDomain*, RaggedIterDomain*> RaggedIterDomain::partition( | ||||||
| IterDomain* in, | ||||||
| TensorView* extents) { | ||||||
|
|
@@ -953,6 +1054,106 @@ std::pair<IterDomain*, RaggedIterDomain*> RaggedIterDomain::partition( | |||||
| return {component_id, ragged_id}; | ||||||
| } | ||||||
|
|
||||||
| IterDomain* RaggedIterDomain::combine( | ||||||
| IterDomain* component, | ||||||
| RaggedIterDomain* ragged) { | ||||||
| NVF_ERROR(component != nullptr, "combine: component IterDomain is null"); | ||||||
| NVF_ERROR(ragged != nullptr, "combine: ragged IterDomain is null"); | ||||||
|
|
||||||
| NVF_ERROR( | ||||||
| !component->isA<RaggedIterDomain>(), | ||||||
| "combine: component must be a regular IterDomain, got RaggedIterDomain: ", | ||||||
| component->toString()); | ||||||
|
|
||||||
| // Validate that component and ragged have compatible properties | ||||||
| NVF_ERROR_EQ( | ||||||
| component->getParallelType(), | ||||||
| ParallelType::Serial, | ||||||
| "Combining parallelized IterDomain not supported: ", | ||||||
| component->toString()); | ||||||
|
|
||||||
| NVF_ERROR_EQ( | ||||||
| ragged->getParallelType(), | ||||||
| ParallelType::Serial, | ||||||
| "Combining parallelized RaggedIterDomain not supported: ", | ||||||
| ragged->toString()); | ||||||
|
|
||||||
| NVF_ERROR_EQ( | ||||||
| component->getIterType(), | ||||||
| IterType::Iteration, | ||||||
| "combine: only IterType::Iteration is supported for component, got ", | ||||||
| component->getIterType(), | ||||||
| " for IterDomain: ", | ||||||
| component->toString()); | ||||||
|
|
||||||
| NVF_ERROR_EQ( | ||||||
| ragged->getIterType(), | ||||||
| IterType::Iteration, | ||||||
| "combine: only IterType::Iteration is supported for ragged, got ", | ||||||
| ragged->getIterType(), | ||||||
| " for RaggedIterDomain: ", | ||||||
| ragged->toString()); | ||||||
|
|
||||||
| // Validate component-ragged pairing when Partition definition is available | ||||||
| // (Option 3: Best-effort validation) | ||||||
| // Only validate when the RaggedIterDomain has a direct Partition definition. | ||||||
| // After propagation (e.g., set() operations), the definition may be nullptr, | ||||||
| // in which case we trust the user to provide the correct component. | ||||||
| if (ragged->definition() != nullptr && | ||||||
| ragged->definition()->isA<Partition>()) { | ||||||
| auto* partition = ragged->definition()->as<Partition>(); | ||||||
| IterDomain* expected_component = partition->component(); | ||||||
|
|
||||||
| NVF_ERROR( | ||||||
| component == expected_component, | ||||||
| "combine: component mismatch. The provided component does not match ", | ||||||
| "the component from the Partition that created this " | ||||||
| "RaggedIterDomain.\n", | ||||||
| " Provided component: ", | ||||||
| component->toString(), | ||||||
| "\n", | ||||||
| " Expected component: ", | ||||||
| expected_component->toString()); | ||||||
| } | ||||||
| // If no Partition definition (after set, in segmented fusion, or external | ||||||
| // input), trust the user and proceed without validation | ||||||
|
|
||||||
| // The combined extent is the sum of all extents in the ragged dimension | ||||||
| // For a 1D extents tensor [e0, e1, ..., en-1], the total is sum(extents) | ||||||
| TensorView* extents_tv = ragged->extents(); | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
The type already says it. Also, in the context of RaggedIterDomain, extents has to be a TensorView. |
||||||
| NVF_ERROR(extents_tv != nullptr, "combine: ragged extents tensor is null"); | ||||||
|
|
||||||
| // It is still assumed the extents tensor is just 1D | ||||||
| NVF_ERROR_EQ( | ||||||
| std::ssize(extents_tv->getLogicalDomain()), | ||||||
| 1, | ||||||
| "Unexpected rank of extent tensor: ", | ||||||
| extents_tv->toString()); | ||||||
|
|
||||||
| auto container = component->container(); | ||||||
| auto zero = container->zeroVal(DataType::Index); | ||||||
|
|
||||||
| // Create a symbolic extent for the combined IterDomain | ||||||
| // This represents the sum of all ragged extents, i.e., | ||||||
| // sum(extents_tv, {0}). We could use the sum output as the extent | ||||||
| // but we would need to extract the scalar value out of the 0-dim | ||||||
| // tensor. For now, we leave it as a symbolic Val. | ||||||
| Val* combined_extent = | ||||||
| IrBuilder::createInContainer<Val>(container, DataType::Index); | ||||||
naoyam marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
||||||
| // Create the combined IterDomain with the symbolic extent | ||||||
| IterDomain* combined_id = IterDomainBuilder(zero, combined_extent) | ||||||
| .parallel_type(ParallelType::Serial) | ||||||
| .iter_type(IterType::Iteration) | ||||||
| .build(); | ||||||
|
|
||||||
| // Create the Combine expression linking component + ragged -> combined | ||||||
| IrBuilder::createInContainer<Combine>( | ||||||
| container, combined_id, component, ragged); | ||||||
naoyam marked this conversation as resolved.
Show resolved
Hide resolved
naoyam marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
||||||
| return combined_id; | ||||||
| } | ||||||
|
|
||||||
| TensorDomain::TensorDomain( | ||||||
| IrBuilderPasskey passkey, | ||||||
| std::vector<IterDomain*> logical_domain, | ||||||
|
|
@@ -1472,6 +1673,13 @@ bool TensorDomain::hasVectorize() const { | |||||
| }); | ||||||
| } | ||||||
|
|
||||||
| bool TensorDomain::hasRaggedIterDomain() const { | ||||||
| return std::any_of( | ||||||
| logical().begin(), logical().end(), [](IterDomain* logical_id) { | ||||||
| return logical_id->isA<RaggedIterDomain>(); | ||||||
| }); | ||||||
| } | ||||||
|
|
||||||
| std::optional<int64_t> TensorDomain::getReductionAxis() const { | ||||||
| auto it = std::find_if( | ||||||
| loop_domain_.begin(), loop_domain_.end(), [](const auto& id) { | ||||||
|
|
||||||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.