-
Notifications
You must be signed in to change notification settings - Fork 74
Add utility ir_utils::resetContiguityFromTensor
#5766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
!test |
|
Review updated until commit b7778bd Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Enhancement |
| ||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ No major issues detected |
|
!test |
|
!test |
1 similar comment
|
!test |
|
!test |
Greptile SummaryAdded utility function Key changes:
Testing coverage: The test validates all 2^3=8 combinations of contiguity patterns for a 5-dimension allocation domain Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Client
participant resetContiguityFromTensor
participant inferAllocationSizesAndStrides
participant TensorView
participant TensorDomain
Client->>resetContiguityFromTensor: Call with TensorView & at::Tensor
alt tensor not defined
resetContiguityFromTensor->>Client: Early return
end
resetContiguityFromTensor->>inferAllocationSizesAndStrides: Get sizes & strides
Note over inferAllocationSizesAndStrides: Returns vectors excluding<br/>reduction dimensions
inferAllocationSizesAndStrides-->>resetContiguityFromTensor: (sizes, strides)
resetContiguityFromTensor->>TensorView: getMaybeAllocationDomain()
TensorView-->>resetContiguityFromTensor: alloc domain (all dims)
Note over resetContiguityFromTensor: Initialize contiguity vector<br/>(all nullopt)
loop Right-to-left traversal (alloc_idx)
alt Reduction dimension
Note over resetContiguityFromTensor: Keep nullopt<br/>Don't consume sizes/strides
else Broadcast dimension
Note over resetContiguityFromTensor: Keep nullopt<br/>Consume sizes/strides entry<br/>(decrement sizes_idx)
else Regular iteration dimension
alt Rightmost non-skipped dim
Note over resetContiguityFromTensor: Check: stride == 1
else Other dimensions
Note over resetContiguityFromTensor: Check: stride ==<br/>prev_stride * prev_size
end
Note over resetContiguityFromTensor: Set contiguity[alloc_idx]<br/>Update prev_non_skipped_sizes_idx
end
end
Note over resetContiguityFromTensor: Validate all sizes/strides<br/>consumed (sizes_idx == -1)
resetContiguityFromTensor->>TensorView: setContiguity(contiguity)
TensorView->>TensorDomain: setContiguity(contiguity)
resetContiguityFromTensor->>Client: Return
|
naoyam
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
!test |
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
| // A dimension is contiguous if its stride equals the stride of the | ||
| // next dimension multiplied by that dimension's size | ||
| contiguity[alloc_idx] = | ||
| (strides[sizes_idx] == | ||
| strides[prev_non_skipped_sizes_idx] * | ||
| sizes[prev_non_skipped_sizes_idx]); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Consider handling size-0 dimensions explicitly. When sizes[prev_non_skipped_sizes_idx] == 0, the expected stride is 0 regardless of actual stride (which is undefined for size-0 dims in PyTorch). Current logic may incorrectly mark such dimensions as non-contiguous.
| // A dimension is contiguous if its stride equals the stride of the | |
| // next dimension multiplied by that dimension's size | |
| contiguity[alloc_idx] = | |
| (strides[sizes_idx] == | |
| strides[prev_non_skipped_sizes_idx] * | |
| sizes[prev_non_skipped_sizes_idx]); | |
| } | |
| } else { | |
| // A dimension is contiguous if its stride equals the stride of the | |
| // next dimension multiplied by that dimension's size | |
| // Special case: size-0 dimensions have undefined stride semantics | |
| if (sizes[prev_non_skipped_sizes_idx] == 0) { | |
| contiguity[alloc_idx] = std::nullopt; | |
| } else { | |
| contiguity[alloc_idx] = | |
| (strides[sizes_idx] == | |
| strides[prev_non_skipped_sizes_idx] * | |
| sizes[prev_non_skipped_sizes_idx]); | |
| } |
This PR just adds a utility and its test. This utility will be used in #5772, please see #5772 for the context.