Skip to content
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
d87e6d7
Initial introduction of RaggedIterDomain
naoyam Dec 12, 2025
77c6a07
Merge remote-tracking branch 'origin/main' into raggediterdomain_init…
naoyam Dec 12, 2025
f16fc4d
cleanup
naoyam Dec 12, 2025
23d55f1
fix
naoyam Dec 12, 2025
8392332
fix
naoyam Dec 12, 2025
787dfec
unit test
naoyam Dec 12, 2025
a0b40a3
cleanup
naoyam Dec 12, 2025
dbdd917
Fix IterVisitor
naoyam Dec 12, 2025
cdbd81e
cleanup
naoyam Dec 12, 2025
d4c8d7f
WIP: partition
naoyam Dec 12, 2025
9575a13
Partition expr
naoyam Dec 13, 2025
a054ae0
TensorView::partition
naoyam Dec 13, 2025
69dbe0f
cleanup
naoyam Dec 13, 2025
db3b359
Merge remote-tracking branch 'origin/main' into raggediterdomain_part…
naoyam Dec 13, 2025
2348dde
cleanup
naoyam Dec 13, 2025
7090b9c
WIP: asNested
naoyam Dec 13, 2025
b07e285
cleanup
naoyam Dec 13, 2025
a2c504b
asNested
naoyam Dec 15, 2025
b1d8cf4
warpdim
naoyam Dec 15, 2025
201c148
Make sure RaggedIterDomain is propagated to output tensors
naoyam Dec 17, 2025
9e0b161
Extend ops to be aware with RaggediterDomain
naoyam Dec 17, 2025
60a2dd5
RaggedIterDomain and reduction
naoyam Dec 17, 2025
566d63d
WIP
naoyam Dec 18, 2025
144b206
WIP
naoyam Dec 18, 2025
e2efe75
cleanup
naoyam Dec 18, 2025
0b68d6b
cleanup
naoyam Dec 18, 2025
8a73bb2
cleanup
naoyam Dec 18, 2025
550e0c5
Merge branch 'raggediterdomain_partition' into raggediterdomain-asnested
naoyam Dec 18, 2025
82bd85e
Merge remote-tracking branch 'origin/main' into raggediterdomain-asne…
naoyam Dec 18, 2025
f215f07
Use extents as a parameter
naoyam Dec 18, 2025
5b99432
Merge remote-tracking branch 'origin/main' into raggediterdomain-asne…
naoyam Dec 18, 2025
2dd9287
Merge branch 'raggediterdomain-asnested' into raggediterdomain_clone
naoyam Dec 18, 2025
c3aebec
combine
naoyam Dec 19, 2025
a22bb1f
Add tests
naoyam Dec 19, 2025
f521c38
WIP
naoyam Dec 19, 2025
8d0d9cb
don't hold component ID in RaggedIterDomain
naoyam Dec 19, 2025
67aac1b
Add design doc
naoyam Dec 19, 2025
3a80926
license
naoyam Dec 19, 2025
f75ecb6
Merge branch 'main' into raggediterdomain-asnested
naoyam Jan 7, 2026
8aa854e
feedback
naoyam Jan 7, 2026
72ae14f
fix
naoyam Jan 7, 2026
85d48df
Merge branch 'raggediterdomain-asnested' into raggediterdomain_clone
naoyam Jan 7, 2026
5f86d9c
Merge branch 'main' into raggediterdomain_clone
naoyam Jan 7, 2026
bf5b627
Merge remote-tracking branch 'origin/main' into raggediterdomain_clone
naoyam Jan 7, 2026
bec4c09
Merge remote-tracking branch 'origin/main' into raggediterdomain_clone
naoyam Jan 7, 2026
4d8acab
cleanup
naoyam Jan 9, 2026
3b082ba
cleanup
naoyam Jan 9, 2026
72dbc41
Merge branch 'raggediterdomain_clone' into ragged_combine
naoyam Jan 13, 2026
5002407
expand doc
naoyam Jan 13, 2026
be0e2ea
cleanup
naoyam Jan 13, 2026
05a6201
Merge remote-tracking branch 'origin/main' into ragged_combine
naoyam Jan 16, 2026
d2b5384
format
naoyam Jan 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class Val;
f(ScanOp); \
f(Merge); \
f(Partition); \
f(Combine); \
f(Swizzle); \
f(Swizzle2D); \
f(Resize); \
Expand Down
252 changes: 230 additions & 22 deletions csrc/ir/internal_base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

namespace nvfuser {

IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent)
: start_(_start), extent_(_extent) {
IterDomainBuilder::IterDomainBuilder(Val* start, Val* extent)
: start_(start), extent_(extent) {
NVF_ERROR(
start_ != nullptr && extent_ != nullptr,
"Start and extent are required to build an iter domain.");
Expand All @@ -48,7 +48,11 @@ IterDomainBuilder::IterDomainBuilder(const IterDomain* id)
is_rfactor_domain_(id->isRFactorProduct()),
is_padded_dimension_(id->hasPaddingToMultipleOfWarp()),
is_clustered_dimension_(id->isClusteredBlockDim()),
padded_to_size_(id->getMaybeSizeAfterPadding()) {}
padded_to_size_(id->getMaybeSizeAfterPadding()) {
if (id->isA<RaggedIterDomain>()) {
ragged_extents_ = id->as<RaggedIterDomain>()->extents();
}
}

IterDomainBuilder& IterDomainBuilder::resetSchedulingParams() {
parallel_type_ = ParallelType::Serial;
Expand All @@ -63,60 +67,72 @@ IterDomainBuilder& IterDomainBuilder::resetRfactor() {
return is_rfactor_domain(false);
}

IterDomainBuilder& IterDomainBuilder::start(Val* _start) {
start_ = _start;
IterDomainBuilder& IterDomainBuilder::start(Val* start) {
start_ = start;
return *this;
}

IterDomainBuilder& IterDomainBuilder::extent(Val* _extent) {
extent_ = _extent;
IterDomainBuilder& IterDomainBuilder::extent(Val* extent) {
extent_ = extent;
return *this;
}

IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* _expanded_extent) {
expanded_extent_ = _expanded_extent;
IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* expanded_extent) {
expanded_extent_ = expanded_extent;
return *this;
}

IterDomainBuilder& IterDomainBuilder::stop_offset(Val* _stop_offset) {
stop_offset_ = _stop_offset;
IterDomainBuilder& IterDomainBuilder::stop_offset(Val* stop_offset) {
stop_offset_ = stop_offset;
return *this;
}

IterDomainBuilder& IterDomainBuilder::parallel_type(
ParallelType _parallel_type) {
parallel_type_ = _parallel_type;
ParallelType parallel_type) {
parallel_type_ = parallel_type;
return *this;
}

IterDomainBuilder& IterDomainBuilder::iter_type(IterType _iter_type) {
iter_type_ = _iter_type;
IterDomainBuilder& IterDomainBuilder::iter_type(IterType iter_type) {
iter_type_ = iter_type;
return *this;
}

IterDomainBuilder& IterDomainBuilder::is_rfactor_domain(
bool _is_rfactor_domain) {
is_rfactor_domain_ = _is_rfactor_domain;
bool is_rfactor_domain) {
is_rfactor_domain_ = is_rfactor_domain;
return *this;
}

IterDomainBuilder& IterDomainBuilder::is_padded_dimension(
bool _is_padded_dimension) {
is_padded_dimension_ = _is_padded_dimension;
bool is_padded_dimension) {
is_padded_dimension_ = is_padded_dimension;
return *this;
}

IterDomainBuilder& IterDomainBuilder::padded_to_size(
std::optional<int64_t> _padded_to_size) {
padded_to_size_ = _padded_to_size;
std::optional<int64_t> padded_to_size) {
padded_to_size_ = padded_to_size;
return *this;
}

IterDomainBuilder& IterDomainBuilder::ragged_extents(
TensorView* ragged_extents) {
ragged_extents_ = ragged_extents;
return *this;
}

IterDomain* IterDomainBuilder::build() const {
NVF_ERROR(
start_ != nullptr && extent_ != nullptr,
"Start and extent are required to build an iter domain.");
return IrBuilder::createInContainer<IterDomain>(start_->container(), *this);

if (ragged_extents_ != nullptr) {
return IrBuilder::createInContainer<RaggedIterDomain>(
start_->container(), *this);
} else {
return IrBuilder::createInContainer<IterDomain>(start_->container(), *this);
}
}

IterDomain::IterDomain(
Expand Down Expand Up @@ -815,6 +831,77 @@ void validateLoopDomain(

} // namespace

RaggedIterDomain::RaggedIterDomain(
IrBuilderPasskey passkey,
const IterDomainBuilder& args)
: IterDomain(
passkey,
ValType::RaggedIterDomain,
args.start_,
args.extent_,
args.expanded_extent_,
args.stop_offset_,
args.parallel_type_,
args.iter_type_,
args.is_rfactor_domain_,
args.is_padded_dimension_,
args.is_clustered_dimension_,
args.padded_to_size_),
extents_(args.ragged_extents_) {
// Extents must be non-null
NVF_ERROR(
extents_ != nullptr, "RaggedIterDomain requires non-null extents tensor");

// Extents must have integer dtype
NVF_ERROR_EQ(
extents_->dtype(),
DataType::Index,
"RaggedIterDomain extents must have index type, got ",
extents_->dtype());

// Only IterType::Iteration is supported at this moment
NVF_ERROR_EQ(
iter_type_,
IterType::Iteration,
"Only IterType::Iteration is supported: ",
iter_type_);

// RaggedIterDomain has specific requirements on member values
NVF_ERROR(
start_->isZeroInt(),
"RaggedIterDomain start must be zero, got: ",
start_->toInlineString());

NVF_ERROR(
extent_->isOneInt(),
"RaggedIterDomain extent must be one (placeholder), got: ",
extent_->toInlineString());

NVF_ERROR(
expanded_extent_ == nullptr,
"RaggedIterDomain does not support expanded_extent");

NVF_ERROR(
stop_offset_ == nullptr || stop_offset_->isZeroInt(),
"RaggedIterDomain stop_offset must be nullptr or zero, got: ",
stop_offset_ ? stop_offset_->toInlineString() : "nullptr");

NVF_ERROR(
!is_rfactor_domain_, "RaggedIterDomain does not support rfactor domains");

NVF_ERROR(
!is_padded_dimension_,
"RaggedIterDomain does not support padded dimensions");

NVF_ERROR(
!is_clustered_dimension_,
"RaggedIterDomain does not support clustered dimensions");

NVF_ERROR(
!padded_to_size_.has_value(),
"RaggedIterDomain does not support padded_to_size");
}

RaggedIterDomain::RaggedIterDomain(
IrBuilderPasskey passkey,
TensorView* extents,
Expand Down Expand Up @@ -895,6 +982,20 @@ std::string RaggedIterDomain::toString(int indent_size) const {
return toInlineString(indent_size);
}

IterDomain* RaggedIterDomain::cloneWithoutRFactor(bool map_with_original) {
// Create a new RaggedIterDomain with the same extents and properties
auto cloned = IrBuilder::create<RaggedIterDomain>(
extents_, getIterType(), getParallelType());

// Optionally map the clone with the original in the Exact graph
if (map_with_original) {
// TODO: Implement mapping if needed
NVF_THROW("Not implemented");
}

return cloned;
}

std::pair<IterDomain*, RaggedIterDomain*> RaggedIterDomain::partition(
IterDomain* in,
TensorView* extents) {
Expand Down Expand Up @@ -953,6 +1054,106 @@ std::pair<IterDomain*, RaggedIterDomain*> RaggedIterDomain::partition(
return {component_id, ragged_id};
}

IterDomain* RaggedIterDomain::combine(
IterDomain* component,
RaggedIterDomain* ragged) {
NVF_ERROR(component != nullptr, "combine: component IterDomain is null");
NVF_ERROR(ragged != nullptr, "combine: ragged IterDomain is null");

NVF_ERROR(
!component->isA<RaggedIterDomain>(),
"combine: component must be a regular IterDomain, got RaggedIterDomain: ",
component->toString());

// Validate that component and ragged have compatible properties
NVF_ERROR_EQ(
component->getParallelType(),
ParallelType::Serial,
"Combining parallelized IterDomain not supported: ",
component->toString());

NVF_ERROR_EQ(
ragged->getParallelType(),
ParallelType::Serial,
"Combining parallelized RaggedIterDomain not supported: ",
ragged->toString());

NVF_ERROR_EQ(
component->getIterType(),
IterType::Iteration,
"combine: only IterType::Iteration is supported for component, got ",
component->getIterType(),
" for IterDomain: ",
component->toString());

NVF_ERROR_EQ(
ragged->getIterType(),
IterType::Iteration,
"combine: only IterType::Iteration is supported for ragged, got ",
ragged->getIterType(),
" for RaggedIterDomain: ",
ragged->toString());

// Validate component-ragged pairing when Partition definition is available
// (Option 3: Best-effort validation)
// Only validate when the RaggedIterDomain has a direct Partition definition.
// After propagation (e.g., set() operations), the definition may be nullptr,
// in which case we trust the user to provide the correct component.
if (ragged->definition() != nullptr &&
ragged->definition()->isA<Partition>()) {
auto* partition = ragged->definition()->as<Partition>();
IterDomain* expected_component = partition->component();

NVF_ERROR(
component == expected_component,
"combine: component mismatch. The provided component does not match ",
"the component from the Partition that created this "
"RaggedIterDomain.\n",
" Provided component: ",
component->toString(),
"\n",
" Expected component: ",
expected_component->toString());
}
// If no Partition definition (after set, in segmented fusion, or external
// input), trust the user and proceed without validation

// The combined extent is the sum of all extents in the ragged dimension
// For a 1D extents tensor [e0, e1, ..., en-1], the total is sum(extents)
TensorView* extents_tv = ragged->extents();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TensorView* extents_tv = ragged->extents();
TensorView* extents = ragged->extents();

The type already says it. Also, in the context of RaggedIterDomain, extents has to be a TensorView.

NVF_ERROR(extents_tv != nullptr, "combine: ragged extents tensor is null");

// It is still assumed the extents tensor is just 1D
NVF_ERROR_EQ(
std::ssize(extents_tv->getLogicalDomain()),
1,
"Unexpected rank of extent tensor: ",
extents_tv->toString());

auto container = component->container();
auto zero = container->zeroVal(DataType::Index);

// Create a symbolic extent for the combined IterDomain
// This represents the sum of all ragged extents, i.e.,
// sum(extents_tv, {0}). We could use the sum output as the extent
// but we would need to extract the scalar value out of the 0-dim
// tensor. For now, we leave it as a symbolic Val.
Val* combined_extent =
IrBuilder::createInContainer<Val>(container, DataType::Index);

// Create the combined IterDomain with the symbolic extent
IterDomain* combined_id = IterDomainBuilder(zero, combined_extent)
.parallel_type(ParallelType::Serial)
.iter_type(IterType::Iteration)
.build();

// Create the Combine expression linking component + ragged -> combined
IrBuilder::createInContainer<Combine>(
container, combined_id, component, ragged);

return combined_id;
}

TensorDomain::TensorDomain(
IrBuilderPasskey passkey,
std::vector<IterDomain*> logical_domain,
Expand Down Expand Up @@ -1472,6 +1673,13 @@ bool TensorDomain::hasVectorize() const {
});
}

bool TensorDomain::hasRaggedIterDomain() const {
return std::any_of(
logical().begin(), logical().end(), [](IterDomain* logical_id) {
return logical_id->isA<RaggedIterDomain>();
});
}

std::optional<int64_t> TensorDomain::getReductionAxis() const {
auto it = std::find_if(
loop_domain_.begin(), loop_domain_.end(), [](const auto& id) {
Expand Down
Loading
Loading