Skip to content

Commit

Permalink
rfc: support for torch-tensorrt
Browse files Browse the repository at this point in the history
  • Loading branch information
hietalajulius committed Nov 14, 2024
1 parent 4841b68 commit cc71c52
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ download-libtorch = ["torch-sys/download-libtorch"]
python-extension = ["torch-sys/python-extension"]
rl-python = ["cpython"]
doc-only = ["torch-sys/doc-only"]
torch-tensorrt = ["torch-sys/torch-tensorrt"]
cuda-tests = []

[package.metadata.docs.rs]
Expand Down
1 change: 1 addition & 0 deletions torch-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ zip = "0.6"
download-libtorch = ["ureq", "serde", "serde_json"]
doc-only = []
python-extension = []
torch-tensorrt = []

[package.metadata.docs.rs]
features = [ "doc-only" ]
79 changes: 79 additions & 0 deletions torch-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,88 @@ impl SystemInfo {
}
}

const PYTHON_PRINT_TORCH_TENSORRT_PATH: &str = r#"
try:
import torch_tensorrt
import os
lib_path = os.path.join(os.path.dirname(torch_tensorrt.__file__), 'lib')
print('TORCH_TENSORRT_LIB:', lib_path)
except ImportError:
print('TORCH_TENSORRT_LIB: NOT_FOUND')
"#;

const PYTHON_PRINT_TENSORRT_LIBS_PATH: &str = r#"
try:
import tensorrt_libs
import os
lib_path = os.path.dirname(tensorrt_libs.__file__)
print('TENSORRT_LIB:', lib_path)
except ImportError:
print('TENSORRT_LIB: NOT_FOUND')
"#;

fn get_torch_tensorrt_lib_path(python_interpreter: &Path) -> Option<PathBuf> {
let output = std::process::Command::new(python_interpreter)
.arg("-c")
.arg(PYTHON_PRINT_TORCH_TENSORRT_PATH)
.output()
.ok()?;

for line in String::from_utf8_lossy(&output.stdout).lines() {
if let Some(path) = line.strip_prefix("TORCH_TENSORRT_LIB: ") {
if path != "NOT_FOUND" {
return Some(PathBuf::from(path));
}
}
}
None
}

fn get_tensorrt_libs_path(python_interpreter: &Path) -> Option<PathBuf> {
let output = std::process::Command::new(python_interpreter)
.arg("-c")
.arg(PYTHON_PRINT_TENSORRT_LIBS_PATH)
.output()
.ok()?;

for line in String::from_utf8_lossy(&output.stdout).lines() {
if let Some(path) = line.strip_prefix("TENSORRT_LIB: ") {
if path != "NOT_FOUND" {
return Some(PathBuf::from(path));
}
}
}
None
}

fn link_torch_tensorrt(system_info: &SystemInfo) {
if let Some(tensorrt_path) = get_torch_tensorrt_lib_path(&system_info.python_interpreter) {
println!("cargo:rustc-link-search=native={}", tensorrt_path.display());
println!("cargo:rustc-link-arg=-Wl,-rpath={}", tensorrt_path.display());
println!("cargo:rustc-link-lib=dylib:-as-needed=torchtrt_runtime");
}
}

fn link_tensorrt_libs(system_info: &SystemInfo) {
if let Some(tensorrt_libs_path) = get_tensorrt_libs_path(&system_info.python_interpreter) {
println!("cargo:rustc-link-search=native={}", tensorrt_libs_path.display());
println!("cargo:rustc-link-arg=-Wl,-rpath={}", tensorrt_libs_path.display());
println!("cargo:rustc-link-lib=nvinfer_plugin");
println!("cargo:rustc-link-lib=nvinfer");
println!("cargo:rustc-link-lib=nvonnxparser");
}
}

fn main() -> anyhow::Result<()> {
if !cfg!(feature = "doc-only") {
let system_info = SystemInfo::new()?;

if cfg!(feature = "torch-tensorrt") {
//create the feature flag for cpp called USE_TORCH_TENSORRT
cc::Build::new().define("USE_TORCH_TENSORRT", None);
link_torch_tensorrt(&system_info);
link_tensorrt_libs(&system_info);
}
// use_cuda is a hacky way to detect whether cuda is available and
// if it's the case link to it by explicitly depending on a symbol
// from the torch_cuda library.
Expand Down
4 changes: 4 additions & 0 deletions torch-sys/libtch/torch_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
#include<vector>
#include "torch_api.h"

#ifdef USE_TORCH_TENSORRT
#include "torch_tensorrt/torch_tensorrt.h"
#endif

#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"

Expand Down

0 comments on commit cc71c52

Please sign in to comment.