Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion crates/burn-core/src/module/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,21 @@ pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
)
}

/// Move the module and all of its sub-modules to the autodiff backend.
///
/// # Notes
///
/// * Only plain modules (not already on an autodiff backend) can be moved.
/// * Calling `train()` on a module that is already on an autodiff backend
/// will result in a type error, because the module's inner backend does not match.
fn train<AB>(self) -> <Self as HasAutodiffModule<AB>>::TrainModule
where
AB: AutodiffBackend<InnerBackend = B>,
Self: HasAutodiffModule<AB>,
{
<Self as HasAutodiffModule<AB>>::TrainModule::from_inner(self)
}

/// Get the number of parameters the module has, including all of its sub-modules.
fn num_params(&self) -> usize {
module!(
Expand Down Expand Up @@ -370,6 +385,54 @@ pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + core::fmt::Debu
/// Inner module without auto-differentiation.
type InnerModule: Module<B::InnerBackend>;

/// Get the same module, but on the inner backend without auto-differentiation.
/// Returns the same module, but on the inner backend without auto-differentiation.
fn valid(&self) -> Self::InnerModule;

/// Wraps an inner module back into an auto-diff module.
fn from_inner(module: Self::InnerModule) -> Self;
}

/// Helper trait to associate a module with its autodiff version.
pub trait HasAutodiffModule<B: AutodiffBackend> {
/// The module with auto-differentiation.
type TrainModule: AutodiffModule<B, InnerModule = Self>;
}

#[cfg(test)]
mod tests {
use super::*;

use crate::TestAutodiffBackend;
use crate::test_utils::SimpleLinear;

#[test]
fn test_module_val_train_stateful() {
let device = Default::default();
let module = SimpleLinear::<TestAutodiffBackend>::new(4, 4, &device);

assert!(module.weight.is_require_grad());
assert!(module.weight.require_grad);

let module = module.valid();
assert!(!module.weight.is_require_grad());
assert!(module.weight.require_grad); // stateful

// Without `HasAutodiffModule`, we would need to specify the module type as well, which would be annoying
// let module: SimpleLinear<TestAutodiffBackend> = module.train();
let module = module.train::<TestAutodiffBackend>();
assert!(module.weight.is_require_grad());
assert!(module.weight.require_grad); // stateful

let module = module.no_grad();
assert!(!module.weight.is_require_grad());
assert!(!module.weight.require_grad); // stateful

let module = module.valid();
assert!(!module.weight.is_require_grad()); // always
assert!(!module.weight.require_grad); // stateful

let module = module.train::<TestAutodiffBackend>();
assert!(!module.weight.is_require_grad());
assert!(!module.weight.require_grad); // stateful
}
}
9 changes: 9 additions & 0 deletions crates/burn-core/src/module/param/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ pub struct Param<T: Parameter> {
/// - After lazy init triggers: `Some(RwLock<None>)` (inner Option is taken)
pub(crate) initialization: Option<RwLock<Option<Uninitialized<T>>>>,
pub(crate) param_mapper: ParamMapper<T>,
// For stateful `module.valid()` <> `module.train()`
pub(crate) require_grad: bool,
}

#[derive(Clone)]
Expand Down Expand Up @@ -170,11 +172,13 @@ impl<P: Parameter> Uninitialized<P> {
impl<T: Parameter> Param<T> {
/// Create a new parameter that is already initialized.
pub fn initialized(id: ParamId, value: T) -> Self {
let require_grad = value.is_require_grad();
Self {
id,
state: OnceCell::from(value),
initialization: None,
param_mapper: Default::default(),
require_grad,
}
}

Expand All @@ -199,6 +203,7 @@ impl<T: Parameter> Param<T> {
shape,
}))),
param_mapper: Default::default(),
require_grad: is_require_grad,
}
}

Expand Down Expand Up @@ -247,12 +252,14 @@ impl<T: Parameter> Param<T> {
pub fn map<F: FnOnce(T) -> T>(self, func: F) -> Self {
let (id, tensor, param_mapper) = self.consume();
let tensor = func(tensor);
let require_grad = tensor.is_require_grad();

Self {
id,
state: OnceCell::from(tensor),
initialization: None,
param_mapper,
require_grad,
}
}

Expand All @@ -261,11 +268,13 @@ impl<T: Parameter> Param<T> {
/// This is a helper method for creating parameters while preserving the param mapper,
/// typically used in ModuleMapper implementations.
pub fn from_mapped_value(id: ParamId, value: T, param_mapper: ParamMapper<T>) -> Self {
let require_grad = value.is_require_grad();
Self {
id,
state: OnceCell::from(value),
initialization: None,
param_mapper,
require_grad,
}
}

Expand Down
16 changes: 16 additions & 0 deletions crates/burn-core/src/module/param/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ macro_rules! constant {
fn valid(&self) -> Self::InnerModule {
self.clone()
}

fn from_inner(module: Self::InnerModule) -> Self {
module
}
};

($type:ty) => {
Expand Down Expand Up @@ -197,6 +201,10 @@ impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<
fn valid(&self) -> Self::InnerModule {
self.clone().inner()
}

fn from_inner(tensor: Self::InnerModule) -> Self {
Tensor::from_inner(tensor)
}
}

impl<B: Backend> Module<B> for PhantomData<B> {
Expand Down Expand Up @@ -245,6 +253,10 @@ impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
fn valid(&self) -> Self::InnerModule {
PhantomData
}

fn from_inner(_module: Self::InnerModule) -> Self {
PhantomData
}
}

/// Container to satisfy the Module trait for types that are not modules.
Expand Down Expand Up @@ -318,6 +330,10 @@ where
fn valid(&self) -> Self::InnerModule {
self.clone()
}

fn from_inner(module: Self::InnerModule) -> Self {
module
}
}

// Implement deref for Ignored
Expand Down
19 changes: 19 additions & 0 deletions crates/burn-core/src/module/param/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ where
fn valid(&self) -> Self::InnerModule {
self.as_ref().map(|module| module.valid())
}

fn from_inner(module: Self::InnerModule) -> Self {
module.map(|module| T::from_inner(module))
}
}

impl<T, B> Module<B> for Vec<T>
Expand Down Expand Up @@ -184,6 +188,13 @@ where
fn valid(&self) -> Self::InnerModule {
self.iter().map(|module| module.valid()).collect()
}

fn from_inner(module: Self::InnerModule) -> Self {
module
.into_iter()
.map(|module| T::from_inner(module))
.collect()
}
}

impl<const N: usize, T, B> Module<B> for [T; N]
Expand Down Expand Up @@ -281,6 +292,10 @@ where
fn valid(&self) -> Self::InnerModule {
self.clone().map(|module| module.valid())
}

fn from_inner(module: Self::InnerModule) -> Self {
module.map(|module| T::from_inner(module))
}
}

/// A macro for generating implementations for tuple modules of different sizes.
Expand Down Expand Up @@ -339,6 +354,10 @@ macro_rules! impl_module_tuple {
fn valid(&self) -> Self::InnerModule {
($(self.$i.valid(),)*)
}

fn from_inner(module: Self::InnerModule) -> Self {
($($l::from_inner(module.$i),)*)
}
}

impl<$($l,)*> ModuleDisplayDefault for ($($l,)*)
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-core/src/module/param/running.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,4 +248,11 @@ impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for RunningState<Tens

RunningState::with_id(self.id, value.inner())
}

fn from_inner(module: Self::InnerModule) -> Self {
module.sync();
let value = module.value();

RunningState::with_id(module.id, Tensor::from_inner(value))
}
}
62 changes: 59 additions & 3 deletions crates/burn-core/src/module/param/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{Param, ParamId, Parameter};
use crate::module::{
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
ModuleVisitor,
AutodiffModule, Content, HasAutodiffModule, Module, ModuleDisplay, ModuleDisplayDefault,
ModuleMapper, ModuleVisitor,
};
use crate::tensor::{
Tensor,
Expand Down Expand Up @@ -458,16 +458,36 @@ impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D
type InnerModule = Param<Tensor<B::InnerBackend, D>>;

fn valid(&self) -> Self::InnerModule {
Param::initialized(self.id, self.val().inner().set_require_grad(false))
// Preserve initialized param `require_grad` state, but reset the inner value's
let require_grad = self.require_grad;
let mut param = Param::initialized(self.id, self.val().inner().set_require_grad(false));
param.require_grad = require_grad;
param
}

fn from_inner(module: Self::InnerModule) -> Self {
// Reinstate the param's `require_grad` state
let tensor = Tensor::from_inner(module.val()).set_require_grad(module.require_grad);
Param::initialized(module.id, tensor)
}
}

impl<const D: usize, B: AutodiffBackend> HasAutodiffModule<B>
for Param<Tensor<B::InnerBackend, D>>
{
type TrainModule = Param<Tensor<B, D>>;
}

impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Int>> {
type InnerModule = Param<Tensor<B::InnerBackend, D, Int>>;

fn valid(&self) -> Self::InnerModule {
Param::initialized(self.id, self.val().inner())
}

fn from_inner(module: Self::InnerModule) -> Self {
Param::initialized(module.id, Tensor::from_inner(module.val()))
}
}

impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Bool>> {
Expand All @@ -476,6 +496,10 @@ impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D
fn valid(&self) -> Self::InnerModule {
Param::initialized(self.id, self.val().inner())
}

fn from_inner(module: Self::InnerModule) -> Self {
Param::initialized(module.id, Tensor::from_inner(module.val()))
}
}

#[cfg(all(test, feature = "std"))]
Expand Down Expand Up @@ -512,4 +536,36 @@ mod tests {
assert!(!no_grad_is_require_grad);
assert!(with_default_is_require_grad);
}

#[test]
fn test_param_require_grad_stateful() {
let device = Default::default();
let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device).require_grad();

let param = Param::initialized(ParamId::new(), tensor);
assert!(param.is_require_grad());
assert!(param.require_grad);

let param = param.valid();
assert!(!param.is_require_grad());
assert!(param.require_grad); // stateful

// Without `HasAutodiffModule`, we would need to specify the param type as well, which would be annoying:
// let param: Param<Tensor<TestAutodiffBackend, _>> = param.train();
let param = param.train::<TestAutodiffBackend>();
assert!(param.is_require_grad());
assert!(param.require_grad); // stateful

let param = param.no_grad();
assert!(!param.is_require_grad());
assert!(!param.require_grad); // stateful

let param = param.valid();
assert!(!param.is_require_grad()); // always
assert!(!param.require_grad); // stateful

let param = param.train::<TestAutodiffBackend>();
assert!(!param.is_require_grad());
assert!(!param.require_grad); // stateful
}
}
Loading