Skip to content

Conversation

@rdspring1
Copy link
Collaborator

torch::jit::toIValue converts from py::handle to IValue. Then KernelArgumentHolder converts IValue to PolymorphicValue. The PyTorch function handles several types, beyond what NvFuser supports, resulting in unnecessary latency. toPolymorphicValue converts py::handle directly to PolymorphicValue to save latency.

  • This is a prerequisite for using nanobind.

tests/python/direct/test_python_frontend.py

function time
toPolymorphicValue 41.04s
torch::jit::toIValue 41.33s

All python tests --- 7% improvement

function time
toPolymorphicValue 317.45s
torch::jit::toIValue 340.92s

@rdspring1 rdspring1 requested a review from jjsjann123 January 7, 2026 17:47
@rdspring1 rdspring1 added the Direct Bindings Python extension with direct mapping to NvFuser CPP objects. label Jan 7, 2026
@github-actions
Copy link

github-actions bot commented Jan 7, 2026

Review updated until commit a49a6c1

Description

  • Add toPolymorphicValue helper function for direct py::handle to PolymorphicValue conversion

  • Replace two-step conversion (py::handleIValuePolymorphicValue) with single-step conversion

  • Support multiple types: torch.Tensor, bool, int64_t, double, complex

  • Achieve 7% performance improvement in all python tests by reducing conversion latency

Changes walkthrough

Relevant files
Enhancement
direct_utils.cpp
Add direct py::handle to PolymorphicValue conversion         

python/python_direct/direct_utils.cpp

  • Add toPolymorphicValue function for direct conversion from py::handle
    to PolymorphicValue
  • Update from_pyiterable to use new direct conversion instead of
    torch::jit::toIValue
  • Handle torch.Tensor, bool, int64_t, double, and complex types
  • Remove dependency on two-step conversion process for improved
    performance
  • +22/-2   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Type Coverage

    The new toPolymorphicValue function only handles 5 specific types (Tensor, bool, int64_t, double, complex), while the previous torch::jit::toIValue with c10::AnyType::get() would accept any type. Need to verify that all types that were previously supported and actually used in practice are covered by the new implementation.

    PolymorphicValue toPolymorphicValue(const py::handle& obj) {
      static py::object torch_Tensor = py::module_::import("torch").attr("Tensor");
      if (py::isinstance(obj, torch_Tensor)) {
        return PolymorphicValue(py::cast<at::Tensor>(obj));
      } else if (py::isinstance<py::bool_>(obj)) {
        return PolymorphicValue(py::cast<bool>(obj));
      } else if (py::isinstance<py::int_>(obj)) {
        return PolymorphicValue(py::cast<int64_t>(obj));
      } else if (py::isinstance<py::float_>(obj)) {
        return PolymorphicValue(py::cast<double>(obj));
      } else if (PyComplex_Check(obj.ptr())) {
        return PolymorphicValue(py::cast<std::complex<double>>(obj));
      }
      NVF_THROW("Cannot convert provided py::handle to a PolymorphicValue.");
    }
    Complex Type Handling

    The complex number handling uses PyComplex_Check which is a CPython C API function. Since this PR is described as a prerequisite for using nanobind, verify that this approach is compatible with nanobind or if an alternative implementation is needed.

    } else if (PyComplex_Check(obj.ptr())) {
      return PolymorphicValue(py::cast<std::complex<double>>(obj));

    Test failures

    • (High, 49) NCCL NVLS multicast binding failures across multidevice/nvfuser test suites on dlcluster_viking_ci

      Test Name H100 (dist.) Source
      tests.python.multidevice.test_communication.test_allgather
      tests.python.multidevice.test_communication.test_allgather_expanded_broadcast
      tests.python.multidevice.test_communication.test_allreduce
      tests.python.multidevice.test_communication.test_reduce_scatter
      tests.python.multidevice.test_communication.test_reduce_scatter_noncontiguous
      tests.python.multidevice.test_dtensor.test_column_parallel_linear
      tests.python.multidevice.test_dtensor.test_plus_one
      tests.python.multidevice.test_dtensor.test_row_parallel_linear
      tests.python.multidevice.test_expert_parallel.test_dispatch_and_combine
      tests.python.multidevice.test_matmul.test_column_parallel_grouped_mm
      ... with 39 more test failures omitted. Check internal logs.
    • (Medium, 1) NCCL invalid usage in multidevice overlap all-gather tests (tests/python/multidevice/test_overlap.py)

      Test Name H100 (dist.) Source
      tests.python.multidevice.test_overlap.test_overlap_allgather_matmul_shard_outermost[backend_type=CommunicatorBackend.cuda]

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 7, 2026

    Greptile Overview

    Greptile Summary

    Replaces torch::jit::toIValue with a custom toPolymorphicValue function that converts Python objects directly to PolymorphicValue, eliminating the intermediate IValue conversion step. The new function handles torch.Tensor, bool, int, float, and complex types while properly checking bool before int to respect Python's type hierarchy. Performance measurements show a 7% improvement across all Python tests (317.45s vs 340.92s) and this change is a prerequisite for future nanobind migration.

    Confidence Score: 5/5

    • This PR is safe to merge with no identified issues
    • The implementation correctly handles all supported types (Tensor, bool, int64_t, double, complex) with proper type checking order. The bool check before int check correctly handles Python's type hierarchy where bool is a subclass of int. Static initialization of torch_Tensor is thread-safe in C++11+. The error handling with NVF_THROW is consistent with codebase patterns. Performance improvements are verified by test results.
    • No files require special attention

    Important Files Changed

    File Analysis

    Filename Score Overview
    python/python_direct/direct_utils.cpp 5/5 Replaces torch::jit::toIValue with custom toPolymorphicValue for direct Python-to-C++ conversion, improving performance by 7%

    Sequence Diagram

    sequenceDiagram
        participant Python as Python Code
        participant Converter as toPolymorphicValue
        participant Holder as KernelArgumentHolder
        
        Python->>Converter: py::handle (tensor/scalar)
        
        alt torch.Tensor
            Converter->>Converter: py::cast<at::Tensor>
        else bool
            Converter->>Converter: py::cast<bool>
        else int
            Converter->>Converter: py::cast<int64_t>
        else float
            Converter->>Converter: py::cast<double>
        else complex
            Converter->>Converter: py::cast<std::complex<double>>
        end
        
        Converter->>Holder: PolymorphicValue
        Note over Converter,Holder: Direct conversion without IValue intermediate
    
    Loading

    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Sweet improvements for such a trivial change.

    Looking forward seeing how nanobind can help with latency!

    @rdspring1 rdspring1 force-pushed the direct_pv_conversion branch from 1605983 to 761abfa Compare January 9, 2026 18:38
    @rdspring1
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Greptile Overview

    Greptile Summary

    This PR optimizes the Python to C++ type conversion process in the direct Python API by replacing torch::jit::toIValue with a custom toPolymorphicValue function. The new implementation eliminates the intermediate IValue conversion step and directly converts Python objects to PolymorphicValue, resulting in a 7% performance improvement across all Python tests.

    Key Changes:

    • Added toPolymorphicValue helper function in an anonymous namespace that handles direct conversion from py::handle to PolymorphicValue
    • Supports all PyTorch-compatible scalar types: torch.Tensor, bool, int64_t, double, and std::complex<double>
    • Type checking correctly orders bool before int to handle Python's type hierarchy where bool is a subclass of int
    • Uses static initialization for efficient repeated torch.Tensor type checking
    • Maintains the same error handling semantics, throwing NVF_THROW for unsupported types
    • The change is isolated to the direct Python API (python_direct) and does not affect public interfaces

    Compatibility:

    • The public API signature of from_pyiterable remains unchanged
    • All supported types in the old implementation are supported in the new implementation
    • Error handling is consistent with existing codebase patterns
    • This change is a prerequisite for future nanobind migration as noted in the PR description

    Confidence Score: 5/5

    • This PR is safe to merge with minimal risk - the implementation correctly handles all supported types with proper error handling and maintains API compatibility.
    • Score of 5 reflects: (1) Correct type checking order (bool before int) handling Python's type hierarchy, (2) All supported types properly handled (Tensor, bool, int64_t, double, complex), (3) Proper error handling with NVF_THROW for unsupported types, (4) Efficient static initialization for torch.Tensor type checking, (5) Unchanged public API signature maintaining backward compatibility, (6) Performance improvement of 7% without behavioral changes, (7) Isolated change to python_direct module with no side effects on other code, (8) Includes are properly handled through existing header dependencies.
    • No files require special attention

    Important Files Changed

    File Analysis

    Filename Score Overview
    python/python_direct/direct_utils.cpp 5/5 Replaced torch::jit::toIValue with a custom toPolymorphicValue function that directly converts py::handle to PolymorphicValue, improving performance by avoiding intermediate IValue conversion. Correctly handles all supported types (torch.Tensor, bool, int64_t, double, complex) with proper type checking order (bool before int) and appropriate error handling.

    Sequence Diagram

    sequenceDiagram
        participant Python
        participant PyIterable as from_pyiterable
        participant OldPath as Old Path
        participant NewPath as New Path
        participant KAH as KernelArgumentHolder
        
        Python->>PyIterable: py::iterable(scalars, tensors)
        
        rect rgb(200, 220, 255)
        Note over PyIterable,OldPath: Old Conversion: torch::jit::toIValue
        PyIterable->>OldPath: py::handle to IValue
        OldPath->>OldPath: IValueToPolymorphicValue
        OldPath->>KAH: push PolymorphicValue
        end
        
        rect rgb(220, 255, 220)
        Note over PyIterable,NewPath: New Conversion: toPolymorphicValue
        PyIterable->>NewPath: py::handle to PolymorphicValue
        NewPath->>NewPath: Direct type checks
        NewPath->>KAH: push PolymorphicValue
        end
    
    Loading

    @rdspring1 rdspring1 force-pushed the direct_pv_conversion branch from 761abfa to a49a6c1 Compare January 10, 2026 18:22
    @rdspring1
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @rdspring1 rdspring1 merged commit 5ae31f2 into main Jan 10, 2026
    61 checks passed
    @rdspring1 rdspring1 deleted the direct_pv_conversion branch January 10, 2026 21:42
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Direct Bindings Python extension with direct mapping to NvFuser CPP objects.

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants