Skip to content

Conversation

@zasdfgbnm
Copy link
Collaborator

No description provided.

@github-actions
Copy link

github-actions bot commented Jan 7, 2026

Review updated until commit 4afe5b1

Description

  • Introduce new infer-contiguity enable option for runtime contiguity inference

  • Rename inferOutputShapeAndContiguousStrides to inferContiguousOutputMetaTensor

  • Add inferOutputMetaTensor function to handle expr-eval segments using meta device tensors

  • Add updateContiguityOfSegmentOutputs function to update tensor contiguity at runtime

  • Enable contiguity inference in all test files for comprehensive coverage

Changes walkthrough

Relevant files
Configuration changes
2 files
options.cpp
Add infer-contiguity enable option mapping                             
+1/-0     
options.h
Add InferContiguity enum value                                                     
+1/-0     
Enhancement
5 files
allocations.cpp
Rename function to inferContiguousOutputMetaTensor             
+2/-2     
allocations.h
Update function signature for renamed allocation function
+1/-1     
fusion_kernel_runtime.cpp
Add output meta tensor inference and contiguity update functions
+69/-8   
fusion_kernel_runtime.h
Declare new output inference and contiguity update methods
+21/-0   
composite_nodes.cpp
Add conditional contiguity handling in MatmulOp evaluation
+12/-8   
Miscellaneous
1 files
fusion_cache_utils.cpp
Add missing include for ir utils                                                 
+1/-0     
Tests
7 files
test_indexing_advanced.cpp
Enable infer-contiguity option in test setup                         
+2/-0     
test_loop_domain_scheduling.cpp
Enable infer-contiguity option in test setup                         
+1/-0     
test_matmul_scheduler.cpp
Enable infer-contiguity option in test setup                         
+1/-0     
test_pointwise.cpp
Enable infer-contiguity option in test setup                         
+1/-0     
test_rng.cpp
Enable infer-contiguity option in test setup                         
+1/-0     
utils.cpp
Enable infer-contiguity option in test setup                         
+1/-0     
test_python_frontend.py
Add test case for issue 4888 with infer-contiguity option
+99/-0   

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
Performance Impact

The new inferOutputMetaTensor method creates meta device tensors for expr-eval segments. While this provides better contiguity inference, we should verify that the performance overhead of creating these meta tensors is acceptable, especially for small segments or when the option is disabled.

KernelArgumentHolder FusionKernelRuntime::inferOutputMetaTensor(
    HeuristicParamsList* heuristics,
    SegmentedGroup* group_to_run,
    const KernelArgumentHolder& group_runtime_inputs,
    PrecomputedValues* evaluator_precomputed_values) const {
  FUSER_PERF_SCOPE("FusionKernelRuntime::inferOutputMetaTensor");
  NVF_ERROR(heuristics != nullptr);
  Fusion* fusion_to_run = group_to_run->getFusion();
  KernelArgumentHolder group_runtime_outputs;
  const auto& heuristic_params = heuristics->at(group_to_run->groupId());
  const bool is_expr_eval =
      heuristic_params->scheduler_type == SchedulerType::ExprEval;
  if (is_expr_eval) {
    // For expr evaluated fusion, the striding rules follow that of ATen.
    ExpressionEvaluator eval_fusion;
    for (auto i : arange(group_to_run->inputs().size())) {
      const auto& tensor_pv = group_runtime_inputs[i];
      if (tensor_pv.is<at::Tensor>()) {
        const auto& t = tensor_pv.as<at::Tensor>();
        if (t.defined()) {
          const auto meta_t = at::empty_strided(
              t.sizes(),
              t.strides(),
              at::TensorOptions().device(at::kMeta).dtype(t.dtype()));
          eval_fusion.bind(fusion_to_run->inputs()[i], meta_t);
        } else {
          eval_fusion.bind(fusion_to_run->inputs()[i], t);
        }
      } else {
        eval_fusion.bind(fusion_to_run->inputs()[i], tensor_pv);
      }
    }
    for (auto v : fusion_to_run->outputs()) {
      auto result = eval_fusion.evaluate(v);
      group_runtime_outputs.push(result);
    }
  } else {
    auto fusion_to_run = group_to_run->getFusion();
    return inferContiguousOutputMetaTensor(
        fusion_to_run, group_runtime_inputs, evaluator_precomputed_values);
  }
  return group_runtime_outputs;
}
Conditional Logic

The MatmulOp::evaluate method now has conditional logic based on EnableOption::InferContiguity. We should ensure this doesn't introduce any unexpected behavior when the option is toggled, and that the fallback behavior is correct.

// Without InferContiguity, we mistakenly assume the output is contiguous.
if (!isOptionEnabled(EnableOption::InferContiguity)) {
  const auto& [sizes, strides] = inferShapeAndContiguousStrides(out(), ee);
  auto meta_out = at::detail::empty_strided_meta(sizes, strides, a.dtype());

  if (meta_out.is_contiguous()) {
    return {matmul_out};
  }

  auto strided_matmul_out = at::empty_strided(sizes, strides, a.options());
  strided_matmul_out = strided_matmul_out.copy_(matmul_out);
  return {strided_matmul_out};
}
return {matmul_out};
Test Coverage

The new test test_issue4888 exercises the infer-contiguity functionality. We should verify this test adequately covers the edge cases and that the test passes consistently across different GPU architectures.

def test_issue4888():
    # https://github.com/NVIDIA/Fuser/issues/4888
    def nvfuser_fusion_id2(fd: FusionDefinition) -> None:
        T0 = fd.define_tensor(
            shape=[4096, 4097],
            contiguity=[True, True],
            dtype=DataType.BFloat16,
            is_cpu=False,
            stride_order=[1, 0],
        )
        T1 = fd.define_tensor(
            shape=[4096, 4097],
            contiguity=[True, True],
            dtype=DataType.Bool,
            is_cpu=False,
            stride_order=[1, 0],
        )
        T2 = fd.define_tensor(
            shape=[4096, 4097],
            contiguity=[True, True],
            dtype=DataType.Bool,
            is_cpu=False,
            stride_order=[1, 0],
        )
        T3 = fd.define_tensor(
            shape=[1, 32, 4096, 4096],
            contiguity=[None, True, True, True],
            dtype=DataType.BFloat16,
            is_cpu=False,
            stride_order=[3, 2, 1, 0],
        )
        T4 = fd.ops.cast(T0, dtype=DataType.Float)
        T5 = fd.ops.bitwise_or(T1, T2)
        T6 = fd.ops.set(T5)
        fd.add_output(T6, T1)
        T7 = fd.ops.cast(T6, dtype=DataType.Float)
        T8 = fd.ops.mul(T4, T7)
        T9 = fd.ops.cast(T8, dtype=DataType.BFloat16)
        T10 = fd.ops.set(T9)
        fd.add_output(T10, T0)
        T15 = fd.ops.broadcast_in_dim(T10, shape=[1, 4096, 4097], broadcast_dims=[1, 2])
        T21 = fd.ops.broadcast_in_dim(
            T15, shape=[1, 1, 4096, 4097], broadcast_dims=[0, 2, 3]
        )
        T27 = fd.ops.broadcast_in_dim(
            T21, shape=[1, 1, 4096, 4097], broadcast_dims=[0, 1, 2, 3]
        )
        T43 = fd.ops.slice(
            T27,
            start_indices=[0, 0, 0, 0],
            end_indices=[1, 1, 4096, 4096],
            strides=[1, 1, 1, 1],
            manual_normalization=0,
        )
        T49 = fd.ops.broadcast_in_dim(
            T43, shape=[1, 32, 4096, 4096], broadcast_dims=[0, 1, 2, 3]
        )
        T50 = fd.ops.cast(T49, dtype=DataType.Float)
        T51 = fd.ops.cast(T3, dtype=DataType.Float)
        S52 = fd.define_scalar(0.0883883, dtype=DataType.Double)
        T53 = fd.ops.mul(T51, S52)
        T54 = fd.ops.add(T53, T50)
        T55 = fd.ops.max(T54, dims=[3], keepdim=False, dtype=DataType.Null)
        T61 = fd.ops.broadcast_in_dim(
            T55, shape=[1, 32, 4096, 1], broadcast_dims=[0, 1, 2]
        )
        T67 = fd.ops.broadcast_in_dim(
            T61, shape=[1, 32, 4096, 4096], broadcast_dims=[0, 1, 2, 3]
        )
        T68 = fd.ops.sub(T54, T67)
        T69 = fd.ops.exp(T68)
        T70 = fd.ops.sum(T69, dims=[3], keepdim=False, dtype=DataType.Null)
        T76 = fd.ops.broadcast_in_dim(
            T70, shape=[1, 32, 4096, 1], broadcast_dims=[0, 1, 2]
        )
        T82 = fd.ops.broadcast_in_dim(
            T76, shape=[1, 32, 4096, 4096], broadcast_dims=[0, 1, 2, 3]
        )
        T83 = fd.ops.reciprocal(T82)
        T84 = fd.ops.mul(T69, T83)
        T85 = fd.ops.cast(T84, dtype=DataType.BFloat16)
        fd.add_output(T49)
        fd.add_output(T84)
        fd.add_output(T85)

    with FusionDefinition() as fd:
        nvfuser_fusion_id2(fd)

    inputs = [
        torch.testing.make_tensor((4096, 4097), dtype=torch.bfloat16, device="cuda:0"),
        torch.testing.make_tensor((4096, 4097), dtype=torch.bool, device="cuda:0"),
        torch.testing.make_tensor((4096, 4097), dtype=torch.bool, device="cuda:0"),
        torch.testing.make_tensor(
            (1, 32, 4096, 4096), dtype=torch.bfloat16, device="cuda:0"
        ),
    ]
    fd.execute(inputs, _enable_options=["infer-contiguity"])

Test failures (partial, pipeline still running)

  • (High, 6) nvFuser evaluator_common internal assert in multidevice overlap (row_parallel_linear_forward) tests

    Test Name A100 (dist.) GB200 (dist.) Source
    tests.python.multidevice.test_overlap.test_row_parallel_linear_forward
    tests.python.multidevice.test_overlap.test_row_parallel_linear_forward_benchmark[s=2]
    tests.python.multidevice.test_overlap.test_row_parallel_linear_forward_benchmark[s=4]
  • (Medium, 16) NVFuser internal assert on block_scaling_factor contiguity in BlockQuantizationSchedulingTestSuite

    Test Name GB200 Source
    BlockQuantizationSchedulingTestSuite/BlockQuantizationSchedulingTest.AutoScheduleSingleOp/__bfloat_1024x1024_NoGlobalScale_WithSwizzle Link
    BlockQuantizationSchedulingTestSuite/BlockQuantizationSchedulingTest.AutoScheduleSingleOp/__bfloat_1024x1024_WithGlobalScale_WithSwizzle Link
    BlockQuantizationSchedulingTestSuite/BlockQuantizationSchedulingTest.AutoScheduleSingleOp/__bfloat_128x64_NoGlobalScale_WithSwizzle Link
    BlockQuantizationSchedulingTestSuite/BlockQuantizationSchedulingTest.AutoScheduleSingleOp/__bfloat_128x64_WithGlobalScale_WithSwizzle Link
    BlockQuantizationSchedulingTestSuite/BlockQuantizationSchedulingTest.AutoScheduleSingleOp/__bfloat_2048x128_NoGlobalScale_WithSwizzle Link
    BlockQuantizationSchedulingTestSuite/BlockQuantizationSchedulingTest.AutoScheduleSingleOp/__bfloat_2048x128_WithGlobalScale_WithSwizzle Link
    BlockQuantizationSchedulingTestSuite/BlockQuantizationSchedulingTest.AutoScheduleSingleOp/__bfloat_2048x2048_NoGlobalScale_WithSwizzle Link
    BlockQuantizationSchedulingTestSuite/BlockQuantizationSchedulingTest.AutoScheduleSingleOp/__bfloat_2048x2048_WithGlobalScale_WithSwizzle Link
    BlockQuantizationSchedulingTestSuite/BlockQuantizationSchedulingTest.AutoScheduleSingleOp/float_1024x1024_NoGlobalScale_WithSwizzle Link
    BlockQuantizationSchedulingTestSuite/BlockQuantizationSchedulingTest.AutoScheduleSingleOp/float_1024x1024_WithGlobalScale_WithSwizzle Link
    ... with 6 more test failures omitted. Check internal logs.
  • (Medium, 8) nvFuser stride/contiguity mismatch in multidevice SDPA & Transformer tests (BSHE layout)

    Test Name A100 (dist.) GB200 (dist.) Source
    tests.python.multidevice.test_multidevice.test_sdpa[qkv_format=QkvFormat.BSHE]
    tests.python.multidevice.test_multidevice.test_sdpa_loop_split[qkv_format=QkvFormat.BSHE]
    tests.python.multidevice.test_transformer.test_transformer_forward[SEQUENCE_PARALLEL]
    tests.python.multidevice.test_transformer.test_transformer_forward[TENSOR_PARALLEL]
  • (Medium, 6) nvFuser allocation-domain assertion failures in LayoutOpTest suite

    Test Name A100 GB200 Source
    LayoutOpTest.SchedulerKernel Link
    LayoutOpTest.SchedulerKernelWithExplicitQuantizationPattern Link
    LayoutOpTest.SchedulerKernelWithOffsetsProducer Link
  • (Medium, 2) nvFuser segmentation logic mismatch in RevertPrivatizedUpcast test suite

    Test Name A100 GB200 Source
    SegmentationTest.RevertPrivatizedUpcast Link
  • (Medium, 2) nvFuser AliasTest aliasing mismatch in test_alias.cpp across multiple runners

    Test Name A100 GB200 Source
    AliasTest.AliasOutputBeforeNonAliasOutput Link
  • (Medium, 2) nvFuser matmul output stride mismatch in MatmulNodeTest

    Test Name A100 GB200 Source
    MatmulNodeTest.OutputStrides Link

zasdfgbnm and others added 13 commits January 13, 2026 09:52
…andling

- Renamed `inferOutputShapeAndContiguousStrides` to `inferContiguousOutputMetaTensor` for clarity.
- Updated function signatures to remove unnecessary parameters.
- Introduced `inferOutputMetaTensor` in `FusionKernelRuntime` to handle output shape inference for segmented groups.
- Enhanced `updateWithSegmentOutputs` to streamline output management without updating contiguity directly.
- Improved overall code organization and readability.
@zasdfgbnm
Copy link
Collaborator Author

!test

@zasdfgbnm
Copy link
Collaborator Author

!test

@zasdfgbnm
Copy link
Collaborator Author

!test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants