Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

No description provided.

@github-actions
Copy link

github-actions bot commented Nov 25, 2025

Review updated until commit 5cb25f3

Description

  • Implement TMA inner persistent scheduler for normalization operations

  • Split normalization inner operations into TMA and non-TMA paths

  • Add manual scheduling optimizations for dynamic block dimensions

  • Implement vectorized smem2regs operations for improved memory access

  • Add RMS norm and layer norm implementations with workspace support

Changes walkthrough

Relevant files

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
Missing Error Handling

In the computeHeuristics method, when TMA is enabled but tma_params is nullptr, the code will crash on line 184. Consider adding a check for nullptr before dereferencing.

auto tma_params = normalization_inner::tma::getInnerPersistentHeuristics(
    fusion, runtime_info, data_cache);
NVF_ERROR(tma_params != nullptr);
return tma_params;
Potential Memory Leak

The dynamic_cast operations in schedule method could fail and return nullptr, but there's no nullptr check before using the cast result. This could lead to undefined behavior.

auto rparams = dynamic_cast<const ReductionParams*>(params);
NVF_ERROR(
    rparams != nullptr && rparams->scheduler_type == schedulerType(),
    "Incorrect parameters sent to InnerPersistentKernelScheduler::schedule",
    params);
NVF_ERROR(
    rparams->scheduler_type ==
    InnerPersistentKernelScheduler::schedulerType());
normalization_inner::non_tma::scheduleInnerPersistent(fusion, rparams);
Division by Zero Risk

In getRegisterSharing function, there's a division by padded_threads on line 162. If padded_threads is 0, this will cause a division by zero error. Add validation to prevent this.

      (reg_per_thread - tma_branch_regs) * padded_threads / computation_threads;
  if (compute_branch_regs % regs_granularity != 0) {
    compute_branch_regs -= compute_branch_regs % regs_granularity;
    tma_branch_regs = reg_per_thread -
        (compute_branch_regs - reg_per_thread) * computation_threads /
            padded_threads;
  }
  compute_branch_regs =
      std::min(compute_branch_regs, scheduler_utils::max_registers_per_thread);
  return std::make_pair(tma_branch_regs, compute_branch_regs);
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants