-
Notifications
You must be signed in to change notification settings - Fork 74
RaggedIterDomain cloning #5707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
RaggedIterDomain cloning #5707
Changes from all commits
d87e6d7
77c6a07
f16fc4d
23d55f1
8392332
787dfec
a0b40a3
dbdd917
cdbd81e
d4c8d7f
9575a13
a054ae0
69dbe0f
db3b359
2348dde
7090b9c
b07e285
a2c504b
b1d8cf4
201c148
9e0b161
60a2dd5
566d63d
144b206
e2efe75
0b68d6b
8a73bb2
550e0c5
82bd85e
f215f07
5b99432
2dd9287
f75ecb6
8aa854e
72ae14f
85d48df
5f86d9c
bf5b627
bec4c09
4d8acab
3b082ba
bfc3da9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,8 +30,8 @@ | |
|
|
||
| namespace nvfuser { | ||
|
|
||
| IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent) | ||
| : start_(_start), extent_(_extent) { | ||
| IterDomainBuilder::IterDomainBuilder(Val* start, Val* extent) | ||
| : start_(start), extent_(extent) { | ||
| NVF_ERROR( | ||
| start_ != nullptr && extent_ != nullptr, | ||
| "Start and extent are required to build an iter domain."); | ||
|
|
@@ -48,7 +48,11 @@ IterDomainBuilder::IterDomainBuilder(const IterDomain* id) | |
| is_rfactor_domain_(id->isRFactorProduct()), | ||
| is_padded_dimension_(id->hasPaddingToMultipleOfWarp()), | ||
| is_clustered_dimension_(id->isClusteredBlockDim()), | ||
| padded_to_size_(id->getMaybeSizeAfterPadding()) {} | ||
| padded_to_size_(id->getMaybeSizeAfterPadding()) { | ||
| if (id->isA<RaggedIterDomain>()) { | ||
| ragged_extents_ = id->as<RaggedIterDomain>()->extents(); | ||
| } | ||
| } | ||
|
Comment on lines
48
to
+55
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using The new Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||
|
|
||
| IterDomainBuilder& IterDomainBuilder::resetSchedulingParams() { | ||
| parallel_type_ = ParallelType::Serial; | ||
|
|
@@ -63,60 +67,72 @@ IterDomainBuilder& IterDomainBuilder::resetRfactor() { | |
| return is_rfactor_domain(false); | ||
| } | ||
|
|
||
| IterDomainBuilder& IterDomainBuilder::start(Val* _start) { | ||
| start_ = _start; | ||
| IterDomainBuilder& IterDomainBuilder::start(Val* start) { | ||
| start_ = start; | ||
| return *this; | ||
| } | ||
|
|
||
| IterDomainBuilder& IterDomainBuilder::extent(Val* _extent) { | ||
| extent_ = _extent; | ||
| IterDomainBuilder& IterDomainBuilder::extent(Val* extent) { | ||
| extent_ = extent; | ||
| return *this; | ||
| } | ||
|
|
||
| IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* _expanded_extent) { | ||
| expanded_extent_ = _expanded_extent; | ||
| IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* expanded_extent) { | ||
| expanded_extent_ = expanded_extent; | ||
| return *this; | ||
| } | ||
|
|
||
| IterDomainBuilder& IterDomainBuilder::stop_offset(Val* _stop_offset) { | ||
| stop_offset_ = _stop_offset; | ||
| IterDomainBuilder& IterDomainBuilder::stop_offset(Val* stop_offset) { | ||
| stop_offset_ = stop_offset; | ||
| return *this; | ||
| } | ||
|
|
||
| IterDomainBuilder& IterDomainBuilder::parallel_type( | ||
| ParallelType _parallel_type) { | ||
| parallel_type_ = _parallel_type; | ||
| ParallelType parallel_type) { | ||
| parallel_type_ = parallel_type; | ||
| return *this; | ||
| } | ||
|
|
||
| IterDomainBuilder& IterDomainBuilder::iter_type(IterType _iter_type) { | ||
| iter_type_ = _iter_type; | ||
| IterDomainBuilder& IterDomainBuilder::iter_type(IterType iter_type) { | ||
| iter_type_ = iter_type; | ||
| return *this; | ||
| } | ||
|
|
||
| IterDomainBuilder& IterDomainBuilder::is_rfactor_domain( | ||
| bool _is_rfactor_domain) { | ||
| is_rfactor_domain_ = _is_rfactor_domain; | ||
| bool is_rfactor_domain) { | ||
| is_rfactor_domain_ = is_rfactor_domain; | ||
| return *this; | ||
| } | ||
|
|
||
| IterDomainBuilder& IterDomainBuilder::is_padded_dimension( | ||
| bool _is_padded_dimension) { | ||
| is_padded_dimension_ = _is_padded_dimension; | ||
| bool is_padded_dimension) { | ||
| is_padded_dimension_ = is_padded_dimension; | ||
|
Comment on lines
+70
to
+109
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just renaming |
||
| return *this; | ||
| } | ||
|
|
||
| IterDomainBuilder& IterDomainBuilder::padded_to_size( | ||
| std::optional<int64_t> _padded_to_size) { | ||
| padded_to_size_ = _padded_to_size; | ||
| std::optional<int64_t> padded_to_size) { | ||
| padded_to_size_ = padded_to_size; | ||
| return *this; | ||
| } | ||
|
|
||
| IterDomainBuilder& IterDomainBuilder::ragged_extents( | ||
| TensorView* ragged_extents) { | ||
| ragged_extents_ = ragged_extents; | ||
| return *this; | ||
| } | ||
|
|
||
| IterDomain* IterDomainBuilder::build() const { | ||
| NVF_ERROR( | ||
| start_ != nullptr && extent_ != nullptr, | ||
| "Start and extent are required to build an iter domain."); | ||
| return IrBuilder::createInContainer<IterDomain>(start_->container(), *this); | ||
|
|
||
| if (ragged_extents_ != nullptr) { | ||
| return IrBuilder::createInContainer<RaggedIterDomain>( | ||
| start_->container(), *this); | ||
| } else { | ||
| return IrBuilder::createInContainer<IterDomain>(start_->container(), *this); | ||
| } | ||
| } | ||
|
|
||
| IterDomain::IterDomain( | ||
|
|
@@ -815,6 +831,77 @@ void validateLoopDomain( | |
|
|
||
| } // namespace | ||
|
|
||
| RaggedIterDomain::RaggedIterDomain( | ||
| IrBuilderPasskey passkey, | ||
| const IterDomainBuilder& args) | ||
| : IterDomain( | ||
| passkey, | ||
| ValType::RaggedIterDomain, | ||
| args.start_, | ||
| args.extent_, | ||
| args.expanded_extent_, | ||
| args.stop_offset_, | ||
| args.parallel_type_, | ||
| args.iter_type_, | ||
| args.is_rfactor_domain_, | ||
| args.is_padded_dimension_, | ||
| args.is_clustered_dimension_, | ||
| args.padded_to_size_), | ||
| extents_(args.ragged_extents_) { | ||
| // Extents must be non-null | ||
| NVF_ERROR( | ||
| extents_ != nullptr, "RaggedIterDomain requires non-null extents tensor"); | ||
|
|
||
| // Extents must have integer dtype | ||
| NVF_ERROR_EQ( | ||
| extents_->dtype(), | ||
| DataType::Index, | ||
| "RaggedIterDomain extents must have index type, got ", | ||
| extents_->dtype()); | ||
|
|
||
| // Only IterType::Iteration is supported at this moment | ||
| NVF_ERROR_EQ( | ||
| iter_type_, | ||
| IterType::Iteration, | ||
| "Only IterType::Iteration is supported: ", | ||
| iter_type_); | ||
|
|
||
| // RaggedIterDomain has specific requirements on member values | ||
| NVF_ERROR( | ||
| start_->isZeroInt(), | ||
| "RaggedIterDomain start must be zero, got: ", | ||
| start_->toInlineString()); | ||
|
|
||
| NVF_ERROR( | ||
| extent_->isOneInt(), | ||
| "RaggedIterDomain extent must be one (placeholder), got: ", | ||
| extent_->toInlineString()); | ||
|
|
||
| NVF_ERROR( | ||
| expanded_extent_ == nullptr, | ||
| "RaggedIterDomain does not support expanded_extent"); | ||
|
|
||
| NVF_ERROR( | ||
| stop_offset_ == nullptr || stop_offset_->isZeroInt(), | ||
| "RaggedIterDomain stop_offset must be nullptr or zero, got: ", | ||
| stop_offset_ ? stop_offset_->toInlineString() : "nullptr"); | ||
|
|
||
| NVF_ERROR( | ||
| !is_rfactor_domain_, "RaggedIterDomain does not support rfactor domains"); | ||
|
|
||
| NVF_ERROR( | ||
| !is_padded_dimension_, | ||
| "RaggedIterDomain does not support padded dimensions"); | ||
|
|
||
| NVF_ERROR( | ||
| !is_clustered_dimension_, | ||
| "RaggedIterDomain does not support clustered dimensions"); | ||
|
|
||
| NVF_ERROR( | ||
| !padded_to_size_.has_value(), | ||
| "RaggedIterDomain does not support padded_to_size"); | ||
| } | ||
|
|
||
| RaggedIterDomain::RaggedIterDomain( | ||
| IrBuilderPasskey passkey, | ||
| TensorView* extents, | ||
|
|
@@ -895,6 +982,18 @@ std::string RaggedIterDomain::toString(int indent_size) const { | |
| return toInlineString(indent_size); | ||
| } | ||
|
|
||
| IterDomain* RaggedIterDomain::cloneWithoutRFactor(bool map_with_original) { | ||
| auto cloned = IterDomainBuilder(this).resetRfactor().build(); | ||
|
|
||
| // Optionally map the clone with the original in the Exact graph | ||
| if (map_with_original) { | ||
| // TODO: Implement mapping if needed | ||
| NVF_THROW("Not implemented"); | ||
| } | ||
|
Comment on lines
+989
to
+992
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: mapping implementation missing - should call |
||
|
|
||
| return cloned; | ||
| } | ||
|
|
||
| std::pair<IterDomain*, RaggedIterDomain*> RaggedIterDomain::partition( | ||
| IterDomain* in, | ||
| TensorView* extents) { | ||
|
|
@@ -1472,6 +1571,13 @@ bool TensorDomain::hasVectorize() const { | |
| }); | ||
| } | ||
|
|
||
| bool TensorDomain::hasRaggedIterDomain() const { | ||
| return std::any_of( | ||
| logical().begin(), logical().end(), [](IterDomain* logical_id) { | ||
| return logical_id->isA<RaggedIterDomain>(); | ||
| }); | ||
| } | ||
|
|
||
| std::optional<int64_t> TensorDomain::getReductionAxis() const { | ||
| auto it = std::find_if( | ||
| loop_domain_.begin(), loop_domain_.end(), [](const auto& id) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,7 +40,7 @@ struct AnalyzeViewResult; | |
| class IterDomainBuilder { | ||
| public: | ||
| // Match legacy constructor | ||
| IterDomainBuilder(Val* _start, Val* _extent); | ||
| IterDomainBuilder(Val* start, Val* extent); | ||
|
|
||
| // Grab all the parameters from id to set the IterDomainBuilder | ||
| IterDomainBuilder(const IterDomain* id); | ||
|
|
@@ -52,15 +52,16 @@ class IterDomainBuilder { | |
| // Resets is_rfactor_domain | ||
| IterDomainBuilder& resetRfactor(); | ||
|
|
||
| IterDomainBuilder& start(Val* _start); | ||
| IterDomainBuilder& extent(Val* _extent); | ||
| IterDomainBuilder& expanded_extent(Val* _expanded_extent); | ||
| IterDomainBuilder& stop_offset(Val* _stop_offset); | ||
| IterDomainBuilder& parallel_type(ParallelType _parallel_type); | ||
| IterDomainBuilder& iter_type(IterType _iter_type); | ||
| IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain); | ||
| IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension); | ||
| IterDomainBuilder& padded_to_size(std::optional<int64_t> _padded_to_size); | ||
| IterDomainBuilder& start(Val* start); | ||
| IterDomainBuilder& extent(Val* extent); | ||
| IterDomainBuilder& expanded_extent(Val* expanded_extent); | ||
| IterDomainBuilder& stop_offset(Val* stop_offset); | ||
| IterDomainBuilder& parallel_type(ParallelType parallel_type); | ||
| IterDomainBuilder& iter_type(IterType iter_type); | ||
| IterDomainBuilder& is_rfactor_domain(bool is_rfactor_domain); | ||
| IterDomainBuilder& is_padded_dimension(bool is_padded_dimension); | ||
| IterDomainBuilder& padded_to_size(std::optional<int64_t> padded_to_size); | ||
|
Comment on lines
+55
to
+63
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just renaming |
||
| IterDomainBuilder& ragged_extents(TensorView* ragged_extents); | ||
|
|
||
| IterDomain* build() const; | ||
|
|
||
|
|
@@ -79,6 +80,9 @@ class IterDomainBuilder { | |
| bool is_padded_dimension_ = false; | ||
| bool is_clustered_dimension_ = false; | ||
| std::optional<int64_t> padded_to_size_ = std::nullopt; | ||
|
|
||
| // For RaggedIterDomain: stores the extents tensor | ||
| TensorView* ragged_extents_ = nullptr; | ||
| }; | ||
|
|
||
| //! Simply a representation of an annotated 1D iterable from start to extent. | ||
|
|
@@ -122,7 +126,7 @@ class NVF_API IterDomain : public Val { | |
| //! | ||
| //! When map_with_original is true, the clone of the original is | ||
| //! mapped in the Exact graph. | ||
| IterDomain* cloneWithoutRFactor(bool map_with_original = false); | ||
| virtual IterDomain* cloneWithoutRFactor(bool map_with_original = false); | ||
|
|
||
| //! Clone a vector domains | ||
| static std::vector<IterDomain*> clone( | ||
|
|
@@ -448,6 +452,8 @@ class NVF_API IterDomain : public Val { | |
| //! components | ||
| class NVF_API RaggedIterDomain : public IterDomain { | ||
| public: | ||
| RaggedIterDomain(IrBuilderPasskey passkey, const IterDomainBuilder& args); | ||
|
|
||
| //! \param extents TensorView containing component extents (must be integer | ||
| //! type) | ||
| //! \param iter_type Iteration type (Iteration, Reduction, etc.) | ||
|
|
@@ -493,6 +499,9 @@ class NVF_API RaggedIterDomain : public IterDomain { | |
| IterDomain* in, | ||
| TensorView* extents); | ||
|
|
||
| //! Override cloneWithoutRFactor to preserve RaggedIterDomain type | ||
| IterDomain* cloneWithoutRFactor(bool map_with_original = false) override; | ||
|
|
||
| private: | ||
| //! Extent tensor containing all component extents | ||
| //! Can be 1D, 2D, or N-D depending on nesting structure | ||
|
|
@@ -643,6 +652,8 @@ class NVF_API TensorDomain : public Val { | |
|
|
||
| bool hasSymbolicAxis() const; | ||
|
|
||
| bool hasRaggedIterDomain() const; | ||
|
|
||
| std::optional<int64_t> getReductionAxis() const; | ||
|
|
||
| // The input logical domain. The root domain of a consumer should equal the | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just renaming