Skip to content

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Oct 6, 2025

In the greedy schedule so far, there's just a very naive limit on the usage of shared memory buffers (https://github.com/NVIDIA/Fuser/blob/main/csrc/scheduler/greedy.cpp#L703-L719). This PR improves the estimation of the buffer requirement. Specifically, it adds two utility functions to compute the required sizes of the block radix sort and block scan (computeBlockRadixSortTempStorageBytes and computeBlockScanTempStorageBytes), both of which are automatically generated with Cursor. They seem to make sense and the unit tests agree.

The scheduler is updated to use them whenever batching is supported. If batching is not supported (e.g., topk), each call still needs to be limited so that a single thread block can handle, i.e., the constrained size is allowed to be up to 1024.

Note: We are still using statically allocated buffers, but we should manage them ourselves and dynamically allocate or reuse buffers much like we do for other normal shared memory uses. It is not yet critical for our immediate problems. In either case, we would still need to know how much each CUB call would require, so the two util functions would still be necessary.

Closes #5044

@github-actions
Copy link

github-actions bot commented Oct 6, 2025

Review updated until commit 0ef2c6f

Description

  • Improved shared memory usage estimation for CUB operations

  • Added precise buffer size computation for BlockRadixSort and BlockScan

  • Enhanced scheduler constraints checking with accurate resource modeling

  • Fixed Host IR JIT configuration and build system integration


Changes walkthrough 📝

Relevant files
Bug fix
1 files
jit.cpp
Fix size comparison in tensor shape inference                       
+2/-2     
Configuration changes
7 files
options.cpp
Remove deprecated HostIrJit option                                             
+0/-1     
fusion_kernel_runtime.cpp
Conditional compilation for Host IR JIT                                   
+18/-15 
test_host_ir_integration.cpp
Remove HostIrJit from test setup                                                 
+0/-1     
utils.py
Add build_with_host_ir_jit configuration                                 
+11/-0   
options.h
Remove deprecated HostIrJit option                                             
+0/-1     
fusion_kernel_runtime.h
Conditional inclusion of HostIrJit                                             
+12/-9   
CMakeLists.txt
Conditional compilation for Host IR JIT                                   
+46/-35 
Enhancement
4 files
greedy.cpp
Enhanced shared memory buffer usage checking                         
+115/-34
cub_utils.cpp
Implement CUB shared memory size computation                         
+375/-0 
scan.cu
Use BLOCK_SCAN_WARP_SCANS algorithm                                           
+5/-2     
cub_utils.h
Declare CUB shared memory buffer utilities                             
+95/-0   
Tests
5 files
test_argsort.cpp
Add shared memory requirement tests for argsort                   
+123/-0 
test_greedy.cpp
Add shared memory size tests for greedy scheduler               
+199/-0 
test_host_ir_jit.cpp
Simplify HostIrJit test initialization                                     
+4/-9     
test_scan.cpp
Add shared memory requirement tests for scan                         
+135/-0 
test_topk.cpp
Add shared memory requirement tests for topk                         
+118/-0 
Documentation
1 files
host_ir_jit.md
Update Host IR JIT documentation                                                 
+14/-11 

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Possible Issue

The checkDomainConstraints function now returns the size of constrained IDs in bytes, but the value is used as if it represents the number of elements. This could lead to incorrect calculations when checking thread limits and register usage.

int64_t checkDomainConstraints(
    const std::vector<IterDomain*>& domain,
    const std::vector<int64_t>& constrained_id_offsets,
    int64_t bytes_per_element,
    bool support_batching = false) {
  int64_t size_of_constrained_ids = 1;
  for (const auto i : constrained_id_offsets) {
    auto constrained_id = domain.at(i);
    auto extent_val = runtime_info_.expressionEvaluator().evaluate(
        constrained_id->extent());
    NVF_ERROR(
        extent_val.hasValue(),
        "Cannot infer the extent of a constrained ID: ",
        constrained_id->toString());
    size_of_constrained_ids *= extent_val.as<int64_t>();
  }

  const int64_t threads_per_block = max_threads_per_block_;

  // At this moment, not all constrained ops supports batching. If
  // batching is not supported, the limit is simply set as the
  // maximum number of threads per thread block. This is likely
  // a sufficient condition even for shared memory, although not
  // guaranteed.
  if (!support_batching) {
    if (size_of_constrained_ids > threads_per_block) {
      reject(
          "Extent of constrained logical IDs, ",
          size_of_constrained_ids,
          ", exceeds the number of threads per thread block: ",
          threads_per_block);
    }
  }

  // The maximum supported size depends on several factors. The hard
  // limit is the shared memory capacity since the kernel launch
  // would just fail if the shared memory usage exceeds the
  // available size. It is checked at the end of the RunTimeChecker
  // constructor.
  //
  // The next important limit would be the register usage as we
  // would not want to have excessive register spilling. The
  // register usage would be linearly correlated with the batching
  // factor. For now, just put a simple upper limit to avoid
  // disastrous regressions. Fine tuning would be necessary.
  const int64_t register_count_per_thread =
      ceilDiv(size_of_constrained_ids, threads_per_block) *
      bytes_per_element / 4;
  const int64_t available_register_count_per_thread =
      at::cuda::getCurrentDeviceProperties()->regsPerBlock /
      threads_per_block;
  // Make sure at least 20 registers are always available
  const int64_t reserved_register_count_per_thread = 20;
  if (register_count_per_thread + reserved_register_count_per_thread >
      available_register_count_per_thread) {
    reject(
        "Expected register usage, ",
        register_count_per_thread,
        ", exceeds the available count, ",
        available_register_count_per_thread);
  }

  return size_of_constrained_ids;
Possible Issue

The register usage calculation appears to be incorrect as it uses size_of_constrained_ids (in bytes) directly in the calculation, which would overestimate register usage since it doesn't account for the actual data type size.

const int64_t register_count_per_thread =
    ceilDiv(size_of_constrained_ids, threads_per_block) *
    bytes_per_element / 4;
const int64_t available_register_count_per_thread =
    at::cuda::getCurrentDeviceProperties()->regsPerBlock /
    threads_per_block;
// Make sure at least 20 registers are always available
const int64_t reserved_register_count_per_thread = 20;
if (register_count_per_thread + reserved_register_count_per_thread >
    available_register_count_per_thread) {
  reject(
      "Expected register usage, ",
      register_count_per_thread,
      ", exceeds the available count, ",
      available_register_count_per_thread);
}
Possible Issue

The computeBlockRadixSortTempStorageBytes function uses a hardcoded value of 36 bytes per thread for rank_aliasable_bytes, but this value is specific to RADIX_BITS=4 and may not be accurate for other configurations.

const int64_t rank_aliasable_bytes = int64_t{36} * block_threads;


void handle(TopKOp* topk) override {
checkDomainConstraints(
ir_utils::getTvOutput(topk)->getLogicalDomain(), {topk->dim()});
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use of the output here was wrong. Need to check the input as it has a larger extent.

@naoyam
Copy link
Collaborator Author

naoyam commented Oct 7, 2025

!test --diff

@naoyam
Copy link
Collaborator Author

naoyam commented Oct 7, 2025

!test --diff

@naoyam naoyam force-pushed the greedy_scheduler_shmem_buffer_size branch from 1876804 to 9855676 Compare October 7, 2025 07:42
@naoyam
Copy link
Collaborator Author

naoyam commented Oct 7, 2025

!test --diff

// may be too strict and fragile as a test. After all, we would
// just need a reasonably tight upper bound. Consider relaxing the
// condition if necessary.
EXPECT_EQ(expected_size, ke.getStaticSmemSize())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👏 impressive job by the agent!

// It doesn't seem consistent whether compilation or launch should
// fail if the requirement of static shared memory exceeds the default
// limit but within the opt-in larger limit. As we should move to
// dynamic allocaitons anyway, don't assert for now.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// dynamic allocaitons anyway, don't assert for now.
// dynamic allocations anyway, don't assert for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, I think we should skip on these since we didn't run anything.

@naoyam
Copy link
Collaborator Author

naoyam commented Oct 16, 2025

!build

@naoyam naoyam merged commit effc230 into main Oct 16, 2025
14 of 15 checks passed
@naoyam naoyam deleted the greedy_scheduler_shmem_buffer_size branch October 16, 2025 22:26
wujingyue pushed a commit that referenced this pull request Oct 22, 2025
It got accidentally reverted by #5328
@wujingyue wujingyue mentioned this pull request Oct 22, 2025
wujingyue added a commit that referenced this pull request Oct 22, 2025
It got accidentally reverted by #5328

Co-authored-by: Michael Davis <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Ensure shared memory capacity is sufficient

4 participants