Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
d87e6d7
Initial introduction of RaggedIterDomain
naoyam Dec 12, 2025
77c6a07
Merge remote-tracking branch 'origin/main' into raggediterdomain_init…
naoyam Dec 12, 2025
f16fc4d
cleanup
naoyam Dec 12, 2025
23d55f1
fix
naoyam Dec 12, 2025
8392332
fix
naoyam Dec 12, 2025
787dfec
unit test
naoyam Dec 12, 2025
a0b40a3
cleanup
naoyam Dec 12, 2025
dbdd917
Fix IterVisitor
naoyam Dec 12, 2025
cdbd81e
cleanup
naoyam Dec 12, 2025
d4c8d7f
WIP: partition
naoyam Dec 12, 2025
9575a13
Partition expr
naoyam Dec 13, 2025
a054ae0
TensorView::partition
naoyam Dec 13, 2025
69dbe0f
cleanup
naoyam Dec 13, 2025
db3b359
Merge remote-tracking branch 'origin/main' into raggediterdomain_part…
naoyam Dec 13, 2025
2348dde
cleanup
naoyam Dec 13, 2025
7090b9c
WIP: asNested
naoyam Dec 13, 2025
b07e285
cleanup
naoyam Dec 13, 2025
a2c504b
asNested
naoyam Dec 15, 2025
b1d8cf4
warpdim
naoyam Dec 15, 2025
201c148
Make sure RaggedIterDomain is propagated to output tensors
naoyam Dec 17, 2025
9e0b161
Extend ops to be aware with RaggediterDomain
naoyam Dec 17, 2025
60a2dd5
RaggedIterDomain and reduction
naoyam Dec 17, 2025
566d63d
WIP
naoyam Dec 18, 2025
144b206
WIP
naoyam Dec 18, 2025
e2efe75
cleanup
naoyam Dec 18, 2025
0b68d6b
cleanup
naoyam Dec 18, 2025
8a73bb2
cleanup
naoyam Dec 18, 2025
550e0c5
Merge branch 'raggediterdomain_partition' into raggediterdomain-asnested
naoyam Dec 18, 2025
82bd85e
Merge remote-tracking branch 'origin/main' into raggediterdomain-asne…
naoyam Dec 18, 2025
f215f07
Use extents as a parameter
naoyam Dec 18, 2025
5b99432
Merge remote-tracking branch 'origin/main' into raggediterdomain-asne…
naoyam Dec 18, 2025
2dd9287
Merge branch 'raggediterdomain-asnested' into raggediterdomain_clone
naoyam Dec 18, 2025
f75ecb6
Merge branch 'main' into raggediterdomain-asnested
naoyam Jan 7, 2026
8aa854e
feedback
naoyam Jan 7, 2026
72ae14f
fix
naoyam Jan 7, 2026
85d48df
Merge branch 'raggediterdomain-asnested' into raggediterdomain_clone
naoyam Jan 7, 2026
5f86d9c
Merge branch 'main' into raggediterdomain_clone
naoyam Jan 7, 2026
bf5b627
Merge remote-tracking branch 'origin/main' into raggediterdomain_clone
naoyam Jan 7, 2026
bec4c09
Merge remote-tracking branch 'origin/main' into raggediterdomain_clone
naoyam Jan 7, 2026
4d8acab
cleanup
naoyam Jan 9, 2026
3b082ba
cleanup
naoyam Jan 9, 2026
bfc3da9
Update csrc/ops/utils.cpp
naoyam Jan 9, 2026
54b2093
Merge remote-tracking branch 'origin/main' into raggediterdomain_clone
naoyam Jan 16, 2026
f260245
cleanup
naoyam Jan 16, 2026
10bcf33
Merge remote-tracking branch 'origin/raggediterdomain_clone' into rag…
naoyam Jan 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class Val;
f(TopKOp); \
f(ScanOp); \
f(Merge); \
f(Partition); \
f(Swizzle); \
f(Swizzle2D); \
f(Resize); \
Expand Down
9 changes: 9 additions & 0 deletions csrc/ir/interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,15 @@ class NVF_API TensorView : public Val {
return merge(axis, axis + 1);
}

// Partition "axis" into component and ragged dimensions based on offsets
// The offsets tensor defines partition boundaries where:
// Shape: [num_components + 1], values: [0, off1, off2, ..., total]
// Extents are computed as: extents[i] = offsets[i+1] - offsets[i]
// Returns this TensorView with the axis replaced by component and ragged dims
// e.g. partition(0, offsets) on tv[id{N}] results in:
// tv[id{num_components}, ragged_id{extents}]
TensorView* partition(int64_t axis, TensorView* offsets);

// Flatten the axis from `from` to `to` into a single axis.
// Both `from` and `to` are inclusive.
TensorView* flatten(int64_t from = 0, int64_t to = -1);
Expand Down
205 changes: 203 additions & 2 deletions csrc/ir/internal_base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <ir/internal_base_nodes.h>
#include <ir/iostream.h>
#include <ir/utils.h>
#include <ops/alias.h>
#include <ops/arith.h>
#include <transform_rfactor.h>
#include <transform_view.h>
Expand All @@ -47,7 +48,11 @@ IterDomainBuilder::IterDomainBuilder(const IterDomain* id)
is_rfactor_domain_(id->isRFactorProduct()),
is_padded_dimension_(id->hasPaddingToMultipleOfWarp()),
is_clustered_dimension_(id->isClusteredBlockDim()),
padded_to_size_(id->getMaybeSizeAfterPadding()) {}
padded_to_size_(id->getMaybeSizeAfterPadding()) {
if (id->isA<RaggedIterDomain>()) {
ragged_extents_ = id->as<RaggedIterDomain>()->extents();
}
}
Comment on lines 48 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using IterDomainBuilder consistently

The new RaggedIterDomain::cloneWithoutRFactor (line 985-997) doesn't use IterDomainBuilder but directly calls IrBuilder::create<RaggedIterDomain>, bypassing the builder pattern established here. For consistency, both paths should use the same construction method.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


IterDomainBuilder& IterDomainBuilder::resetSchedulingParams() {
parallel_type_ = ParallelType::Serial;
Expand Down Expand Up @@ -115,7 +120,13 @@ IterDomain* IterDomainBuilder::build() const {
NVF_ERROR(
start_ != nullptr && extent_ != nullptr,
"Start and extent are required to build an iter domain.");
return IrBuilder::createInContainer<IterDomain>(start_->container(), *this);

if (ragged_extents_ != nullptr) {
return IrBuilder::createInContainer<RaggedIterDomain>(
start_->container(), *this);
} else {
return IrBuilder::createInContainer<IterDomain>(start_->container(), *this);
}
}

IterDomain::IterDomain(
Expand Down Expand Up @@ -814,6 +825,77 @@ void validateLoopDomain(

} // namespace

RaggedIterDomain::RaggedIterDomain(
IrBuilderPasskey passkey,
const IterDomainBuilder& args)
: IterDomain(
passkey,
ValType::RaggedIterDomain,
args.start_,
args.extent_,
args.expanded_extent_,
args.stop_offset_,
args.parallel_type_,
args.iter_type_,
args.is_rfactor_domain_,
args.is_padded_dimension_,
args.is_clustered_dimension_,
args.padded_to_size_),
extents_(args.ragged_extents_) {
// Extents must be non-null
NVF_ERROR(
extents_ != nullptr, "RaggedIterDomain requires non-null extents tensor");

// Extents must have integer dtype
NVF_ERROR_EQ(
extents_->dtype(),
DataType::Index,
"RaggedIterDomain extents must have index type, got ",
extents_->dtype());

// Only IterType::Iteration is supported at this moment
NVF_ERROR_EQ(
iter_type_,
IterType::Iteration,
"Only IterType::Iteration is supported: ",
iter_type_);

// RaggedIterDomain has specific requirements on member values
NVF_ERROR(
start_->isZeroInt(),
"RaggedIterDomain start must be zero, got: ",
start_->toInlineString());

NVF_ERROR(
extent_->isOneInt(),
"RaggedIterDomain extent must be one (placeholder), got: ",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something to consider: this could be repurposed to store an upper bound of the sum of all component sizes. This came from an offline discussion with @jjsjann123

This will be useful to decide how large a buffer to allocate for block quantization. Given the current grouped gemm kernel we borrowed from cutlass, the block scaling factor tensor has to be padded to a multiple of 16 (?) for each group. Therefore, the allocation size has to be at least ceildiv(num_tokens + g * 15, 16) * 16.

In the IterDomain graph, this can be represented as the following:

# m = number of tokens

[i{g},      ragged{m}  ]               <= logical domain
        /               \
     ragged              i{16}         <= allocation domain
{ceildiv(m+g*15, 16)}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe this fits into is_padded_dimension and padded_to_size?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this is the idea how we could use ragged dimensions to define the smallest allocation size.

extent_->toInlineString());

NVF_ERROR(
expanded_extent_ == nullptr,
"RaggedIterDomain does not support expanded_extent");

NVF_ERROR(
stop_offset_ == nullptr || stop_offset_->isZeroInt(),
"RaggedIterDomain stop_offset must be nullptr or zero, got: ",
stop_offset_ ? stop_offset_->toInlineString() : "nullptr");

NVF_ERROR(
!is_rfactor_domain_, "RaggedIterDomain does not support rfactor domains");

NVF_ERROR(
!is_padded_dimension_,
"RaggedIterDomain does not support padded dimensions");

NVF_ERROR(
!is_clustered_dimension_,
"RaggedIterDomain does not support clustered dimensions");

NVF_ERROR(
!padded_to_size_.has_value(),
"RaggedIterDomain does not support padded_to_size");
}

RaggedIterDomain::RaggedIterDomain(
IrBuilderPasskey passkey,
TensorView* extents,
Expand Down Expand Up @@ -894,6 +976,102 @@ std::string RaggedIterDomain::toString(int indent_size) const {
return toInlineString(indent_size);
}

IterDomain* RaggedIterDomain::cloneWithoutRFactor(bool map_with_original) {
// Create a new RaggedIterDomain with the same extents and properties
auto cloned = IrBuilder::create<RaggedIterDomain>(
extents_, getIterType(), getParallelType());

// Optionally map the clone with the original in the Exact graph
if (map_with_original) {
// TODO: Implement mapping if needed
NVF_THROW("Not implemented");
}
Comment on lines +989 to +992
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: mapping implementation missing - should call fusion()->registerExactMapping(this, cloned) like base IterDomain::cloneWithoutRFactor does (line 334)


return cloned;
}

std::pair<IterDomain*, RaggedIterDomain*> RaggedIterDomain::partition(
IterDomain* in,
TensorView* offsets) {
NVF_ERROR(in != nullptr, "partition: input IterDomain is null");

NVF_ERROR(
!in->isA<RaggedIterDomain>(),
"partition: input is already RaggedIterDomain, cannot partition again");

NVF_ERROR_EQ(
in->getParallelType(),
ParallelType::Serial,
"Partitioning of parallelized IterDomain not supported: ",
in->toString());

NVF_ERROR_EQ(
in->getIterType(),
IterType::Iteration,
"partition: only IterType::Iteration is supported, got ",
in->getIterType(),
" for IterDomain: ",
in->toString());

NVF_ERROR(offsets != nullptr, "partition: offsets tensor is null");

NVF_ERROR_EQ(
offsets->dtype(),
DataType::Index,
"partition: offsets must have Index type, got ",
offsets->dtype());

const auto& offsets_domain = offsets->getLogicalDomain();
NVF_ERROR_EQ(
offsets_domain.size(),
1,
"partition: offsets tensor must be 1D, got ",
offsets_domain.size(),
"D tensor. Multi-dimensional offsets not yet supported.");

auto container = in->container();

// Compute extents from offsets: extents[i] = offsets[i+1] - offsets[i]
// offsets_left = offsets[:-1] (all but last element)
// offsets_right = offsets[1:] (all but first element)

auto offsets_len = offsets_domain[0]->extent();

auto zero = container->zeroVal(DataType::Index);
auto one = container->oneVal(DataType::Index);
auto len_minus_one = sub(offsets_len, one);

// Slice offsets[:-1]
Slice left_slice;
left_slice.start = zero;
left_slice.stop = len_minus_one;
auto offsets_left = slice(offsets, {left_slice});

// Slice offsets[1:]
Slice right_slice;
right_slice.start = one;
right_slice.stop = offsets_len;
auto offsets_right = slice(offsets, {right_slice});

// Compute extents: extents = offsets_right - offsets_left
auto extents = sub(offsets_right, offsets_left);

// Create component IterDomain
// Component extent = number of components = len(offsets) - 1
auto component_extent = len_minus_one;
auto component_id = IterDomainBuilder(zero, component_extent)
.parallel_type(ParallelType::Serial)
.iter_type(IterType::Iteration)
.build();

auto ragged_id =
IrBuilder::create<RaggedIterDomain>(extents, in->getIterType());

IrBuilder::create<Partition>(component_id, ragged_id, in, extents);

return {component_id, ragged_id};
}

TensorDomain::TensorDomain(
IrBuilderPasskey passkey,
std::vector<IterDomain*> logical_domain,
Expand Down Expand Up @@ -1413,6 +1591,13 @@ bool TensorDomain::hasVectorize() const {
});
}

bool TensorDomain::hasRaggedIterDomain() const {
return std::any_of(
logical().begin(), logical().end(), [](IterDomain* logical_id) {
return logical_id->isA<RaggedIterDomain>();
});
}

std::optional<int64_t> TensorDomain::getReductionAxis() const {
auto it = std::find_if(
loop_domain_.begin(), loop_domain_.end(), [](const auto& id) {
Expand Down Expand Up @@ -1498,6 +1683,22 @@ void TensorDomain::merge(int64_t axis_o, int64_t axis_i) {
loop_domain_.insert(loop_domain_.begin() + td_outer_pos, merged_id);
}

// Partition "axis" into component and ragged dimensions. Follow the
// pattern of TensorDomain::split.
void TensorDomain::partition(int64_t axis, TensorView* offsets) {
NVF_ERROR(nDims() > 0, "Tried to do partition on a 0-dim domain");
axis = wrapDim(axis);

IterDomain* id = this->axis(axis);

auto [component_id, ragged_id] = RaggedIterDomain::partition(id, offsets);

// Remove the original axis and insert component and ragged dimensions
loop_domain_.erase(loop_domain_.begin() + axis);
loop_domain_.insert(loop_domain_.begin() + axis, ragged_id);
loop_domain_.insert(loop_domain_.begin() + axis, component_id);
}

// Reorder axes according to map[old_pos] = new_pos
void TensorDomain::reorder(
const std::unordered_map<int64_t, int64_t>& old2new_) {
Expand Down
32 changes: 31 additions & 1 deletion csrc/ir/internal_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class IterDomainBuilder {
IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain);
IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension);
IterDomainBuilder& padded_to_size(std::optional<int64_t> _padded_to_size);
IterDomainBuilder& ragged_extents(TensorView* _ragged_extents);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: missing implementation for ragged_extents method - declared in header but not implemented in csrc/ir/internal_base_nodes.cpp

Suggested change
IterDomainBuilder& ragged_extents(TensorView* _ragged_extents);
// IterDomainBuilder& ragged_extents(TensorView* _ragged_extents); // TODO: implement or use direct member access


IterDomain* build() const;

Expand All @@ -79,6 +80,9 @@ class IterDomainBuilder {
bool is_padded_dimension_ = false;
bool is_clustered_dimension_ = false;
std::optional<int64_t> padded_to_size_ = std::nullopt;

// For RaggedIterDomain: stores the extents tensor
TensorView* ragged_extents_ = nullptr;
};

//! Simply a representation of an annotated 1D iterable from start to extent.
Expand Down Expand Up @@ -122,7 +126,7 @@ class NVF_API IterDomain : public Val {
//!
//! When map_with_original is true, the clone of the original is
//! mapped in the Exact graph.
IterDomain* cloneWithoutRFactor(bool map_with_original = false);
virtual IterDomain* cloneWithoutRFactor(bool map_with_original = false);

//! Clone a vector domains
static std::vector<IterDomain*> clone(
Expand Down Expand Up @@ -448,6 +452,8 @@ class NVF_API IterDomain : public Val {
//! components
class NVF_API RaggedIterDomain : public IterDomain {
public:
RaggedIterDomain(IrBuilderPasskey passkey, const IterDomainBuilder& args);

//! \param extents TensorView containing component extents (must be integer
//! type)
//! \param iter_type Iteration type (Iteration, Reduction, etc.)
Expand Down Expand Up @@ -476,6 +482,25 @@ class NVF_API RaggedIterDomain : public IterDomain {
return extents_;
}

//! Partition an IterDomain into component and ragged dimensions
//! Creates a component IterDomain and a RaggedIterDomain based on offsets
//!
//! \param in Input IterDomain to partition (must be regular IterDomain)
//! \param offsets Offset tensor defining partition boundaries (must be 1D)
//! Shape: [num_components + 1], values: [0, off1, off2, ..., total]
//! Extents are computed as: extents[i] = offsets[i+1] - offsets[i]
//! \return Pair of (component_id, ragged_id)
//! component_id: IterDomain with extent = num_components
//! ragged_id: RaggedIterDomain with extents computed from offsets
//!
//! TODO: Support multi-dimensional offsets for nested ragged structures
static std::pair<IterDomain*, RaggedIterDomain*> partition(
IterDomain* in,
TensorView* offsets);

//! Override cloneWithoutRFactor to preserve RaggedIterDomain type
IterDomain* cloneWithoutRFactor(bool map_with_original = false) override;

private:
//! Extent tensor containing all component extents
//! Can be 1D, 2D, or N-D depending on nesting structure
Expand Down Expand Up @@ -626,6 +651,8 @@ class NVF_API TensorDomain : public Val {

bool hasSymbolicAxis() const;

bool hasRaggedIterDomain() const;

std::optional<int64_t> getReductionAxis() const;

// The input logical domain. The root domain of a consumer should equal the
Expand Down Expand Up @@ -776,6 +803,9 @@ class NVF_API TensorDomain : public Val {
// axis is by default placed at original position axis_o
void merge(int64_t axis_o, int64_t axis_i);

// Partition axis into component and ragged dimensions based on offsets
void partition(int64_t axis, TensorView* offsets);

// Reorder axes according to map[old_pos] = new_pos
void reorder(const std::unordered_map<int64_t, int64_t>& old2new);

Expand Down
33 changes: 33 additions & 0 deletions csrc/ir/internal_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2612,6 +2612,39 @@ std::string Merge::toInlineString(int indent_size) const {

NVFUSER_DEFINE_CLONE_AND_CREATE(Merge)

Partition::Partition(
IrBuilderPasskey passkey,
IterDomain* component,
RaggedIterDomain* ragged,
IterDomain* in,
TensorView* extents)
: Expr(passkey) {
addOutput(component);
addOutput(ragged);
addInput(in);
// Should the extents tensor be an input rather than an attribute?
addAttribute(extents);
}

std::string Partition::toString(int indent_size) const {
std::stringstream ss;
ss << "Partition: ";
ss << in()->toString();
ss << " by extents " << extents()->toString();
ss << " -> component: ";
ss << component()->toString();
ss << ", ragged: ";
ss << ragged()->toString();
ss << "\n";
return ss.str();
}

std::string Partition::toInlineString(int indent_size) const {
NVF_CHECK(false, "Partition can not be printed inline");
}

NVFUSER_DEFINE_CLONE_AND_CREATE(Partition)

Swizzle::Swizzle(
IrBuilderPasskey passkey,
IterDomain* out_x,
Expand Down
Loading
Loading