diff --git a/crates/burn-core/src/module/base.rs b/crates/burn-core/src/module/base.rs index 099bbb9fd3..981b09dfc8 100644 --- a/crates/burn-core/src/module/base.rs +++ b/crates/burn-core/src/module/base.rs @@ -126,6 +126,21 @@ pub trait Module: Clone + Send + core::fmt::Debug { ) } + /// Move the module and all of its sub-modules to the autodiff backend. + /// + /// # Notes + /// + /// * Only plain modules (not already on an autodiff backend) can be moved. + /// * Calling `train()` on a module that is already on an autodiff backend + /// will result in a type error, because the module's inner backend does not match. + fn train(self) -> >::TrainModule + where + AB: AutodiffBackend, + Self: HasAutodiffModule, + { + >::TrainModule::from_inner(self) + } + /// Get the number of parameters the module has, including all of its sub-modules. fn num_params(&self) -> usize { module!( @@ -370,6 +385,54 @@ pub trait AutodiffModule: Module + Send + core::fmt::Debu /// Inner module without auto-differentiation. type InnerModule: Module; - /// Get the same module, but on the inner backend without auto-differentiation. + /// Returns the same module, but on the inner backend without auto-differentiation. fn valid(&self) -> Self::InnerModule; + + /// Wraps an inner module back into an auto-diff module. + fn from_inner(module: Self::InnerModule) -> Self; +} + +/// Helper trait to associate a module with its autodiff version. +pub trait HasAutodiffModule { + /// The module with auto-differentiation. + type TrainModule: AutodiffModule; +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::TestAutodiffBackend; + use crate::test_utils::SimpleLinear; + + #[test] + fn test_module_val_train_stateful() { + let device = Default::default(); + let module = SimpleLinear::::new(4, 4, &device); + + assert!(module.weight.is_require_grad()); + assert!(module.weight.require_grad); + + let module = module.valid(); + assert!(!module.weight.is_require_grad()); + assert!(module.weight.require_grad); // stateful + + // Without `HasAutodiffModule`, we would need to specify the module type as well, which would be annoying + // let module: SimpleLinear = module.train(); + let module = module.train::(); + assert!(module.weight.is_require_grad()); + assert!(module.weight.require_grad); // stateful + + let module = module.no_grad(); + assert!(!module.weight.is_require_grad()); + assert!(!module.weight.require_grad); // stateful + + let module = module.valid(); + assert!(!module.weight.is_require_grad()); // always + assert!(!module.weight.require_grad); // stateful + + let module = module.train::(); + assert!(!module.weight.is_require_grad()); + assert!(!module.weight.require_grad); // stateful + } } diff --git a/crates/burn-core/src/module/param/base.rs b/crates/burn-core/src/module/param/base.rs index 685ff99ff2..62b76beb5f 100644 --- a/crates/burn-core/src/module/param/base.rs +++ b/crates/burn-core/src/module/param/base.rs @@ -56,6 +56,8 @@ pub struct Param { /// - After lazy init triggers: `Some(RwLock)` (inner Option is taken) pub(crate) initialization: Option>>>, pub(crate) param_mapper: ParamMapper, + // For stateful `module.valid()` <> `module.train()` + pub(crate) require_grad: bool, } #[derive(Clone)] @@ -170,11 +172,13 @@ impl Uninitialized

{ impl Param { /// Create a new parameter that is already initialized. pub fn initialized(id: ParamId, value: T) -> Self { + let require_grad = value.is_require_grad(); Self { id, state: OnceCell::from(value), initialization: None, param_mapper: Default::default(), + require_grad, } } @@ -199,6 +203,7 @@ impl Param { shape, }))), param_mapper: Default::default(), + require_grad: is_require_grad, } } @@ -247,12 +252,14 @@ impl Param { pub fn map T>(self, func: F) -> Self { let (id, tensor, param_mapper) = self.consume(); let tensor = func(tensor); + let require_grad = tensor.is_require_grad(); Self { id, state: OnceCell::from(tensor), initialization: None, param_mapper, + require_grad, } } @@ -261,11 +268,13 @@ impl Param { /// This is a helper method for creating parameters while preserving the param mapper, /// typically used in ModuleMapper implementations. pub fn from_mapped_value(id: ParamId, value: T, param_mapper: ParamMapper) -> Self { + let require_grad = value.is_require_grad(); Self { id, state: OnceCell::from(value), initialization: None, param_mapper, + require_grad, } } diff --git a/crates/burn-core/src/module/param/constant.rs b/crates/burn-core/src/module/param/constant.rs index e391365efa..f465b7aa30 100644 --- a/crates/burn-core/src/module/param/constant.rs +++ b/crates/burn-core/src/module/param/constant.rs @@ -91,6 +91,10 @@ macro_rules! constant { fn valid(&self) -> Self::InnerModule { self.clone() } + + fn from_inner(module: Self::InnerModule) -> Self { + module + } }; ($type:ty) => { @@ -197,6 +201,10 @@ impl> AutodiffModule< fn valid(&self) -> Self::InnerModule { self.clone().inner() } + + fn from_inner(tensor: Self::InnerModule) -> Self { + Tensor::from_inner(tensor) + } } impl Module for PhantomData { @@ -245,6 +253,10 @@ impl AutodiffModule for PhantomData { fn valid(&self) -> Self::InnerModule { PhantomData } + + fn from_inner(_module: Self::InnerModule) -> Self { + PhantomData + } } /// Container to satisfy the Module trait for types that are not modules. @@ -318,6 +330,10 @@ where fn valid(&self) -> Self::InnerModule { self.clone() } + + fn from_inner(module: Self::InnerModule) -> Self { + module + } } // Implement deref for Ignored diff --git a/crates/burn-core/src/module/param/primitive.rs b/crates/burn-core/src/module/param/primitive.rs index 46e0d21555..e197b116e8 100644 --- a/crates/burn-core/src/module/param/primitive.rs +++ b/crates/burn-core/src/module/param/primitive.rs @@ -81,6 +81,10 @@ where fn valid(&self) -> Self::InnerModule { self.as_ref().map(|module| module.valid()) } + + fn from_inner(module: Self::InnerModule) -> Self { + module.map(|module| T::from_inner(module)) + } } impl Module for Vec @@ -184,6 +188,13 @@ where fn valid(&self) -> Self::InnerModule { self.iter().map(|module| module.valid()).collect() } + + fn from_inner(module: Self::InnerModule) -> Self { + module + .into_iter() + .map(|module| T::from_inner(module)) + .collect() + } } impl Module for [T; N] @@ -281,6 +292,10 @@ where fn valid(&self) -> Self::InnerModule { self.clone().map(|module| module.valid()) } + + fn from_inner(module: Self::InnerModule) -> Self { + module.map(|module| T::from_inner(module)) + } } /// A macro for generating implementations for tuple modules of different sizes. @@ -339,6 +354,10 @@ macro_rules! impl_module_tuple { fn valid(&self) -> Self::InnerModule { ($(self.$i.valid(),)*) } + + fn from_inner(module: Self::InnerModule) -> Self { + ($($l::from_inner(module.$i),)*) + } } impl<$($l,)*> ModuleDisplayDefault for ($($l,)*) diff --git a/crates/burn-core/src/module/param/running.rs b/crates/burn-core/src/module/param/running.rs index be494a9485..16ff1d628a 100644 --- a/crates/burn-core/src/module/param/running.rs +++ b/crates/burn-core/src/module/param/running.rs @@ -248,4 +248,11 @@ impl AutodiffModule for RunningState Self { + module.sync(); + let value = module.value(); + + RunningState::with_id(module.id, Tensor::from_inner(value)) + } } diff --git a/crates/burn-core/src/module/param/tensor.rs b/crates/burn-core/src/module/param/tensor.rs index 29b14b4a8a..88bbfba7af 100644 --- a/crates/burn-core/src/module/param/tensor.rs +++ b/crates/burn-core/src/module/param/tensor.rs @@ -1,7 +1,7 @@ use super::{Param, ParamId, Parameter}; use crate::module::{ - AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper, - ModuleVisitor, + AutodiffModule, Content, HasAutodiffModule, Module, ModuleDisplay, ModuleDisplayDefault, + ModuleMapper, ModuleVisitor, }; use crate::tensor::{ Tensor, @@ -458,16 +458,36 @@ impl AutodiffModule for Param>; fn valid(&self) -> Self::InnerModule { - Param::initialized(self.id, self.val().inner().set_require_grad(false)) + // Preserve initialized param `require_grad` state, but reset the inner value's + let require_grad = self.require_grad; + let mut param = Param::initialized(self.id, self.val().inner().set_require_grad(false)); + param.require_grad = require_grad; + param + } + + fn from_inner(module: Self::InnerModule) -> Self { + // Reinstate the param's `require_grad` state + let tensor = Tensor::from_inner(module.val()).set_require_grad(module.require_grad); + Param::initialized(module.id, tensor) } } +impl HasAutodiffModule + for Param> +{ + type TrainModule = Param>; +} + impl AutodiffModule for Param> { type InnerModule = Param>; fn valid(&self) -> Self::InnerModule { Param::initialized(self.id, self.val().inner()) } + + fn from_inner(module: Self::InnerModule) -> Self { + Param::initialized(module.id, Tensor::from_inner(module.val())) + } } impl AutodiffModule for Param> { @@ -476,6 +496,10 @@ impl AutodiffModule for Param Self::InnerModule { Param::initialized(self.id, self.val().inner()) } + + fn from_inner(module: Self::InnerModule) -> Self { + Param::initialized(module.id, Tensor::from_inner(module.val())) + } } #[cfg(all(test, feature = "std"))] @@ -512,4 +536,36 @@ mod tests { assert!(!no_grad_is_require_grad); assert!(with_default_is_require_grad); } + + #[test] + fn test_param_require_grad_stateful() { + let device = Default::default(); + let tensor = Tensor::::ones([3, 3], &device).require_grad(); + + let param = Param::initialized(ParamId::new(), tensor); + assert!(param.is_require_grad()); + assert!(param.require_grad); + + let param = param.valid(); + assert!(!param.is_require_grad()); + assert!(param.require_grad); // stateful + + // Without `HasAutodiffModule`, we would need to specify the param type as well, which would be annoying: + // let param: Param> = param.train(); + let param = param.train::(); + assert!(param.is_require_grad()); + assert!(param.require_grad); // stateful + + let param = param.no_grad(); + assert!(!param.is_require_grad()); + assert!(!param.require_grad); // stateful + + let param = param.valid(); + assert!(!param.is_require_grad()); // always + assert!(!param.require_grad); // stateful + + let param = param.train::(); + assert!(!param.is_require_grad()); + assert!(!param.require_grad); // stateful + } } diff --git a/crates/burn-derive/src/module/codegen.rs b/crates/burn-derive/src/module/codegen.rs index ce114217ad..268118f520 100644 --- a/crates/burn-derive/src/module/codegen.rs +++ b/crates/burn-derive/src/module/codegen.rs @@ -15,6 +15,7 @@ pub(crate) trait ModuleCodegen { fn gen_fork(&self) -> TokenStream; fn gen_map(&self) -> TokenStream; fn gen_valid(&self) -> TokenStream; + fn gen_from_inner(&self) -> TokenStream; fn gen_into_record(&self) -> TokenStream; fn gen_load_record(&self) -> TokenStream; fn gen_clone(&self) -> TokenStream; @@ -39,6 +40,7 @@ pub(crate) fn generate_module_standard( let to_device = codegen.gen_to_device(); let fork = codegen.gen_fork(); let valid_fn = codegen.gen_valid(); + let from_inner_fn = codegen.gen_from_inner(); let into_record_fn = codegen.gen_into_record(); let load_record_fn = codegen.gen_load_record(); let clone_fn = codegen.gen_clone(); @@ -51,8 +53,12 @@ pub(crate) fn generate_module_standard( generics.module.split_for_impl(); let (generics_module_autodiff, generics_ty_module_autodiff, generics_where_module_autodiff) = generics.module_autodiff.split_for_impl(); + let (generics_module_has_autodiff, _generics_ty, generics_where_module_has_autodiff) = + generics.module_has_autodiff.split_for_impl(); let generics_ty_inner_module = generics.inner_module_ty; + let generics_ty_train_module = generics.train_module_ty; + let generics_ty_train_inner_module = generics.train_inner_ty; let mut codegen = quote! { impl #generics_module burn::module::Module for #name #generics_ty_module #generics_where_module { @@ -77,6 +83,13 @@ pub(crate) fn generate_module_standard( type InnerModule=#name; #valid_fn + + #from_inner_fn + } + + impl #generics_module_has_autodiff burn::module::HasAutodiffModule for #name #generics_where_module_has_autodiff + { + type TrainModule=#name; } impl #generics_module core::fmt::Display for #name #generics_ty_module #generics_where_module { @@ -168,13 +181,17 @@ pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream { struct GenericsParser { module: Generics, module_autodiff: Generics, + module_has_autodiff: Generics, inner_module_ty: TokenStream, + train_module_ty: TokenStream, + train_inner_ty: TokenStream, } impl GenericsParser { fn from_ast(generics: &Generics) -> Self { let mut module = GenericsHelper::new(generics.clone()); let mut module_autodiff = GenericsHelper::new(generics.clone()); + let mut module_has_autodiff = GenericsHelper::new(generics.clone()); let backend_trait = module.fetch_backend_trait(); @@ -186,7 +203,17 @@ impl GenericsParser { ::InnerBackend: #backend_trait }); + module_has_autodiff.add_predicate(parse_quote! { + B: burn::tensor::backend::AutodiffBackend + }); + + module_has_autodiff.add_predicate(parse_quote! { + ::InnerBackend: #backend_trait + }); + let mut generics_names_except_backend = quote! {}; + let mut train_generics_names_except_backend = quote! {}; + let mut train_inner_generics_names_except_backend = quote! {}; module .types() @@ -231,16 +258,47 @@ impl GenericsParser { } ); + module_has_autodiff.add_predicate( + parse_quote! { + #ident: burn::module::Module + } + ); + + module_has_autodiff.add_predicate( + parse_quote! { + #ident: burn::module::ModuleDisplay + } + ); + + module_has_autodiff.add_predicate( + parse_quote! { + #ident: burn::module::HasAutodiffModule + } + ); + + module_has_autodiff.add_predicate( + parse_quote! { + #ident::TrainModule: burn::module::ModuleDisplay + } + ); + train_generics_names_except_backend.extend(quote! { #ident, }); + train_inner_generics_names_except_backend.extend(quote! { #ident::TrainModule, }); + }); module.consts().into_iter().for_each(|ident| { generics_names_except_backend.extend(quote! { #ident, }); + train_generics_names_except_backend.extend(quote! { #ident, }); + train_inner_generics_names_except_backend.extend(quote! { #ident, }); }); Self { module: module.generics, module_autodiff: module_autodiff.generics, + module_has_autodiff: module_has_autodiff.generics, inner_module_ty: generics_names_except_backend, + train_module_ty: train_generics_names_except_backend, + train_inner_ty: train_inner_generics_names_except_backend, } } } diff --git a/crates/burn-derive/src/module/codegen_enum.rs b/crates/burn-derive/src/module/codegen_enum.rs index 0a7eaf7141..57ec79a084 100644 --- a/crates/burn-derive/src/module/codegen_enum.rs +++ b/crates/burn-derive/src/module/codegen_enum.rs @@ -1,6 +1,6 @@ use super::{codegen::ModuleCodegen, record_enum::EnumModuleRecordCodegen}; use crate::shared::enum_variant::{EnumVariant, parse_variants}; -use proc_macro2::{Ident, TokenStream}; +use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; use syn::Visibility; @@ -127,6 +127,21 @@ impl ModuleCodegen for EnumModuleCodegen { } } + fn gen_from_inner(&self) -> TokenStream { + let match_body = + self.gen_variants_match_fn_param("module", "Self::InnerModule::", |variant| { + quote! { + Self::#variant(burn::module::AutodiffModule::::from_inner(module)) + } + }); + + quote! { + fn from_inner(module: Self::InnerModule) -> Self { + #match_body + } + } + } + fn gen_into_record(&self) -> TokenStream { let match_body = self.gen_variants_match_fn(|variant| { quote! { @@ -186,24 +201,33 @@ impl EnumModuleCodegen { } } - /// Generate the enum variants' match arm with the provided function + /// Generate the enum variants' match arms with the provided function fn gen_variants_match_fn(&self, func: F) -> TokenStream where F: Fn(Ident) -> TokenStream, { - let mut match_arms = quote! {}; + self.gen_variants_match_fn_param("self", "Self::", func) + } - for variant in self.variants.iter() { + /// Generate a match expression over the given argument (e.g., `self`) + /// and using the provided prefix for variants (e.g., `Self::` or `Self::InnerModule::`) + fn gen_variants_match_fn_param(&self, arg: &str, prefix: &str, func: F) -> TokenStream + where + F: Fn(Ident) -> TokenStream, + { + let match_arms = self.variants.iter().map(|variant| { let name = &variant.ident; - let arm_pattern = quote! {Self::#name(module)}; + let full_variant = syn::parse_str::(&format!("{prefix}{name}")).unwrap(); + let arm_pattern = quote! { #full_variant(module) }; let arm_code = func(name.clone()); + quote! { #arm_pattern => #arm_code, } + }); - match_arms.extend(quote! {#arm_pattern => #arm_code,}) - } + let arg = Ident::new(arg, Span::call_site()); quote! { - match self { - #match_arms + match #arg { + #(#match_arms)* } } } diff --git a/crates/burn-derive/src/module/codegen_struct.rs b/crates/burn-derive/src/module/codegen_struct.rs index 964b8a80a2..54f45059dc 100644 --- a/crates/burn-derive/src/module/codegen_struct.rs +++ b/crates/burn-derive/src/module/codegen_struct.rs @@ -142,6 +142,30 @@ impl ModuleCodegen for StructModuleCodegen { } } + fn gen_from_inner(&self) -> TokenStream { + let (names, body) = self.gen_fields_fn_names(|name| { + quote! { + let #name = burn::module::AutodiffModule::::from_inner(#name); + } + }); + + // Destructure inner module to move all fields + let destructure = quote! { + let Self::InnerModule { #(#names),* } = module; + }; + + quote! { + fn from_inner(module: Self::InnerModule) -> Self { + #destructure + #body + + Self { + #(#names),* + } + } + } + } + fn gen_into_record(&self) -> TokenStream { let body = self.gen_fields_fn(|name| { quote! { diff --git a/crates/burn-nn/src/modules/norm/batch.rs b/crates/burn-nn/src/modules/norm/batch.rs index 9cf7a2c0c2..4c9dbb987f 100644 --- a/crates/burn-nn/src/modules/norm/batch.rs +++ b/crates/burn-nn/src/modules/norm/batch.rs @@ -226,21 +226,9 @@ mod tests_1d { let output = module.forward(input_tensor(&device)); - let expected = TensorData::from([ - [ - [1.1483e+00, 3.7521e-01], - [1.6272e-03, 7.5067e-01], - [1.6204e+00, -4.5168e-02], - ], - [ - [6.8856e-02, -1.5923e+00], - [-1.6318e+00, 8.7949e-01], - [-5.3368e-01, -1.0416e+00], - ], - ]); output .to_data() - .assert_approx_eq::(&expected, Tolerance::rel_abs(0.1, 0.001)); + .assert_approx_eq::(&expected_train(), Tolerance::rel_abs(0.1, 0.001)); } #[test] @@ -252,13 +240,31 @@ mod tests_1d { let module = module.valid(); let output = module.forward(input_tensor(&device)); - let expected = TensorData::from([ - [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]], - [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]], - ]); output .to_data() - .assert_approx_eq::(&expected, Tolerance::default()); + .assert_approx_eq::(&expected_valid(), Tolerance::default()); + } + + fn expected_valid() -> TensorData { + TensorData::from([ + [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]], + [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]], + ]) + } + + fn expected_train() -> TensorData { + TensorData::from([ + [ + [1.1483e+00, 3.7521e-01], + [1.6272e-03, 7.5067e-01], + [1.6204e+00, -4.5168e-02], + ], + [ + [6.8856e-02, -1.5923e+00], + [-1.6318e+00, 8.7949e-01], + [-5.3368e-01, -1.0416e+00], + ], + ]) } fn input_tensor(device: &B::Device) -> Tensor { @@ -270,6 +276,26 @@ mod tests_1d { device, ) } + + #[test] + fn batch_norm_forward_train_inference() { + let device = Default::default(); + let module = BatchNormConfig::new(3).init::(&device); + + module.forward(input_tensor(&device)); + let module = module.valid(); + let output = module.forward(input_tensor(&device)); + + output + .to_data() + .assert_approx_eq::(&expected_valid(), Tolerance::default()); + + let module = module.train::(); + let output = module.forward(input_tensor(&device)); + output + .to_data() + .assert_approx_eq::(&expected_train(), Tolerance::default()); + } } #[cfg(feature = "std")]