diff --git a/python/python_direct/direct_utils.cpp b/python/python_direct/direct_utils.cpp index 8ce601d5c53..6a534c539bb 100644 --- a/python/python_direct/direct_utils.cpp +++ b/python/python_direct/direct_utils.cpp @@ -11,6 +11,26 @@ namespace nvfuser::python { +namespace { + +PolymorphicValue toPolymorphicValue(const py::handle& obj) { + static py::object torch_Tensor = py::module_::import("torch").attr("Tensor"); + if (py::isinstance(obj, torch_Tensor)) { + return PolymorphicValue(py::cast(obj)); + } else if (py::isinstance(obj)) { + return PolymorphicValue(py::cast(obj)); + } else if (py::isinstance(obj)) { + return PolymorphicValue(py::cast(obj)); + } else if (py::isinstance(obj)) { + return PolymorphicValue(py::cast(obj)); + } else if (PyComplex_Check(obj.ptr())) { + return PolymorphicValue(py::cast>(obj)); + } + NVF_THROW("Cannot convert provided py::handle to a PolymorphicValue."); +} + +} // namespace + KernelArgumentHolder from_pyiterable( const py::iterable& iter, std::optional device) { @@ -19,10 +39,10 @@ KernelArgumentHolder from_pyiterable( // Allows for a Vector of Sizes to be inputed as a list/tuple if (py::isinstance(obj) || py::isinstance(obj)) { for (py::handle item : obj) { - args.push(torch::jit::toIValue(item, c10::AnyType::get())); + args.push(toPolymorphicValue(item)); } } else { - args.push(torch::jit::toIValue(obj, c10::AnyType::get())); + args.push(toPolymorphicValue(obj)); } }