Skip to content

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Sep 24, 2025

For #5289

@github-actions
Copy link

github-actions bot commented Sep 25, 2025

Review updated until commit daf57cf

Description

  • Add support for stream-based for loops in kernel launch

  • Filter out stream IterDomains in tensor domain operations

  • Improve multidevice domain finalization with stream handling

  • Update test infrastructure for stream parallelization


Changes walkthrough 📝

Relevant files
Enhancement
7 files
lowering.cpp
Add stream ID detection and for-loop wrapping for kernel launch
+38/-8   
index_compute.cpp
Exclude stream domains from static indexing                           
+2/-1     
nodes.cpp
Add noStream filter for tensor domains                                     
+25/-18 
finalize_multidevice_domains.cpp
Reorder stream and device IDs to front                                     
+3/-17   
pointwise.cpp
Exclude stream domains in pointwise heuristics                     
+3/-2     
transform_replay.cpp
Parallelize IDs with stream type                                                 
+2/-0     
internal_base_nodes.h
Add noStream utility and kNoStreams view                                 
+5/-2     
Bug fix
1 files
utils.cpp
Include stream dims in device count                                           
+5/-1     
Formatting
1 files
test_multidevice_sharding.cpp
Update test includes and using declarations                           
+0/-1     
Tests
1 files
test_stream.cpp
Add test for per-stream kernel execution                                 
+20/-21 

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Possible Issue

The function findStreamIterDomain only checks the front of the loop domain for a stream IterDomain, based on the assumption that the FinalizeMultideviceDomains pass places it there. However, if this assumption is violated or not consistently enforced, the function may fail to detect a stream IterDomain that exists elsewhere in the loop domain, leading to incorrect lowering behavior.

IterDomain* findStreamIterDomain(const std::vector<Val*>& outs) {
  for (auto* out : ir_utils::filterByType<TensorView>(outs)) {
    const std::vector<IterDomain*>& loop = out->getLoopDomain();
    // FinalizeMultideviceDomains pass puts the stream IterDomain to the
    // front.
    if (!loop.empty() && loop.front()->isStream()) {
      return loop.front();
    }
  }
  return nullptr;
}
Logic Error

The numDeviceDims function now includes stream IterDomains in its count when they are not reductions. This could lead to incorrect device dimension counts if stream dimensions are not intended to be treated as device dimensions in this context, potentially affecting downstream logic that relies on accurate device dimension counts.

int64_t numDeviceDims(const TensorView* tv) {
  return std::count_if(
      tv->getLoopDomain().begin(),
      tv->getLoopDomain().end(),
      [](IterDomain* id) {
        return (id->isDeviceDim() || id->isStream()) && !id->isReduction();
      });
Possible Issue

The getPointwiseHeuristics function now filters out stream IterDomains when computing ref_loop, but it is unclear if this is the intended behavior for all pointwise operations. Removing stream dimensions at this stage might interfere with correct scheduling or code generation for operations that should account for stream parallelism.

ref_loop = TensorDomain::noStream(
    TensorDomain::noDevices(TensorDomain::noReductions(ref_loop)));

wujingyue added a commit that referenced this pull request Sep 25, 2025
NFC. Needed by #5229.

This way, I can pass round only the map without the containing fusion.
wujingyue added a commit that referenced this pull request Sep 26, 2025
NFC. Needed by #5229.

This way, I can pass round only the map without the containing fusion.
@wujingyue wujingyue force-pushed the wjy/stream branch 2 times, most recently from 744f167 to 003e429 Compare September 29, 2025 00:59
@wujingyue
Copy link
Collaborator Author

!test

@wujingyue wujingyue changed the base branch from main to wjy/order September 29, 2025 01:40
wujingyue added a commit that referenced this pull request Sep 29, 2025
This PR changed reorderDIDToFront to reorder all parallelized dimensions
to front. This is less controversial than I expected because currently
we only call reorderDIDToFront before intra-GPU scheduling kicks in.

Needed by #5229
Base automatically changed from wjy/order to main September 29, 2025 18:22
@Priya2698
Copy link
Collaborator

@wujingyue it might be good to merge some of the changes in this or the matmul PR so we can actually test #5309 end-to-end. That is, once we update FinalizeMultideviceDomainsPass, the example in this PR works with no additional changes.

Do you think that is feasible?

@wujingyue
Copy link
Collaborator Author

Do you think that is feasible?

Sure. Which parts do you need? Feel free to cherry-pick if it's faster to do it yourself.

@Priya2698
Copy link
Collaborator

Do you think that is feasible?

Sure. Which parts do you need?

My idea was to push the changes in lowering.cpp ahead of that so I can try running the test example with #5309.

Feel free to cherry-pick if it's faster to do it yourself.

Sounds good. Once I have a draft ready, we can try doing the above, or cherry pick things into that PR.

@wujingyue
Copy link
Collaborator Author

push the changes in lowering.cpp

Sure! I'm waiting for #5323 to be merged. Meanwhile, feel free to cherry-pick.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants