-
Notifications
You must be signed in to change notification settings - Fork 75
RaggedIterDomain cloning #5707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RaggedIterDomain cloning #5707
Changes from 25 commits
d87e6d7
77c6a07
f16fc4d
23d55f1
8392332
787dfec
a0b40a3
dbdd917
cdbd81e
d4c8d7f
9575a13
a054ae0
69dbe0f
db3b359
2348dde
7090b9c
b07e285
a2c504b
b1d8cf4
201c148
9e0b161
60a2dd5
566d63d
144b206
e2efe75
0b68d6b
8a73bb2
550e0c5
82bd85e
f215f07
5b99432
2dd9287
f75ecb6
8aa854e
72ae14f
85d48df
5f86d9c
bf5b627
bec4c09
4d8acab
3b082ba
bfc3da9
54b2093
f260245
10bcf33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| #include <ir/internal_base_nodes.h> | ||
| #include <ir/iostream.h> | ||
| #include <ir/utils.h> | ||
| #include <ops/alias.h> | ||
| #include <ops/arith.h> | ||
| #include <transform_rfactor.h> | ||
| #include <transform_view.h> | ||
|
|
@@ -47,7 +48,11 @@ IterDomainBuilder::IterDomainBuilder(const IterDomain* id) | |
| is_rfactor_domain_(id->isRFactorProduct()), | ||
| is_padded_dimension_(id->hasPaddingToMultipleOfWarp()), | ||
| is_clustered_dimension_(id->isClusteredBlockDim()), | ||
| padded_to_size_(id->getMaybeSizeAfterPadding()) {} | ||
| padded_to_size_(id->getMaybeSizeAfterPadding()) { | ||
| if (id->isA<RaggedIterDomain>()) { | ||
| ragged_extents_ = id->as<RaggedIterDomain>()->extents(); | ||
| } | ||
| } | ||
|
|
||
| IterDomainBuilder& IterDomainBuilder::resetSchedulingParams() { | ||
| parallel_type_ = ParallelType::Serial; | ||
|
|
@@ -115,7 +120,13 @@ IterDomain* IterDomainBuilder::build() const { | |
| NVF_ERROR( | ||
| start_ != nullptr && extent_ != nullptr, | ||
| "Start and extent are required to build an iter domain."); | ||
| return IrBuilder::createInContainer<IterDomain>(start_->container(), *this); | ||
|
|
||
| if (ragged_extents_ != nullptr) { | ||
| return IrBuilder::createInContainer<RaggedIterDomain>( | ||
| start_->container(), *this); | ||
| } else { | ||
| return IrBuilder::createInContainer<IterDomain>(start_->container(), *this); | ||
| } | ||
| } | ||
|
|
||
| IterDomain::IterDomain( | ||
|
|
@@ -814,6 +825,77 @@ void validateLoopDomain( | |
|
|
||
| } // namespace | ||
|
|
||
| RaggedIterDomain::RaggedIterDomain( | ||
| IrBuilderPasskey passkey, | ||
| const IterDomainBuilder& args) | ||
| : IterDomain( | ||
| passkey, | ||
| ValType::RaggedIterDomain, | ||
| args.start_, | ||
| args.extent_, | ||
| args.expanded_extent_, | ||
| args.stop_offset_, | ||
| args.parallel_type_, | ||
| args.iter_type_, | ||
| args.is_rfactor_domain_, | ||
| args.is_padded_dimension_, | ||
| args.is_clustered_dimension_, | ||
| args.padded_to_size_), | ||
| extents_(args.ragged_extents_) { | ||
| // Extents must be non-null | ||
| NVF_ERROR( | ||
| extents_ != nullptr, "RaggedIterDomain requires non-null extents tensor"); | ||
|
|
||
| // Extents must have integer dtype | ||
| NVF_ERROR_EQ( | ||
| extents_->dtype(), | ||
| DataType::Index, | ||
| "RaggedIterDomain extents must have index type, got ", | ||
| extents_->dtype()); | ||
|
|
||
| // Only IterType::Iteration is supported at this moment | ||
| NVF_ERROR_EQ( | ||
| iter_type_, | ||
| IterType::Iteration, | ||
| "Only IterType::Iteration is supported: ", | ||
| iter_type_); | ||
|
|
||
| // RaggedIterDomain has specific requirements on member values | ||
| NVF_ERROR( | ||
| start_->isZeroInt(), | ||
| "RaggedIterDomain start must be zero, got: ", | ||
| start_->toInlineString()); | ||
|
|
||
| NVF_ERROR( | ||
| extent_->isOneInt(), | ||
| "RaggedIterDomain extent must be one (placeholder), got: ", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something to consider: this could be repurposed to store an upper bound of the sum of all component sizes. This came from an offline discussion with @jjsjann123 This will be useful to decide how large a buffer to allocate for block quantization. Given the current grouped gemm kernel we borrowed from cutlass, the block scaling factor tensor has to be padded to a multiple of 16 (?) for each group. Therefore, the allocation size has to be at least In the IterDomain graph, this can be represented as the following:
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or maybe this fits into is_padded_dimension and padded_to_size?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think this is the idea how we could use ragged dimensions to define the smallest allocation size. |
||
| extent_->toInlineString()); | ||
|
|
||
| NVF_ERROR( | ||
| expanded_extent_ == nullptr, | ||
| "RaggedIterDomain does not support expanded_extent"); | ||
|
|
||
| NVF_ERROR( | ||
| stop_offset_ == nullptr || stop_offset_->isZeroInt(), | ||
| "RaggedIterDomain stop_offset must be nullptr or zero, got: ", | ||
| stop_offset_ ? stop_offset_->toInlineString() : "nullptr"); | ||
|
|
||
| NVF_ERROR( | ||
| !is_rfactor_domain_, "RaggedIterDomain does not support rfactor domains"); | ||
|
|
||
| NVF_ERROR( | ||
| !is_padded_dimension_, | ||
| "RaggedIterDomain does not support padded dimensions"); | ||
|
|
||
| NVF_ERROR( | ||
| !is_clustered_dimension_, | ||
| "RaggedIterDomain does not support clustered dimensions"); | ||
|
|
||
| NVF_ERROR( | ||
| !padded_to_size_.has_value(), | ||
| "RaggedIterDomain does not support padded_to_size"); | ||
| } | ||
|
|
||
| RaggedIterDomain::RaggedIterDomain( | ||
| IrBuilderPasskey passkey, | ||
| TensorView* extents, | ||
|
|
@@ -894,6 +976,102 @@ std::string RaggedIterDomain::toString(int indent_size) const { | |
| return toInlineString(indent_size); | ||
| } | ||
|
|
||
| IterDomain* RaggedIterDomain::cloneWithoutRFactor(bool map_with_original) { | ||
| // Create a new RaggedIterDomain with the same extents and properties | ||
| auto cloned = IrBuilder::create<RaggedIterDomain>( | ||
| extents_, getIterType(), getParallelType()); | ||
|
|
||
| // Optionally map the clone with the original in the Exact graph | ||
| if (map_with_original) { | ||
| // TODO: Implement mapping if needed | ||
| NVF_THROW("Not implemented"); | ||
| } | ||
|
Comment on lines
+989
to
+992
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: mapping implementation missing - should call |
||
|
|
||
| return cloned; | ||
| } | ||
naoyam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| std::pair<IterDomain*, RaggedIterDomain*> RaggedIterDomain::partition( | ||
| IterDomain* in, | ||
| TensorView* offsets) { | ||
| NVF_ERROR(in != nullptr, "partition: input IterDomain is null"); | ||
|
|
||
| NVF_ERROR( | ||
| !in->isA<RaggedIterDomain>(), | ||
| "partition: input is already RaggedIterDomain, cannot partition again"); | ||
|
|
||
| NVF_ERROR_EQ( | ||
| in->getParallelType(), | ||
| ParallelType::Serial, | ||
| "Partitioning of parallelized IterDomain not supported: ", | ||
| in->toString()); | ||
|
|
||
| NVF_ERROR_EQ( | ||
| in->getIterType(), | ||
| IterType::Iteration, | ||
| "partition: only IterType::Iteration is supported, got ", | ||
| in->getIterType(), | ||
| " for IterDomain: ", | ||
| in->toString()); | ||
|
|
||
| NVF_ERROR(offsets != nullptr, "partition: offsets tensor is null"); | ||
|
|
||
| NVF_ERROR_EQ( | ||
| offsets->dtype(), | ||
| DataType::Index, | ||
| "partition: offsets must have Index type, got ", | ||
| offsets->dtype()); | ||
|
|
||
| const auto& offsets_domain = offsets->getLogicalDomain(); | ||
| NVF_ERROR_EQ( | ||
| offsets_domain.size(), | ||
| 1, | ||
| "partition: offsets tensor must be 1D, got ", | ||
| offsets_domain.size(), | ||
| "D tensor. Multi-dimensional offsets not yet supported."); | ||
|
|
||
| auto container = in->container(); | ||
|
|
||
| // Compute extents from offsets: extents[i] = offsets[i+1] - offsets[i] | ||
| // offsets_left = offsets[:-1] (all but last element) | ||
| // offsets_right = offsets[1:] (all but first element) | ||
|
|
||
| auto offsets_len = offsets_domain[0]->extent(); | ||
|
|
||
| auto zero = container->zeroVal(DataType::Index); | ||
| auto one = container->oneVal(DataType::Index); | ||
| auto len_minus_one = sub(offsets_len, one); | ||
|
|
||
| // Slice offsets[:-1] | ||
| Slice left_slice; | ||
| left_slice.start = zero; | ||
| left_slice.stop = len_minus_one; | ||
| auto offsets_left = slice(offsets, {left_slice}); | ||
|
|
||
| // Slice offsets[1:] | ||
| Slice right_slice; | ||
| right_slice.start = one; | ||
| right_slice.stop = offsets_len; | ||
| auto offsets_right = slice(offsets, {right_slice}); | ||
|
|
||
| // Compute extents: extents = offsets_right - offsets_left | ||
| auto extents = sub(offsets_right, offsets_left); | ||
|
|
||
| // Create component IterDomain | ||
| // Component extent = number of components = len(offsets) - 1 | ||
| auto component_extent = len_minus_one; | ||
| auto component_id = IterDomainBuilder(zero, component_extent) | ||
| .parallel_type(ParallelType::Serial) | ||
| .iter_type(IterType::Iteration) | ||
| .build(); | ||
|
|
||
| auto ragged_id = | ||
| IrBuilder::create<RaggedIterDomain>(extents, in->getIterType()); | ||
|
|
||
| IrBuilder::create<Partition>(component_id, ragged_id, in, extents); | ||
|
|
||
| return {component_id, ragged_id}; | ||
| } | ||
|
|
||
| TensorDomain::TensorDomain( | ||
| IrBuilderPasskey passkey, | ||
| std::vector<IterDomain*> logical_domain, | ||
|
|
@@ -1413,6 +1591,13 @@ bool TensorDomain::hasVectorize() const { | |
| }); | ||
| } | ||
|
|
||
| bool TensorDomain::hasRaggedIterDomain() const { | ||
| return std::any_of( | ||
| logical().begin(), logical().end(), [](IterDomain* logical_id) { | ||
| return logical_id->isA<RaggedIterDomain>(); | ||
| }); | ||
| } | ||
|
|
||
| std::optional<int64_t> TensorDomain::getReductionAxis() const { | ||
| auto it = std::find_if( | ||
| loop_domain_.begin(), loop_domain_.end(), [](const auto& id) { | ||
|
|
@@ -1498,6 +1683,22 @@ void TensorDomain::merge(int64_t axis_o, int64_t axis_i) { | |
| loop_domain_.insert(loop_domain_.begin() + td_outer_pos, merged_id); | ||
| } | ||
|
|
||
| // Partition "axis" into component and ragged dimensions. Follow the | ||
| // pattern of TensorDomain::split. | ||
| void TensorDomain::partition(int64_t axis, TensorView* offsets) { | ||
| NVF_ERROR(nDims() > 0, "Tried to do partition on a 0-dim domain"); | ||
| axis = wrapDim(axis); | ||
|
|
||
| IterDomain* id = this->axis(axis); | ||
|
|
||
| auto [component_id, ragged_id] = RaggedIterDomain::partition(id, offsets); | ||
|
|
||
| // Remove the original axis and insert component and ragged dimensions | ||
| loop_domain_.erase(loop_domain_.begin() + axis); | ||
| loop_domain_.insert(loop_domain_.begin() + axis, ragged_id); | ||
| loop_domain_.insert(loop_domain_.begin() + axis, component_id); | ||
| } | ||
|
|
||
| // Reorder axes according to map[old_pos] = new_pos | ||
| void TensorDomain::reorder( | ||
| const std::unordered_map<int64_t, int64_t>& old2new_) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -61,6 +61,7 @@ class IterDomainBuilder { | |||||
| IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain); | ||||||
| IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension); | ||||||
| IterDomainBuilder& padded_to_size(std::optional<int64_t> _padded_to_size); | ||||||
| IterDomainBuilder& ragged_extents(TensorView* _ragged_extents); | ||||||
|
||||||
| IterDomainBuilder& ragged_extents(TensorView* _ragged_extents); | |
| // IterDomainBuilder& ragged_extents(TensorView* _ragged_extents); // TODO: implement or use direct member access |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using
IterDomainBuilderconsistentlyThe new
RaggedIterDomain::cloneWithoutRFactor(line 985-997) doesn't useIterDomainBuilderbut directly callsIrBuilder::create<RaggedIterDomain>, bypassing the builder pattern established here. For consistency, both paths should use the same construction method.Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!