Skip to content

Commit

Permalink
Remove autodiff from generate (#2759)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Jan 31, 2025
1 parent 6d8dd69 commit cb0854c
Showing 1 changed file with 10 additions and 19 deletions.
29 changes: 10 additions & 19 deletions examples/wgan/examples/wgan-generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,18 @@ pub fn launch<B: Backend>(device: B::Device) {
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::{
ndarray::{NdArray, NdArrayDevice},
Autodiff,
};
use burn::backend::ndarray::{NdArray, NdArrayDevice};

use crate::launch;

pub fn run() {
launch::<Autodiff<NdArray>>(NdArrayDevice::Cpu);
launch::<NdArray>(NdArrayDevice::Cpu);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

use crate::launch;

Expand All @@ -38,41 +32,38 @@ mod tch_gpu {
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

launch::<Autodiff<LibTorch>>(device);
launch::<LibTorch>(device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

use crate::launch;

pub fn run() {
launch::<Autodiff<LibTorch>>(LibTorchDevice::Cpu);
launch::<LibTorch>(LibTorchDevice::Cpu);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use crate::launch;
use burn::backend::{wgpu::Wgpu, Autodiff};
use burn::backend::wgpu::Wgpu;

pub fn run() {
launch::<Autodiff<Wgpu>>(Default::default());
launch::<Wgpu>(Default::default());
}
}

#[cfg(feature = "cuda")]
mod cuda {
use crate::launch;
use burn::backend::{Autodiff, Cuda};
use burn::backend::Cuda;

pub fn run() {
launch::<Autodiff<Cuda>>(Default::default());
launch::<Cuda>(Default::default());
}
}

Expand Down

0 comments on commit cb0854c

Please sign in to comment.