Skip to content

Commit

Permalink
Merge pull request #642 from LaurentMazare/remove-cuda-hack
Browse files Browse the repository at this point in the history
Another attempt at removing the cuda hack.
  • Loading branch information
LaurentMazare authored Oct 4, 2024
2 parents a2e44c4 + 16e8f59 commit fbb7039
Show file tree
Hide file tree
Showing 7 changed files with 6 additions and 60 deletions.
1 change: 0 additions & 1 deletion examples/stable-diffusion/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2448,7 +2448,6 @@ impl DDIMScheduler {
}

fn main() -> anyhow::Result<()> {
tch::maybe_init_cuda();
println!("Cuda available: {}", tch::Cuda::is_available());
println!("Cudnn available: {}", tch::Cuda::cudnn_is_available());
// TODO: Switch to using claps to allow more flags?
Expand Down
6 changes: 0 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,3 @@ pub use tensor::{

pub mod nn;
pub mod vision;

pub fn maybe_init_cuda() {
unsafe {
torch_sys::dummy_cuda_dependency();
}
}
4 changes: 0 additions & 4 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! A Torch tensor.
use crate::{Device, Kind, TchError};
use torch_sys::*;

mod convert;
pub mod display;
Expand Down Expand Up @@ -252,6 +251,3 @@ impl Tensor {
self.g_to_mkldnn(self.kind())
}
}

#[used]
static INIT_ARRAY: [unsafe extern "C" fn(); 1] = [dummy_cuda_dependency];
16 changes: 6 additions & 10 deletions torch-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,7 @@ impl SystemInfo {
}
}

fn make(&self, use_cuda: bool, use_hip: bool) {
let cuda_dependency = if use_cuda || use_hip {
"libtch/dummy_cuda_dependency.cpp"
} else {
"libtch/fake_cuda_dependency.cpp"
};
println!("cargo:rerun-if-changed={}", cuda_dependency);
fn make(&self) {
println!("cargo:rerun-if-changed=libtch/torch_python.cpp");
println!("cargo:rerun-if-changed=libtch/torch_python.h");
println!("cargo:rerun-if-changed=libtch/torch_api_generated.cpp");
Expand All @@ -365,8 +359,7 @@ impl SystemInfo {
println!("cargo:rerun-if-changed=libtch/stb_image_write.h");
println!("cargo:rerun-if-changed=libtch/stb_image_resize.h");
println!("cargo:rerun-if-changed=libtch/stb_image.h");
let mut c_files =
vec!["libtch/torch_api.cpp", "libtch/torch_api_generated.cpp", cuda_dependency];
let mut c_files = vec!["libtch/torch_api.cpp", "libtch/torch_api_generated.cpp"];
if cfg!(feature = "python-extension") {
c_files.push("libtch/torch_python.cpp")
}
Expand Down Expand Up @@ -440,6 +433,9 @@ fn main() -> anyhow::Result<()> {
// if this issue.
// TODO: Try out the as-needed native link modifier when it lands.
// https://github.com/rust-lang/rust/issues/99424
//
// Update: it seems that the dummy dependency is not necessary anymore, so just
// removing it and keeping this comment around for legacy.
let si_lib = &system_info.libtorch_lib_dir;
let use_cuda =
si_lib.join("libtorch_cuda.so").exists() || si_lib.join("torch_cuda.dll").exists();
Expand All @@ -451,7 +447,7 @@ fn main() -> anyhow::Result<()> {
si_lib.join("libtorch_hip.so").exists() || si_lib.join("torch_hip.dll").exists();
println!("cargo:rustc-link-search=native={}", si_lib.display());

system_info.make(use_cuda, use_hip);
system_info.make();

println!("cargo:rustc-link-lib=static=tch");
if use_cuda {
Expand Down
29 changes: 0 additions & 29 deletions torch-sys/libtch/dummy_cuda_dependency.cpp

This file was deleted.

6 changes: 0 additions & 6 deletions torch-sys/libtch/fake_cuda_dependency.cpp

This file was deleted.

4 changes: 0 additions & 4 deletions torch-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,3 @@ extern "C" {
pub fn atm_set_tensor_expr_fuser_enabled(enabled: c_int);
pub fn atm_get_tensor_expr_fuser_enabled() -> bool;
}

extern "C" {
pub fn dummy_cuda_dependency();
}

0 comments on commit fbb7039

Please sign in to comment.