Skip to content

Commit

Permalink
pybind attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
manickavela29 committed Jun 28, 2024
1 parent 4f5a58f commit 389cbf5
Showing 1 changed file with 46 additions and 3 deletions.
49 changes: 46 additions & 3 deletions sherpa-onnx/python/csrc/provider-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,64 @@

namespace sherpa_onnx {

static void PybindCudaConfig(py::module *m) {
using PyClass = CudaConfig;
py::class_<PyClass>(*m, "CudaConfig")
.def(py::init<int32_t>(),
py::arg("cudnn_conv_algo_search") = 1)
.def_readwrite("cudnn_conv_algo_search", &PyClass::cudnn_conv_algo_search)
.def("__str__", &PyClass::ToString);
}

static void PybindTensorrtConfig(py::module *m) {
using PyClass = TensorrtConfig;
py::class_<PyClass>(*m, "TensorrtConfig")
.def(py::init<int32_t, int32_t, int32_t, bool,
bool, bool, bool, const std::string &,
const std::string &, bool>(),
py::arg("trt_max_workspace_size") = 2147483648,
py::arg("trt_max_partition_iterations") = 10,
py::arg("trt_min_subgraph_size") = 5,
py::arg("trt_fp16_enable") = true,
py::arg("trt_detailed_build_log") = false,
py::arg("trt_engine_cache_enable") = true,
py::arg("trt_timing_cache_enable") = true,
py::arg("trt_engine_cache_path") = ".",
py::arg("trt_timing_cache_path") = ".",
py::arg("trt_dump_subgraphs") = false)
.def_readwrite("trt_max_workspace_size", &PyClass::trt_max_workspace_size)
.def_readwrite("trt_max_partition_iterations", &PyClass::trt_max_partition_iterations)
.def_readwrite("trt_min_subgraph_size", &PyClass::trt_min_subgraph_size)
.def_readwrite("trt_fp16_enable", &PyClass::trt_fp16_enable)
.def_readwrite("trt_detailed_build_log", &PyClass::trt_detailed_build_log)
.def_readwrite("trt_engine_cache_enable", &PyClass::trt_engine_cache_enable)
.def_readwrite("trt_timing_cache_enable", &PyClass::trt_timing_cache_enable)
.def_readwrite("trt_engine_cache_path", &PyClass::trt_engine_cache_path)
.def_readwrite("trt_timing_cache_path", &PyClass::trt_timing_cache_path)
.def_readwrite("trt_dump_subgraphs", &PyClass::trt_dump_subgraphs)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}

void PybindProviderConfig(py::module *m) {
PybindCudaConfig(m);
PybindTensorrtConfig(m);

using PyClass = ProviderConfig;
py::class_<PyClass>(*m, "ProviderConfig")
.def(py::init<const TensorrtConfig &,
const CudaConfig &, const std::string,
const CudaConfig &, const std::string &,
int32_t>(),
py::arg("trt_config") = TensorrtConfig(),
py::arg("cuda_config") = CudaConfig(),
py::arg("provider") = "cpu",
py::arg("device") = 0)
.def_readwrite("cuda_config", &PyClass::cuda_config)
.def_readwrite("trt_config", &PyClass::trt_config)
.def_readwrite("cuda_config", &PyClass::cuda_config)
.def_readwrite("provider", &PyClass::provider)
.def_readwrite("device", &PyClass::device)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);}
.def("validate", &PyClass::Validate);
}

} // namespace sherpa_onnx

0 comments on commit 389cbf5

Please sign in to comment.