diff --git a/Cargo.lock b/Cargo.lock index 6ee6c6f481..bd2d874b15 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1046,6 +1046,7 @@ dependencies = [ "burn-tensor", "burn-wgpu", "byteorder", + "ciborium", "divan", "half", "hashbrown 0.15.5", @@ -1411,6 +1412,33 @@ dependencies = [ "phf", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "cipher" version = "0.4.4" diff --git a/Cargo.toml b/Cargo.toml index 374ff575c5..dd81d9b53d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,7 +81,8 @@ regex = { version = "1.11.3", default-features = false, features = [ reqwest = { version = "0.12.23", default-features = false, features = [ "rustls-tls", ] } -rmp-serde = "1.3.0" +ciborium = { version = "0.2", default-features = false } +rmp-serde = { version = "1.3.0", default-features = false } rstest = "0.25.0" rusqlite = "0.37.0" rust-format = "0.3.4" diff --git a/crates/burn-core/src/module/base.rs b/crates/burn-core/src/module/base.rs index 3493dacd4c..099bbb9fd3 100644 --- a/crates/burn-core/src/module/base.rs +++ b/crates/burn-core/src/module/base.rs @@ -1,4 +1,4 @@ -use super::{ParamId, Quantizer}; +use super::{Param, ParamId, Quantizer}; use crate::{ record::Record, tensor::backend::{AutodiffBackend, Backend}, @@ -19,11 +19,12 @@ macro_rules! module { impl ModuleMapper for Mapper { fn map_float( &mut self, - _id: ParamId, - tensor: Tensor, - ) -> Tensor { + param: Param>, + ) -> Param> { + let (id, tensor, mapper) = param.consume(); let func = $item; - func(tensor) + let tensor = func(tensor); + Param::from_mapped_value(id, tensor, mapper) } } let mut mapper = Mapper; @@ -35,9 +36,9 @@ macro_rules! module { backend: core::marker::PhantomData, } impl<'a, B: Backend> ModuleVisitor for Visitor<'a, B> { - fn visit_float(&mut self, _id: ParamId, tensor: &Tensor) { + fn visit_float(&mut self, param: &Param>) { let func = $item; - func(tensor, &mut self.state) + func(¶m.val(), &mut self.state) } } #[allow(clippy::redundant_closure_call)] @@ -211,29 +212,26 @@ pub trait Module: Clone + Send + core::fmt::Debug { /// Module visitor trait for traversing and inspecting module parameters. pub trait ModuleVisitor { - /// Visit a float tensor in the module. + /// Visit a float parameter in the module. /// /// # Parameters - /// - `id`: The unique identifier of the parameter - /// - `tensor`: The float tensor to visit + /// - `param`: The float parameter to visit #[allow(unused_variables)] - fn visit_float(&mut self, id: ParamId, tensor: &Tensor) {} + fn visit_float(&mut self, param: &Param>) {} - /// Visit an int tensor in the module. + /// Visit an int parameter in the module. /// /// # Parameters - /// - `id`: The unique identifier of the parameter - /// - `tensor`: The integer tensor to visit + /// - `param`: The integer parameter to visit #[allow(unused_variables)] - fn visit_int(&mut self, id: ParamId, tensor: &Tensor) {} + fn visit_int(&mut self, param: &Param>) {} - /// Visit a bool tensor in the module. + /// Visit a bool parameter in the module. /// /// # Parameters - /// - `id`: The unique identifier of the parameter - /// - `tensor`: The boolean tensor to visit + /// - `param`: The boolean parameter to visit #[allow(unused_variables)] - fn visit_bool(&mut self, id: ParamId, tensor: &Tensor) {} + fn visit_bool(&mut self, param: &Param>) {} /// Called when entering a submodule. /// @@ -321,51 +319,49 @@ pub trait ModuleMapper { #[allow(unused_variables)] fn exit_module(&mut self, name: &str, container_type: &str) {} - /// Map a float tensor in the module. + /// Map a float parameter in the module. /// /// # Parameters - /// - `id`: The unique identifier of the parameter - /// - `tensor`: The float tensor to transform + /// - `param`: The float parameter to transform /// /// # Returns - /// The transformed tensor + /// The transformed parameter #[allow(unused_variables)] - fn map_float(&mut self, id: ParamId, tensor: Tensor) -> Tensor { - tensor + fn map_float(&mut self, param: Param>) -> Param> { + let (id, tensor, mapper) = param.consume(); + Param::from_mapped_value(id, tensor, mapper) } - /// Map an int tensor in the module. + /// Map an int parameter in the module. /// /// # Parameters - /// - `id`: The unique identifier of the parameter - /// - `tensor`: The integer tensor to transform + /// - `param`: The integer parameter to transform /// /// # Returns - /// The transformed tensor + /// The transformed parameter #[allow(unused_variables)] fn map_int( &mut self, - id: ParamId, - tensor: Tensor, - ) -> Tensor { - tensor + param: Param>, + ) -> Param> { + let (id, tensor, mapper) = param.consume(); + Param::from_mapped_value(id, tensor, mapper) } - /// Map a bool tensor in the module. + /// Map a bool parameter in the module. /// /// # Parameters - /// - `id`: The unique identifier of the parameter - /// - `tensor`: The boolean tensor to transform + /// - `param`: The boolean parameter to transform /// /// # Returns - /// The transformed tensor + /// The transformed parameter #[allow(unused_variables)] fn map_bool( &mut self, - id: ParamId, - tensor: Tensor, - ) -> Tensor { - tensor + param: Param>, + ) -> Param> { + let (id, tensor, mapper) = param.consume(); + Param::from_mapped_value(id, tensor, mapper) } } diff --git a/crates/burn-core/src/module/initializer.rs b/crates/burn-core/src/module/initializer.rs index bdd2c76251..8903bc9361 100644 --- a/crates/burn-core/src/module/initializer.rs +++ b/crates/burn-core/src/module/initializer.rs @@ -107,6 +107,7 @@ impl Initializer { let device = device.clone(); let shape: Shape = shape.into(); let config = self.clone(); + let shape_for_closure = shape.clone(); Param::uninitialized( ParamId::new(), @@ -123,6 +124,7 @@ impl Initializer { }, device, true, + shape_for_closure, ) } diff --git a/crates/burn-core/src/module/param/base.rs b/crates/burn-core/src/module/param/base.rs index 6226e0c764..685ff99ff2 100644 --- a/crates/burn-core/src/module/param/base.rs +++ b/crates/burn-core/src/module/param/base.rs @@ -1,6 +1,7 @@ use super::ParamId; use alloc::{boxed::Box, format}; use burn_common::stub::RwLock; +use burn_tensor::Shape; use core::cell::OnceCell; use core::ops::Deref; @@ -28,61 +29,64 @@ fn new_mapper T + Send + Sync + 'static>(func: F) -> Mapper { /// Parameters are the fundamental building blocks of [modules](crate::module::Module) where they /// serve as containers for [tensors](crate::tensor::Tensor) that can be updated during -/// training, and loaded during inference. If you don't want to save the tensors with a record +/// training, and loaded during inference. If you don't want to save the tensors /// and/or don't want to update it during training, you don't need this type to wrap your tensor. /// -/// # Laziness +/// # Core Lazy Initialization Architecture /// -/// The initialization of parameters can be lazy when created using -/// [uninitialized](Self::uninitialized), which can be done using an [initializer](crate::module::Initializer). +/// `Param` has a dual-state design using `OnceCell`: /// -/// This reduces the amount of allocations done when loading a model for inference without having -/// to create a custom initialization function only for inference. +/// ## State Management /// -/// ## Example +/// **Two possible states:** /// -/// ```rust, ignore -/// let device = Device::default(); -/// let config = ModuleConfig::default(); -/// let record = Recorder::new().load("/path/to/module", &device); -/// -/// // No tensor allocation -/// let module = config.init(device); -/// // Will use the tensor allocated for the record if the same device is used. -/// let module = module.load_record(record); -/// ``` +/// 1. **Initialized**: `state: OnceCell` contains value, `initialization: None` +/// 2. **Uninitialized (Lazy)**: `state` is empty, `initialization: Some(RwLock>>)` pub struct Param { /// The unique ID of this parameter. This is used by eg. optimizers to associate a gradient with a specific parameter. pub id: ParamId, - state: OnceCell, - /// The locking is only required because of `lazy_device` and `lazy_is_require_grad`. + /// The OnceCell holding the initialized parameter value. + /// Empty for uninitialized parameters, populated after first access or explicit initialization. + pub(crate) state: OnceCell, + /// The deferred initialization state for lazy parameters. /// - /// Because of once cell, we have a guarantee that the initialization will only be called once, - /// but it may be called at the same time as `lazy_device` and `lazy_is_require_grad`, which is - /// when the lock is actually useful, waiting for the initialization to be completed before - /// returning the value. - initialization: Option>>>, - pub(crate) record_mapper: RecordMapper, + /// **State Transitions:** + /// - Initialized params: `None` + /// - Uninitialized params: `Some(RwLock)>)` + /// - After lazy init triggers: `Some(RwLock)` (inner Option is taken) + pub(crate) initialization: Option>>>, + pub(crate) param_mapper: ParamMapper, } #[derive(Clone)] -/// Applies functions when loading and saving parameters. -pub struct RecordMapper { +/// Applies transformations when loading and saving parameters. +/// +/// # Mapper System +/// +/// `ParamMapper` allows applying transformations during serialization and deserialization: +/// - `load: Option>` - transformation during deserialization (applied in `transform_for_load()`) +/// - `save: Option>` - transformation during serialization (applied in `transform_for_save()`) +/// +/// These are commonly used for: +/// - Quantization/dequantization +/// - Precision conversion (e.g., FP32 ↔ FP16) +/// - Custom parameter transformations +pub struct ParamMapper { load: Option>, save: Option>, } -impl core::fmt::Debug for RecordMapper { +impl core::fmt::Debug for ParamMapper { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_fmt(format_args!( - "RecordMapper {{ load: {}, save: {} }}", + "ParamMapper {{ load: {}, save: {} }}", self.load.is_some(), self.save.is_some() )) } } -impl RecordMapper { +impl ParamMapper { /// Applies the transformation when loading the given parameter. pub fn on_load(&self, param: T) -> T { match &self.load { @@ -99,7 +103,7 @@ impl RecordMapper { } } -impl Default for RecordMapper { +impl Default for ParamMapper { fn default() -> Self { Self { load: None, @@ -116,7 +120,7 @@ impl core::fmt::Display for Param { impl core::fmt::Debug for Param { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(format!("Param: {} - {:?}", self.id, self.record_mapper).as_str()) + f.write_str(format!("Param: {} - {:?}", self.id, self.param_mapper).as_str()) } } @@ -135,14 +139,28 @@ pub trait Parameter: Clone + core::fmt::Debug + Send { fn set_require_grad(self, require_grad: bool) -> Self; } +/// The deferred initialization state for lazy parameters. #[allow(clippy::type_complexity)] -struct Uninitialized { +pub(crate) struct Uninitialized { + /// The initialization function. Called with `(device, is_require_grad) -> Parameter`. + /// This function is consumed during initialization via `FnOnce`. init: Box P + Send>, - device: P::Device, - is_require_grad: bool, + /// The target device on which the parameter should be initialized. + /// Used by `lazy_device()` to provide device information without triggering initialization. + pub(crate) device: P::Device, + /// The gradient requirement for the parameter. + /// Used by `lazy_is_require_grad()` to provide gradient settings without triggering initialization. + pub(crate) is_require_grad: bool, + /// The shape of the tensor parameter. + /// Used by `lazy_shape()` to provide shape information without triggering initialization. + pub(crate) shape: Shape, } impl Uninitialized

{ + /// Consumes the uninitialized state and runs the initialization function. + /// + /// This is called by [Param::val] when accessing an uninitialized parameter for the first time. + /// The function is given the stored device and gradient requirement, and returns the initialized parameter. fn initialize(self) -> P { let init = self.init; init(&self.device, self.is_require_grad) @@ -156,12 +174,18 @@ impl Param { id, state: OnceCell::from(value), initialization: None, - record_mapper: Default::default(), + param_mapper: Default::default(), } } /// Create a new parameter that is not already initialized. - pub fn uninitialized(id: ParamId, init: F, device: T::Device, is_require_grad: bool) -> Self + pub fn uninitialized( + id: ParamId, + init: F, + device: T::Device, + is_require_grad: bool, + shape: Shape, + ) -> Self where F: FnOnce(&T::Device, bool) -> T + Send + 'static, { @@ -172,12 +196,16 @@ impl Param { init: Box::new(init), device, is_require_grad, + shape, }))), - record_mapper: Default::default(), + param_mapper: Default::default(), } } - /// Gets the parameter value. + /// Gets the parameter value, initializing it lazily if needed. + /// + /// For initialized parameters, this returns a clone of the cached value. + /// For uninitialized parameters, this triggers initialization: pub fn val(&self) -> T { self.state .get_or_init(|| { @@ -193,43 +221,64 @@ impl Param { .clone() } + /// Check if the parameter has been initialized. + /// + /// Returns `true` if the parameter's value has been computed and cached, + /// `false` if it's still lazy and will be initialized on first access. + pub fn is_initialized(&self) -> bool { + self.state.get().is_some() + } + /// Gets the parameter's value while consuming the parameter. pub fn into_value(self) -> T { self.consume().1 } /// Gets the parameter id and value while consuming the parameter. - pub fn consume(self) -> (ParamId, T, RecordMapper) { + pub fn consume(self) -> (ParamId, T, ParamMapper) { let tensor = self.val(); core::mem::drop(self.state); - (self.id, tensor, self.record_mapper) + (self.id, tensor, self.param_mapper) } /// Execute the given function on the inner value. pub fn map T>(self, func: F) -> Self { - let (id, tensor, record_mapper) = self.consume(); + let (id, tensor, param_mapper) = self.consume(); let tensor = func(tensor); Self { id, state: OnceCell::from(tensor), initialization: None, - record_mapper, + param_mapper, } } - /// Runs a transformation on the parameter when loading a saved record. + /// Create an initialized parameter with the given id, value, and param mapper. + /// + /// This is a helper method for creating parameters while preserving the param mapper, + /// typically used in ModuleMapper implementations. + pub fn from_mapped_value(id: ParamId, value: T, param_mapper: ParamMapper) -> Self { + Self { + id, + state: OnceCell::from(value), + initialization: None, + param_mapper, + } + } + + /// Runs a transformation on the parameter when loading. pub fn load_mapper T + Send + Sync + 'static>(mut self, func: F) -> Self { - self.record_mapper.load = Some(new_mapper(func)); + self.param_mapper.load = Some(new_mapper(func)); self } - /// Runs a transformation on the parameter when saving the record. + /// Runs a transformation on the parameter when saving. pub fn save_mapper T + Send + Sync + 'static>(mut self, func: F) -> Self { - self.record_mapper.save = Some(new_mapper(func)); + self.param_mapper.save = Some(new_mapper(func)); self } @@ -267,19 +316,15 @@ impl Param { } } - /// The device on which the parameter is or will be initialized. - /// - /// This should be used instead of [crate::tensor::Tensor::device], since using the tensor - /// function requires a dereference, which triggers the initialization. This is only useful - /// when the device is used for updating the tensor value, which has potentially not been - /// initialized yet, like loading a record. + /// The device on which the parameter is or will be initialized, **without triggering initialization**. /// - /// # Notes + /// This is critical for the load optimization: when loading tensors into an uninitialized parameter, + /// we need to know the target device to move the loaded tensor appropriately, but we don't want to + /// trigger the initialization function (which would allocate an unnecessary tensor). /// - /// This is a crate-private function, since users are not expected to use the device of an - /// uninitialized module to then override its value. All low-level functions should be provided - /// by `burn` and should handle those details. - pub(crate) fn lazy_device(&self) -> T::Device { + /// Use this instead of [crate::tensor::Tensor::device] when you need the device but want to + /// preserve lazy initialization. + pub fn lazy_device(&self) -> T::Device { let initialization = match &self.initialization { Some(init) => init, None => return self.device(), @@ -293,12 +338,11 @@ impl Param { } } - /// The gradient requirement on which the parameter is or will be initialized. + /// The gradient requirement on which the parameter is or will be initialized, **without triggering initialization**. /// - /// This should be used instead of [crate::tensor::Tensor::is_require_grad], since using the tensor - /// function requires a dereference, which triggers the initialization. This is only useful - /// when the boolean is used for updating the tensor value, which has potentially not been - /// initialized yet, like loading a record. + /// Similar to [lazy_device](Self::lazy_device), this is critical for the load optimization. + /// When loading tensors into an uninitialized parameter, we need to apply the correct gradient + /// setting to the loaded tensor without triggering the initialization function. /// /// # Notes /// @@ -347,7 +391,7 @@ impl Param { impl Clone for Param { fn clone(&self) -> Self { let mut param = Param::initialized(self.id, self.val()); - param.record_mapper = self.record_mapper.clone(); + param.param_mapper = self.param_mapper.clone(); param } } diff --git a/crates/burn-core/src/module/param/running.rs b/crates/burn-core/src/module/param/running.rs index 863930ac42..be494a9485 100644 --- a/crates/burn-core/src/module/param/running.rs +++ b/crates/burn-core/src/module/param/running.rs @@ -80,12 +80,15 @@ impl Module for RunningState> { fn visit>(&self, visitor: &mut V) { let tensor = self.value.lock().unwrap(); - visitor.visit_float(self.id, &tensor) + let param = Param::initialized(self.id, tensor.clone()); + visitor.visit_float(¶m) } fn map>(self, mapper: &mut M) -> Self { let mut tensor = self.value.lock().unwrap(); - let tensor_out = mapper.map_float(self.id, tensor.clone()); + let param = Param::initialized(self.id, tensor.clone()); + let param_out = mapper.map_float(param); + let (_, tensor_out, _) = param_out.consume(); *tensor = tensor_out; core::mem::drop(tensor); diff --git a/crates/burn-core/src/module/param/tensor.rs b/crates/burn-core/src/module/param/tensor.rs index a8c5d24f10..29b14b4a8a 100644 --- a/crates/burn-core/src/module/param/tensor.rs +++ b/crates/burn-core/src/module/param/tensor.rs @@ -72,6 +72,28 @@ impl Param> { Param::initialized(ParamId::new(), value.require_grad()) } + /// The shape of the parameter, **without triggering initialization**. + /// + /// This is critical for shape validation during loading: when applying tensors to an + /// uninitialized parameter, we need to validate the shape without triggering the + /// initialization function (which would allocate an unnecessary tensor). + /// + /// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to + /// preserve lazy initialization. + pub fn lazy_shape(&self) -> burn_tensor::Shape { + let initialization = match &self.initialization { + Some(init) => init, + None => return self.shape(), + }; + + let init = initialization.read().unwrap(); + + match init.as_ref() { + Some(value) => value.shape.clone(), + None => self.shape(), + } + } + /// Create a new parameter from data. pub fn from_data(data: T, device: &B::Device) -> Self where @@ -84,51 +106,198 @@ impl Param> { Param::initialized(ParamId::new(), value.require_grad()) }) } -} -impl Module for Param> { - type Record = Param>; + /// Transform a parameter for loading by applying load transformations. + /// + /// This method is used to restore a parameter from a tensor (typically during deserialization). + /// It ensures the tensor is moved to the expected device, applies the param mapper's + /// `on_load` transformation, and preserves the autodiff settings (require_grad). + pub fn transform_for_load(self, tensor: Tensor, param_id: ParamId) -> Self { + let mut new_tensor = tensor; - fn visit>(&self, visitor: &mut V) { - visitor.visit_float(self.id, &self.val()) + let mapper = self.param_mapper.clone(); + + let expected_device = self.lazy_device(); + let expected_require_grad = self.lazy_is_require_grad(); + + // Make sure we load the tensor into the same module device. + if new_tensor.device() != expected_device { + new_tensor = new_tensor.to_device(&expected_device).detach(); + } + + new_tensor = mapper.on_load(new_tensor); + + // Make sure we load the tensor with the same autodiff setting. + new_tensor = new_tensor.set_require_grad(expected_require_grad); + + let mut loaded = Self::initialized(param_id, new_tensor); + loaded.param_mapper = mapper; + loaded } - fn map>(self, mapper: &mut M) -> Self { - let (id, tensor, _mapper) = self.consume(); - let value = mapper.map_float(id, tensor); - Self::initialized(id, value) + /// Transform a parameter for saving by applying save transformations. + /// + /// This method is used to prepare a parameter for saving (typically during serialization). + /// It applies the param mapper's `on_save` transformation, which can be used + /// to modify the tensor before serialization (e.g., quantization, precision conversion). + pub fn transform_for_save(&self) -> Self { + let mut tensor = self.val(); + let mapper = self.param_mapper.clone(); + + tensor = mapper.on_save(tensor); + + Self::initialized(self.id, tensor) } +} - fn into_record(self) -> Self::Record { - let (new_id, mut new_value, mapper) = self.consume(); +impl Param> { + /// The shape of the parameter, **without triggering initialization**. + /// + /// This is critical for shape validation during loading: when applying tensors to an + /// uninitialized parameter, we need to validate the shape without triggering the + /// initialization function (which would allocate an unnecessary tensor). + /// + /// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to + /// preserve lazy initialization. + pub fn lazy_shape(&self) -> burn_tensor::Shape { + let initialization = match &self.initialization { + Some(init) => init, + None => return self.shape(), + }; - new_value = mapper.on_save(new_value); + let init = initialization.read().unwrap(); - Self::initialized(new_id, new_value) + match init.as_ref() { + Some(value) => value.shape.clone(), + None => self.shape(), + } } - fn load_record(self, record: Self::Record) -> Self { - let (new_id, mut new_value, _mapper) = record.consume(); - let mapper = self.record_mapper.clone(); + /// Transform a parameter for loading by applying load transformations. + /// + /// This method is used to restore a parameter from a tensor (typically during deserialization). + /// It ensures the tensor is moved to the expected device and applies the param mapper's + /// `on_load` transformation. + pub fn transform_for_load(self, tensor: Tensor, param_id: ParamId) -> Self { + let mut new_tensor = tensor; + + let mapper = self.param_mapper.clone(); let expected_device = self.lazy_device(); - let expected_require_grad = self.lazy_is_require_grad(); - // Make sure we load the record into the same module device. - if new_value.device() != expected_device { - new_value = new_value.to_device(&expected_device).detach(); + // Make sure we load the tensor into the same module device. + if new_tensor.device() != expected_device { + new_tensor = new_tensor.to_device(&expected_device); + } + + new_tensor = mapper.on_load(new_tensor); + + let mut loaded = Self::initialized(param_id, new_tensor); + loaded.param_mapper = mapper; + loaded + } + + /// Transform a parameter for saving by applying save transformations. + /// + /// This method is used to prepare a parameter for saving (typically during serialization). + /// It applies the param mapper's `on_save` transformation, which can be used + /// to modify the tensor before serialization (e.g., quantization, precision conversion). + pub fn transform_for_save(&self) -> Self { + let mut tensor = self.val(); + let mapper = self.param_mapper.clone(); + + tensor = mapper.on_save(tensor); + + Self::initialized(self.id, tensor) + } +} + +impl Param> { + /// The shape of the parameter, **without triggering initialization**. + /// + /// This is critical for shape validation during loading: when applying tensors to an + /// uninitialized parameter, we need to validate the shape without triggering the + /// initialization function (which would allocate an unnecessary tensor). + /// + /// **Returns:** + /// - For uninitialized params: the shape from the `Uninitialized` struct + /// - For initialized params: the actual shape from the tensor + /// + /// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to + /// preserve lazy initialization. + pub fn lazy_shape(&self) -> burn_tensor::Shape { + let initialization = match &self.initialization { + Some(init) => init, + None => return self.shape(), + }; + + let init = initialization.read().unwrap(); + + match init.as_ref() { + Some(value) => value.shape.clone(), + None => self.shape(), } + } - new_value = mapper.on_load(new_value); + /// Transform a parameter for loading by applying load transformations. + /// + /// This method is used to restore a parameter from a tensor (typically during deserialization). + /// It ensures the tensor is moved to the expected device and applies the param mapper's + /// `on_load` transformation. + pub fn transform_for_load(self, tensor: Tensor, param_id: ParamId) -> Self { + let mut new_tensor = tensor; - // Make sure we load the record with the same autodiff setting. - new_value = new_value.set_require_grad(expected_require_grad); + let mapper = self.param_mapper.clone(); - let mut loaded = Self::initialized(new_id, new_value); - loaded.record_mapper = mapper; + let expected_device = self.lazy_device(); + + // Make sure we load the tensor into the same module device. + if new_tensor.device() != expected_device { + new_tensor = new_tensor.to_device(&expected_device); + } + + new_tensor = mapper.on_load(new_tensor); + + let mut loaded = Self::initialized(param_id, new_tensor); + loaded.param_mapper = mapper; loaded } + /// Transform a parameter for saving by applying save transformations. + /// + /// This method is used to prepare a parameter for saving (typically during serialization). + /// It applies the param mapper's `on_save` transformation, which can be used + /// to modify the tensor before serialization (e.g., quantization, precision conversion). + pub fn transform_for_save(&self) -> Self { + let mut tensor = self.val(); + let mapper = self.param_mapper.clone(); + + tensor = mapper.on_save(tensor); + + Self::initialized(self.id, tensor) + } +} + +impl Module for Param> { + type Record = Param>; + + fn visit>(&self, visitor: &mut V) { + visitor.visit_float(self) + } + + fn map>(self, mapper: &mut M) -> Self { + mapper.map_float(self) + } + + fn into_record(self) -> Self::Record { + self.transform_for_save() + } + + fn load_record(self, record: Self::Record) -> Self { + let (record_param_id, record_tensor, _) = record.consume(); + self.transform_for_load(record_tensor, record_param_id) + } + fn to_device(self, device: &Device) -> Self { self.map(|tensor| tensor.to_device(device)) } @@ -177,38 +346,20 @@ impl Module for Param> { type Record = Param>; fn visit>(&self, visitor: &mut V) { - visitor.visit_int(self.id, &self.val()) + visitor.visit_int(self) } fn map>(self, mapper: &mut M) -> Self { - let value = mapper.map_int(self.id, self.val()); - Self::initialized(self.id, value) + mapper.map_int(self) } fn into_record(self) -> Self::Record { - let (new_id, mut new_value, mapper) = self.consume(); - - new_value = mapper.on_save(new_value); - - Self::initialized(new_id, new_value) + self.transform_for_save() } fn load_record(self, record: Self::Record) -> Self { - let (new_id, mut new_value, _mapper) = record.consume(); - let mapper = self.record_mapper.clone(); - - let expected_device = self.lazy_device(); - - // Make sure we load the record into the same module device. - if new_value.device() != expected_device { - new_value = new_value.to_device(&expected_device); - } - - new_value = mapper.on_load(new_value); - - let mut loaded = Self::initialized(new_id, new_value); - loaded.record_mapper = mapper; - loaded + let (record_param_id, record_tensor, _) = record.consume(); + self.transform_for_load(record_tensor, record_param_id) } fn to_device(self, device: &Device) -> Self { @@ -250,38 +401,20 @@ impl Module for Param> { type Record = Param>; fn visit>(&self, visitor: &mut V) { - visitor.visit_bool(self.id, &self.val()) + visitor.visit_bool(self) } fn map>(self, mapper: &mut M) -> Self { - let value = mapper.map_bool(self.id, self.val()); - Self::initialized(self.id, value) + mapper.map_bool(self) } fn into_record(self) -> Self::Record { - let (new_id, mut new_value, mapper) = self.consume(); - - new_value = mapper.on_save(new_value); - - Self::initialized(new_id, new_value) + self.transform_for_save() } fn load_record(self, record: Self::Record) -> Self { - let (new_id, mut new_value, _mapper) = record.consume(); - let mapper = self.record_mapper.clone(); - - let expected_device = self.lazy_device(); - - // Make sure we load the record into the same module device. - if new_value.device() != expected_device { - new_value = new_value.to_device(&expected_device); - } - - new_value = mapper.on_load(new_value); - - let mut loaded = Self::initialized(new_id, new_value); - loaded.record_mapper = mapper; - loaded + let (record_param_id, record_tensor, _) = record.consume(); + self.transform_for_load(record_tensor, record_param_id) } fn to_device(self, device: &Device) -> Self { diff --git a/crates/burn-core/src/module/param/visitor.rs b/crates/burn-core/src/module/param/visitor.rs index 633aa18ff9..169f260e00 100644 --- a/crates/burn-core/src/module/param/visitor.rs +++ b/crates/burn-core/src/module/param/visitor.rs @@ -1,4 +1,4 @@ -use super::ParamId; +use super::{Param, ParamId}; use crate::module::{Module, ModuleVisitor}; use alloc::vec::Vec; use burn_tensor::{Bool, Int, Tensor, backend::Backend}; @@ -14,14 +14,14 @@ where B: Backend, M: Module, { - fn visit_float(&mut self, id: ParamId, _tensor: &Tensor) { - self.param_ids.push(id); + fn visit_float(&mut self, param: &Param>) { + self.param_ids.push(param.id); } - fn visit_int(&mut self, id: ParamId, _tensor: &Tensor) { - self.param_ids.push(id); + fn visit_int(&mut self, param: &Param>) { + self.param_ids.push(param.id); } - fn visit_bool(&mut self, id: ParamId, _tensor: &Tensor) { - self.param_ids.push(id); + fn visit_bool(&mut self, param: &Param>) { + self.param_ids.push(param.id); } } diff --git a/crates/burn-core/src/module/quantize.rs b/crates/burn-core/src/module/quantize.rs index 7ae1e3d10d..6d86418efd 100644 --- a/crates/burn-core/src/module/quantize.rs +++ b/crates/burn-core/src/module/quantize.rs @@ -4,7 +4,7 @@ use burn_tensor::{ quantization::{Calibration, QuantScheme, compute_q_params, compute_range}, }; -use crate::module::{ModuleMapper, ParamId}; +use crate::module::{ModuleMapper, Param}; /// Describes how to quantize a module. pub struct Quantizer { @@ -15,10 +15,12 @@ pub struct Quantizer { } impl ModuleMapper for Quantizer { - fn map_float(&mut self, _id: ParamId, tensor: Tensor) -> Tensor { + fn map_float(&mut self, param: Param>) -> Param> { + let (id, tensor, mapper) = param.consume(); let range = compute_range(&self.scheme, &tensor, &self.calibration); let qparams = compute_q_params(&self.scheme, range); - tensor.quantize(&self.scheme, qparams) + let tensor = tensor.quantize(&self.scheme, qparams); + Param::from_mapped_value(id, tensor, mapper) } } diff --git a/crates/burn-core/src/module/reinit.rs b/crates/burn-core/src/module/reinit.rs index 84506dba04..6bdf97ffe7 100644 --- a/crates/burn-core/src/module/reinit.rs +++ b/crates/burn-core/src/module/reinit.rs @@ -1,4 +1,4 @@ -use super::{Module, ModuleMapper, ParamId}; +use super::{Module, ModuleMapper}; use burn_tensor::{ Element, ElementConversion, Tensor, TensorData, backend::Backend, @@ -118,12 +118,16 @@ impl Reinitializer { } impl ModuleMapper for Reinitializer { - fn map_float(&mut self, _id: ParamId, tensor: Tensor) -> Tensor { + fn map_float( + &mut self, + param: super::Param>, + ) -> super::Param> { + let (id, tensor, mapper) = param.consume(); let device = tensor.device(); let shape = tensor.shape(); let num_elements = shape.num_elements(); - match &self.float { + let tensor = match &self.float { ReinitStrategy::Range { min, max } => { let tensor = Tensor::arange(0..num_elements as i64, &device) .reshape(shape) @@ -139,19 +143,21 @@ impl ModuleMapper for Reinitializer { ); Tensor::from_data(data, &device) } - } + }; + + super::Param::from_mapped_value(id, tensor, mapper) } fn map_int( &mut self, - _id: ParamId, - tensor: Tensor, - ) -> Tensor { + param: super::Param>, + ) -> super::Param> { + let (id, tensor, mapper) = param.consume(); let device = tensor.device(); let shape = tensor.shape(); let num_elements = shape.num_elements(); - match &self.int { + let tensor = match &self.int { ReinitStrategy::Range { min, max } => { let tensor = Tensor::arange(0..num_elements as i64, &device).reshape(shape); let (factor, bias) = resolve::>(*min, *max, num_elements); @@ -165,15 +171,17 @@ impl ModuleMapper for Reinitializer { ); Tensor::from_data(data, &device) } - } + }; + + super::Param::from_mapped_value(id, tensor, mapper) } fn map_bool( &mut self, - _id: ParamId, - tensor: Tensor, - ) -> Tensor { - tensor + param: super::Param>, + ) -> super::Param> { + let (id, tensor, mapper) = param.consume(); + super::Param::from_mapped_value(id, tensor, mapper) } } diff --git a/crates/burn-import/src/burn/node/constant.rs b/crates/burn-import/src/burn/node/constant.rs index 7204044c51..46e4671019 100644 --- a/crates/burn-import/src/burn/node/constant.rs +++ b/crates/burn-import/src/burn/node/constant.rs @@ -158,7 +158,8 @@ impl NodeCodegen for ConstantNode { burn::module::ParamId::new(), move |device, _require_grad| Tensor::::zeros(#shape, device), device.clone(), - false + false, + #shape.into(), ); }), crate::burn::TensorKind::Float => Some(quote! { @@ -167,6 +168,7 @@ impl NodeCodegen for ConstantNode { move |device, _require_grad| Tensor::::zeros(#shape, device), device.clone(), false, + #shape.into(), ); }), crate::burn::TensorKind::Bool => Some(quote! { @@ -175,6 +177,7 @@ impl NodeCodegen for ConstantNode { move |device, _require_grad| Tensor::::empty(#shape, device), device.clone(), false, + #shape.into(), ); }), } @@ -474,7 +477,8 @@ mod tests { burn::module::ParamId::new(), move |device, _require_grad| Tensor::::zeros([4], device), device.clone(), - false + false, + [4].into(), ); Self { @@ -530,7 +534,8 @@ mod tests { burn::module::ParamId::new(), move |device, _require_grad| Tensor::::zeros([3], device), device.clone(), - false + false, + [3].into(), ); Self { @@ -586,7 +591,8 @@ mod tests { burn::module::ParamId::new(), move |device, _require_grad| Tensor::::empty([1, 3, 2], device), device.clone(), - false + false, + [1, 3, 2].into(), ); Self { diff --git a/crates/burn-nn/src/activation/activation_wrapper.rs b/crates/burn-nn/src/activation/activation_wrapper.rs index 6878b73c0c..3342d810d5 100644 --- a/crates/burn-nn/src/activation/activation_wrapper.rs +++ b/crates/burn-nn/src/activation/activation_wrapper.rs @@ -83,6 +83,7 @@ impl ActivationConfig { /// Provides support for many in-built `burn::nn` activations. #[derive(Module, Debug)] #[non_exhaustive] +#[allow(clippy::large_enum_variant)] pub enum Activation { /// [`Gelu`] activation layer. Gelu(Gelu), diff --git a/crates/burn-no-std-tests/Cargo.toml b/crates/burn-no-std-tests/Cargo.toml index 4608787574..b0bb958405 100644 --- a/crates/burn-no-std-tests/Cargo.toml +++ b/crates/burn-no-std-tests/Cargo.toml @@ -19,4 +19,4 @@ workspace = true burn = { path = "../burn", version = "0.19.0", default-features = false } burn-ndarray = { path = "../burn-ndarray", version = "0.19.0", default-features = false } -burn-store = { path = "../burn-store", version = "0.19.0", default-features = false, features = ["safetensors"]} +burn-store = { path = "../burn-store", version = "0.19.0", default-features = false, features = ["safetensors", "burnpack"]} diff --git a/crates/burn-no-std-tests/src/burnpack.rs b/crates/burn-no-std-tests/src/burnpack.rs new file mode 100644 index 0000000000..7f30981d51 --- /dev/null +++ b/crates/burn-no-std-tests/src/burnpack.rs @@ -0,0 +1,158 @@ +// Test Burnpack storage in no-std environment + +use burn::{ + module::Module, + nn, + tensor::{Tensor, backend::Backend}, +}; + +use burn_store::{BurnpackStore, ModuleSnapshot, PathFilter}; + +/// Simple model for testing Burnpack storage +#[derive(Module, Debug)] +pub struct TestModel { + linear1: nn::Linear, + linear2: nn::Linear, + batch_norm: nn::BatchNorm, +} + +impl TestModel { + pub fn new(device: &B::Device) -> Self { + Self { + linear1: nn::LinearConfig::new(10, 20).init(device), + linear2: nn::LinearConfig::new(20, 10).init(device), + batch_norm: nn::BatchNormConfig::new(10).init(device), + } + } + + pub fn forward(&self, x: Tensor) -> Tensor { + let x = self.linear1.forward(x); + let x = self.linear2.forward(x); + // Apply batch norm (expand to 3D, apply, then squeeze back) + let x: Tensor = x.unsqueeze_dim(2); + let x = self.batch_norm.forward(x); + x.squeeze_dim(2) + } +} + +/// Test basic Burnpack save and load in no-std +pub fn test_burnpack_basic(device: &B::Device) { + // Create a model + let model = TestModel::::new(device); + + // Save to bytes (no file I/O in no-std) + let mut save_store = BurnpackStore::from_bytes(None); + model + .save_into(&mut save_store) + .expect("Failed to save model"); + + // Get the serialized bytes + let bytes = save_store.get_bytes().expect("Failed to get bytes"); + + // Load from bytes + let mut load_store = BurnpackStore::from_bytes(Some(bytes)); + let mut loaded_model = TestModel::::new(device); + let result = loaded_model + .load_from(&mut load_store) + .expect("Failed to load model"); + + // Verify all tensors were loaded + assert!(result.is_success(), "Should have no errors"); + assert!(!result.applied.is_empty(), "Should have loaded tensors"); + + // Test that the model still works + let input = Tensor::::ones([2, 10], device); + let _output = loaded_model.forward(input); +} + +/// Test Burnpack with filtering in no-std +pub fn test_burnpack_filtering(device: &B::Device) { + let model = TestModel::::new(device); + + // Save only linear1 weights + let filter = PathFilter::new() + .with_full_path("linear1.weight") + .with_full_path("linear1.bias"); + let mut save_store = BurnpackStore::from_bytes(None).with_filter(filter); + model + .save_into(&mut save_store) + .expect("Failed to save filtered model"); + + let bytes = save_store.get_bytes().expect("Failed to get bytes"); + + // Load with partial loading allowed + let mut load_store = BurnpackStore::from_bytes(Some(bytes)).allow_partial(true); + let mut partial_model = TestModel::::new(device); + let result = partial_model + .load_from(&mut load_store) + .expect("Failed to load partial model"); + + // Verify that only linear1 was loaded + assert_eq!(result.applied.len(), 2, "Should have loaded 2 tensors"); + assert!(!result.missing.is_empty(), "Should have missing tensors"); +} + +/// Test Burnpack with metadata in no-std +pub fn test_burnpack_metadata(device: &B::Device) { + let model = TestModel::::new(device); + + // Save with metadata + let mut save_store = BurnpackStore::from_bytes(None) + .metadata("version", "1.0.0") + .metadata("environment", "no-std") + .metadata("model_type", "test"); + model + .save_into(&mut save_store) + .expect("Failed to save model with metadata"); + + let bytes = save_store.get_bytes().expect("Failed to get bytes"); + + // Load and verify it works + let mut load_store = BurnpackStore::from_bytes(Some(bytes)); + let mut loaded_model = TestModel::::new(device); + let result = loaded_model + .load_from(&mut load_store) + .expect("Failed to load model with metadata"); + + assert!(result.is_success(), "Should load successfully"); +} + +// Note: Key remapping test is omitted as KeyRemapper requires std feature + +// Note: Regex filtering test is omitted as with_regex requires std feature + +/// Test Burnpack with match_all in no-std +pub fn test_burnpack_match_all(device: &B::Device) { + let model = TestModel::::new(device); + + // Save with match_all (should save everything) + let mut save_store = BurnpackStore::from_bytes(None).match_all(); + model + .save_into(&mut save_store) + .expect("Failed to save model"); + + let bytes = save_store.get_bytes().expect("Failed to get bytes"); + + // Load everything + let mut load_store = BurnpackStore::from_bytes(Some(bytes)); + let mut loaded_model = TestModel::::new(device); + let result = loaded_model + .load_from(&mut load_store) + .expect("Failed to load model"); + + assert!(result.is_success(), "Should load successfully"); + // linear1 (weight, bias) + linear2 (weight, bias) + batch_norm (4 params) + assert_eq!(result.applied.len(), 8, "Should load all 8 tensors"); + assert!(result.missing.is_empty(), "Should have no missing tensors"); + assert!(result.unused.is_empty(), "Should have no unused tensors"); +} + +/// Run all Burnpack no-std tests +pub fn run_all_tests(device: &B::Device) { + test_burnpack_basic::(device); + test_burnpack_filtering::(device); + test_burnpack_metadata::(device); + // test_burnpack_remapping requires KeyRemapper which needs std + // test_burnpack_regex_filter requires with_regex which needs std + test_burnpack_match_all::(device); +} diff --git a/crates/burn-no-std-tests/src/lib.rs b/crates/burn-no-std-tests/src/lib.rs index 4438fa39e5..58de35e102 100644 --- a/crates/burn-no-std-tests/src/lib.rs +++ b/crates/burn-no-std-tests/src/lib.rs @@ -1,5 +1,6 @@ #![no_std] +pub mod burnpack; pub mod conv; pub mod mlp; pub mod model; diff --git a/crates/burn-no-std-tests/src/safetensors.rs b/crates/burn-no-std-tests/src/safetensors.rs index f96b66cab1..b70378d723 100644 --- a/crates/burn-no-std-tests/src/safetensors.rs +++ b/crates/burn-no-std-tests/src/safetensors.rs @@ -37,7 +37,7 @@ pub fn test_safetensors_basic(device: &B::Device) { // Save to bytes (no file I/O in no-std) let mut save_store = SafetensorsStore::from_bytes(None); model - .collect_to(&mut save_store) + .save_into(&mut save_store) .expect("Failed to save model"); // Get the serialized bytes @@ -47,7 +47,7 @@ pub fn test_safetensors_basic(device: &B::Device) { let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); let mut loaded_model = TestModel::::new(device); loaded_model - .apply_from(&mut load_store) + .load_from(&mut load_store) .expect("Failed to load model"); // Test that the model still works @@ -64,7 +64,7 @@ pub fn test_safetensors_filtering(device: &B::Device) { .with_full_path("linear1.weight") .with_full_path("linear1.bias"); model - .collect_to(&mut save_store) + .save_into(&mut save_store) .expect("Failed to save filtered model"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); @@ -73,7 +73,7 @@ pub fn test_safetensors_filtering(device: &B::Device) { let mut load_store = SafetensorsStore::from_bytes(Some(bytes)).allow_partial(true); let mut partial_model = TestModel::::new(device); let result = partial_model - .apply_from(&mut load_store) + .load_from(&mut load_store) .expect("Failed to load partial model"); // Verify that only linear1 was loaded @@ -90,7 +90,7 @@ pub fn test_safetensors_metadata(device: &B::Device) { .metadata("version", "1.0.0") .metadata("environment", "no-std"); model - .collect_to(&mut save_store) + .save_into(&mut save_store) .expect("Failed to save model with metadata"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); @@ -99,7 +99,7 @@ pub fn test_safetensors_metadata(device: &B::Device) { let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); let mut loaded_model = TestModel::::new(device); loaded_model - .apply_from(&mut load_store) + .load_from(&mut load_store) .expect("Failed to load model with metadata"); } diff --git a/crates/burn-no-std-tests/tests/burnpack_tests.rs b/crates/burn-no-std-tests/tests/burnpack_tests.rs new file mode 100644 index 0000000000..c050288b3c --- /dev/null +++ b/crates/burn-no-std-tests/tests/burnpack_tests.rs @@ -0,0 +1,12 @@ +extern crate alloc; + +#[test] +fn test_burnpack_no_std() { + use burn_ndarray::NdArray; + use burn_no_std_tests::burnpack; + type Backend = NdArray; + let device = Default::default(); + + // Run all Burnpack tests + burnpack::run_all_tests::(&device); +} diff --git a/crates/burn-optim/src/optim/grad_accum.rs b/crates/burn-optim/src/optim/grad_accum.rs index 95d4a7be76..831746368f 100644 --- a/crates/burn-optim/src/optim/grad_accum.rs +++ b/crates/burn-optim/src/optim/grad_accum.rs @@ -2,7 +2,7 @@ use burn_core as burn; use core::marker::PhantomData; -use burn::module::{AutodiffModule, ModuleVisitor, ParamId}; +use burn::module::{AutodiffModule, ModuleVisitor, Param}; use burn::tensor::{Tensor, backend::AutodiffBackend}; use super::GradientsParams; @@ -56,19 +56,20 @@ struct ModuleGradsAccumulator<'a, M> { } impl> ModuleVisitor for ModuleGradsAccumulator<'_, M> { - fn visit_float(&mut self, id: ParamId, _tensor: &Tensor) { - let grad_updated = match self.grads_new.remove::(id) { - Some(new) => match self.grads.remove::(id) { + fn visit_float(&mut self, param: &Param>) { + let grad_updated = match self.grads_new.remove::(param.id) { + Some(new) => match self.grads.remove::(param.id) { Some(grad) => grad.add(new), None => new, }, - None => match self.grads.remove::(id) { + None => match self.grads.remove::(param.id) { Some(grad) => grad, None => return, }, }; - self.grads.register::(id, grad_updated); + self.grads + .register::(param.id, grad_updated); } } diff --git a/crates/burn-optim/src/optim/simple/adaptor.rs b/crates/burn-optim/src/optim/simple/adaptor.rs index 59cf699919..a59f098cfd 100644 --- a/crates/burn-optim/src/optim/simple/adaptor.rs +++ b/crates/burn-optim/src/optim/simple/adaptor.rs @@ -7,7 +7,7 @@ use crate::{ optim::{GradientsParams, Optimizer}, }; -use burn::module::{AutodiffModule, ModuleMapper, ParamId}; +use burn::module::{AutodiffModule, ModuleMapper, Param, ParamId}; use burn::tensor::{Tensor, backend::AutodiffBackend}; use core::marker::PhantomData; use hashbrown::HashMap; @@ -119,10 +119,11 @@ where B: AutodiffBackend, O: SimpleOptimizer, { - fn map_float(&mut self, id: ParamId, tensor: Tensor) -> Tensor { + fn map_float(&mut self, param: Param>) -> Param> { + let (id, tensor, mapper) = param.consume(); let grad = self.grads.remove(id); - if let Some(grad) = grad { + let tensor = if let Some(grad) = grad { let device = grad.device(); let is_require_grad = tensor.is_require_grad(); let (key, record) = self.records.remove_entry(&id).unzip(); @@ -149,9 +150,11 @@ where if is_require_grad { tensor = tensor.require_grad(); } - return tensor; - } + tensor + } else { + tensor + }; - tensor + Param::from_mapped_value(id, tensor, mapper) } } diff --git a/crates/burn-optim/src/optim/visitor.rs b/crates/burn-optim/src/optim/visitor.rs index 87a0b31f61..603cb624ac 100644 --- a/crates/burn-optim/src/optim/visitor.rs +++ b/crates/burn-optim/src/optim/visitor.rs @@ -1,7 +1,7 @@ use burn_core as burn; use super::GradientsParams; -use burn::module::{AutodiffModule, ModuleVisitor, ParamId}; +use burn::module::{AutodiffModule, ModuleVisitor, Param, ParamId}; use burn::tensor::{Tensor, backend::AutodiffBackend}; use core::marker::PhantomData; @@ -28,18 +28,19 @@ where B: AutodiffBackend, M: AutodiffModule, { - fn visit_float(&mut self, id: ParamId, tensor: &Tensor) { + fn visit_float(&mut self, param: &Param>) { if let Some(filter) = self.filter.as_ref() - && !filter.contains(&id) + && !filter.contains(¶m.id) { return; } - let Some(grad) = tensor.grad_remove(self.grads) else { + let Some(grad) = param.val().grad_remove(self.grads) else { return; }; - self.grads_params.register::(id, grad); + self.grads_params + .register::(param.id, grad); } } @@ -48,12 +49,12 @@ where B: AutodiffBackend, M: AutodiffModule, { - fn visit_float(&mut self, id: ParamId, _tensor: &Tensor) { - let Some(grad) = self.grads.remove::(id) else { + fn visit_float(&mut self, param: &Param>) { + let Some(grad) = self.grads.remove::(param.id) else { return; }; self.grads - .register::(id, grad.to_device(self.device)); + .register::(param.id, grad.to_device(self.device)); } } diff --git a/crates/burn-store/Cargo.toml b/crates/burn-store/Cargo.toml index fb4883f96b..fc3bb99605 100644 --- a/crates/burn-store/Cargo.toml +++ b/crates/burn-store/Cargo.toml @@ -21,22 +21,17 @@ version.workspace = true workspace = true [features] -candle = ["burn-candle"] -cuda = ["burn-cuda"] -default = ["std", "pytorch", "safetensors"] +default = ["std", "pytorch", "safetensors", "burnpack", "memmap"] +std = ["dep:memmap2", "safetensors/std", "burn-core/std", "burn-tensor/std", "dep:regex", "byteorder/std"] +memmap = ["std", "dep:memmap2"] +burnpack = ["serde", "ciborium"] +wgpu = ["burn-wgpu"] metal = ["wgpu", "burn-wgpu/metal"] -pytorch = ["burn-core/record-item-custom-serde", "zip", "serde"] -safetensors = [] -std = [ - "dep:memmap2", - "safetensors/std", - "burn-core/std", - "burn-tensor/std", - "dep:regex", - "byteorder/std", -] +cuda = ["burn-cuda"] +candle = ["burn-candle"] tch = ["burn-tch"] -wgpu = ["burn-wgpu"] +safetensors = [] +pytorch = ["burn-core/record-item-custom-serde", "zip", "serde"] [dependencies] burn-core = { path = "../burn-core", version = "0.19.0", default-features = false } @@ -49,6 +44,7 @@ half = { workspace = true } hashbrown = { workspace = true, features = ["serde"] } memmap2 = { workspace = true, optional = true } regex = { workspace = true, optional = true } +ciborium = { workspace = true, optional = true } serde = { workspace = true, optional = true } zip = { workspace = true, optional = true } @@ -72,9 +68,14 @@ divan = "0.1" tempfile = { workspace = true } [[bench]] -harness = false name = "resnet18_loading" +harness = false [[bench]] -harness = false name = "unified_loading" +harness = false + +[[bench]] +name = "unified_saving" +harness = false + diff --git a/crates/burn-store/README.md b/crates/burn-store/README.md index 1564a0dda9..e56c7da1b2 100644 --- a/crates/burn-store/README.md +++ b/crates/burn-store/README.md @@ -11,6 +11,8 @@ interoperability, and advanced tensor management. ### Core Capabilities +- **Burnpack Format** - Native Burn format with CBOR metadata, memory-mapped loading, ParamId + persistence for stateful training, and no-std support - **SafeTensors Format** - Industry-standard format for secure and efficient tensor serialization - **PyTorch Support** - Direct loading of PyTorch .pth/.pt files with automatic weight transformation @@ -20,7 +22,7 @@ interoperability, and advanced tensor management. - **Flexible Filtering** - Load/save specific model subsets with regex, exact paths, or custom predicates - **Tensor Remapping** - Rename tensors during load/save for framework compatibility -- **No-std Support** - Core functionality available in embedded and WASM environments +- **No-std Support** - Burnpack and SafeTensors formats available in embedded and WASM environments ### Advanced Features @@ -34,16 +36,38 @@ interoperability, and advanced tensor management. ### Basic Save and Load +#### Burnpack (Native Format) + +```rust +use burn_store::{ModuleSnapshot, BurnpackStore}; + +// Save a model with metadata +let mut store = BurnpackStore::from_file("model.bpk") + .metadata("version", "1.0") + .metadata("description", "My trained model"); +model.save_into(&mut store)?; + +// Load a model (automatically memory-mapped when available) +let mut store = BurnpackStore::from_file("model.bpk"); +model.load_from(&mut store)?; +``` + +**Performance**: Burnpack provides faster loading times and reduced memory overhead compared to other formats. + +**Training State Persistence**: Burnpack automatically preserves parameter identifiers (ParamId) for stateful training continuation. + +#### SafeTensors + ```rust use burn_store::{ModuleSnapshot, SafetensorsStore}; // Save a model let mut store = SafetensorsStore::from_file("model.safetensors"); -model.collect_to(&mut store)?; +model.save_into(&mut store)?; // Load a model let mut store = SafetensorsStore::from_file("model.safetensors"); -model.apply_from(&mut store)?; +model.load_from(&mut store)?; ``` ### Filtering Tensors @@ -54,7 +78,7 @@ let mut store = SafetensorsStore::from_file("encoder.safetensors") .with_regex(r"^encoder\..*") .metadata("subset", "encoder_only"); -model.collect_to(&mut store)?; +model.save_into(&mut store)?; // Load with multiple filter patterns (OR logic) let mut store = SafetensorsStore::from_file("model.safetensors") @@ -62,7 +86,7 @@ let mut store = SafetensorsStore::from_file("model.safetensors") .with_regex(r".*\.bias$") // OR include any bias tensors .with_full_path("decoder.scale"); // OR include specific tensor -model.apply_from(&mut store)?; +model.load_from(&mut store)?; ``` ### PyTorch Interoperability @@ -75,20 +99,20 @@ let mut store = PytorchStore::from_file("pytorch_model.pth") .with_top_level_key("state_dict") // Access nested state dict .allow_partial(true); // Skip unknown tensors -burn_model.apply_from(&mut store)?; +burn_model.load_from(&mut store)?; // Load PyTorch model from SafeTensors let mut store = SafetensorsStore::from_file("pytorch_model.safetensors") .with_from_adapter(PyTorchToBurnAdapter) // Auto-transpose linear weights .allow_partial(true); // Skip unknown PyTorch tensors -burn_model.apply_from(&mut store)?; +burn_model.load_from(&mut store)?; // Save Burn model for PyTorch let mut store = SafetensorsStore::from_file("for_pytorch.safetensors") .with_to_adapter(BurnToPyTorchAdapter); // Convert back to PyTorch format -burn_model.collect_to(&mut store)?; +burn_model.save_into(&mut store)?; ``` ### Tensor Name Remapping @@ -119,16 +143,23 @@ let mut store = PytorchStore::from_file("model.pth") ### Memory Operations ```rust -// Save to memory buffer -let mut store = SafetensorsStore::from_bytes(None) - .with_regex(r"^encoder\..*"); -model.collect_to(&mut store)?; +// Burnpack: Save to memory buffer +let mut store = BurnpackStore::from_bytes(None) + .with_regex(r"^encoder\..*") + .metadata("subset", "encoder_only"); +model.save_into(&mut store)?; let bytes = store.get_bytes()?; -// Load from memory buffer -let mut store = SafetensorsStore::from_bytes(Some(bytes)) +// Burnpack: Load from memory buffer (no-std compatible) +let mut store = BurnpackStore::from_bytes(Some(bytes)) .allow_partial(true); -let result = model.apply_from(&mut store)?; +let result = model.load_from(&mut store)?; + +// SafeTensors: Memory operations +let mut store = SafetensorsStore::from_bytes(None) + .with_regex(r"^encoder\..*"); +model.save_into(&mut store)?; +let bytes = store.get_bytes()?; println!("Loaded {} tensors", result.applied.len()); if !result.missing.is_empty() { @@ -136,7 +167,7 @@ if !result.missing.is_empty() { } ``` -SafetensorsStore supports no-std environments when using byte operations +Both BurnpackStore and SafetensorsStore support no-std environments when using byte operations ### Model Surgery and Partial Operations @@ -171,12 +202,12 @@ model2.apply(snapshots, None, None); // Export only specific layers let mut store = SafetensorsStore::from_file("encoder_only.safetensors") .with_regex(r"^encoder\..*"); -model.collect_to(&mut store)?; +model.save_into(&mut store)?; // Load with missing tensors allowed let mut store = SafetensorsStore::from_file("pretrained.safetensors") .allow_partial(true); -let result = model.apply_from(&mut store)?; +let result = model.load_from(&mut store)?; println!("Loaded: {}, Missing: {:?}", result.applied.len(), result.missing); ``` @@ -196,12 +227,12 @@ target_model.apply(merged, None, None); // Alternative: Sequential loading from files let mut base_store = SafetensorsStore::from_file("base.safetensors"); -model.apply_from(&mut base_store)?; +model.load_from(&mut base_store)?; let mut encoder_store = SafetensorsStore::from_file("encoder.safetensors") .with_regex(r"^encoder\..*") .allow_partial(true); -model.apply_from(&mut encoder_store)?; // Overlays encoder weights +model.load_from(&mut encoder_store)?; // Overlays encoder weights ``` ### Complete Example: Migrating PyTorch Models @@ -223,7 +254,7 @@ let mut store = PytorchStore::from_file("pytorch_transformer.pth") .allow_partial(true); let mut model = TransformerModel::new(&device); -let result = model.apply_from(&mut store)?; +let result = model.load_from(&mut store)?; println!("Successfully migrated {} tensors", result.applied.len()); if !result.errors.is_empty() { @@ -235,7 +266,7 @@ let mut save_store = SafetensorsStore::from_file("migrated_model.safetensors") .metadata("source", "pytorch") .metadata("converted_by", "burn-store"); -model.collect_to(&mut save_store)?; +model.save_into(&mut save_store)?; ``` ## Advanced Usage @@ -265,7 +296,7 @@ let mut store = SafetensorsStore::from_file("model.safetensors") ### Handling Load Results ```rust -let result = model.apply_from(&mut store)?; +let result = model.load_from(&mut store)?; // Detailed result information println!("Applied: {} tensors", result.applied.len()); @@ -282,12 +313,14 @@ if !result.errors.is_empty() { ## Benchmarks +### Loading Benchmarks + ```bash # Generate model files first (one-time setup) cd crates/burn-store uv run benches/generate_unified_models.py -# Run unified benchmark with default backend (NdArray CPU) +# Run unified loading benchmark with default backend (NdArray CPU) cargo bench --bench unified_loading # Run with specific backend @@ -298,7 +331,26 @@ cargo bench --bench unified_loading --features candle # Candle backend cargo bench --bench unified_loading --features tch # LibTorch # Run with multiple backends -cargo bench --bench unified_loading --features "wgpu metal" +cargo bench --bench unified_loading --features wgpu,tch +``` + +### Saving Benchmarks + +Compares 3 saving methods: BurnpackStore, NamedMpkFileRecorder, and SafetensorsStore. + +```bash +# Run unified saving benchmark with default backend (NdArray CPU) +cargo bench --bench unified_saving + +# Run with specific backend +cargo bench --bench unified_saving --features metal # Apple GPU +cargo bench --bench unified_saving --features wgpu # WebGPU +cargo bench --bench unified_saving --features cuda # NVIDIA GPU +cargo bench --bench unified_saving --features candle # Candle backend +cargo bench --bench unified_saving --features tch # LibTorch + +# Run with multiple backends +cargo bench --bench unified_saving --features wgpu,tch ``` ## API Overview @@ -327,10 +379,22 @@ The stores provide a fluent API for configuration: #### Configuration -- `metadata(key, value)` - Add custom metadata (SafeTensors only) +- `metadata(key, value)` - Add custom metadata (Burnpack and SafeTensors) - `allow_partial(bool)` - Continue on missing tensors - `validate(bool)` - Toggle validation - `with_top_level_key(key)` - Access nested dict in PyTorch files +- `overwrite(bool)` - Allow overwriting existing files (Burnpack) + +### Inspecting Burnpack Files + +Generate and examine a sample file: + +```bash +cargo run --example burnpack_inspect sample.bpk +hexdump -C sample.bpk | head -20 +``` + +The example creates a sample model and outputs inspection commands for examining the binary format. ## License diff --git a/crates/burn-store/benches/unified_loading.rs b/crates/burn-store/benches/unified_loading.rs index 4885ffe7be..9b86bec355 100644 --- a/crates/burn-store/benches/unified_loading.rs +++ b/crates/burn-store/benches/unified_loading.rs @@ -1,6 +1,8 @@ #![recursion_limit = "256"] //! Unified benchmark comparing all loading methods: +//! - BurnpackStore (new native format) +//! - NamedMpkFileRecorder (old native format) //! - SafetensorsStore (new) //! - SafetensorsFileRecorder (old) //! - PytorchStore (new) @@ -19,11 +21,13 @@ use burn_core::module::Module; use burn_core::prelude::*; -use burn_core::record::{FullPrecisionSettings, Recorder}; +use burn_core::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder}; use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder}; use burn_import::safetensors::SafetensorsFileRecorder; use burn_nn as nn; -use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, PytorchStore, SafetensorsStore}; +use burn_store::{ + BurnpackStore, ModuleSnapshot, PyTorchToBurnAdapter, PytorchStore, SafetensorsStore, +}; use divan::{AllocProfiler, Bencher}; use std::fs; use std::path::PathBuf; @@ -47,7 +51,7 @@ type CandleBackend = burn_candle::Candle; type TchBackend = burn_tch::LibTorch; #[cfg(feature = "metal")] -type MetalBackend = burn_metal::Metal; +type MetalBackend = burn_wgpu::Metal; // Use the same LargeModel as other benchmarks for fair comparison #[derive(Module, Debug)] @@ -72,10 +76,44 @@ fn get_model_dir() -> PathBuf { std::env::temp_dir().join("simple_bench_models") } +/// Generate Burnpack and NamedMpk files from existing SafeTensors file +fn generate_burn_formats(st_path: &PathBuf, bp_path: &PathBuf, mpk_path: &PathBuf) { + type TestBackend = NdArrayBackend; + let device = Default::default(); + + // Load the model from SafeTensors + let mut model = LargeModel::::new(&device); + let mut store = + SafetensorsStore::from_file(st_path.clone()).with_from_adapter(PyTorchToBurnAdapter); + model + .load_from(&mut store) + .expect("Failed to load from SafeTensors"); + + // Save as Burnpack + if !bp_path.exists() { + println!(" Creating Burnpack file..."); + let mut burnpack_store = BurnpackStore::from_file(bp_path.clone()); + model + .save_into(&mut burnpack_store) + .expect("Failed to save as Burnpack"); + } + + // Save as NamedMpk + if !mpk_path.exists() { + println!(" Creating NamedMpk file..."); + let recorder = NamedMpkFileRecorder::::default(); + model + .save_file(mpk_path.clone(), &recorder) + .expect("Failed to save as NamedMpk"); + } +} + /// Get paths to the model files -fn get_model_paths() -> (PathBuf, PathBuf) { +fn get_model_paths() -> (PathBuf, PathBuf, PathBuf, PathBuf) { let dir = get_model_dir(); ( + dir.join("large_model.bpk"), + dir.join("large_model.mpk"), dir.join("large_model.safetensors"), dir.join("large_model.pt"), ) @@ -83,8 +121,9 @@ fn get_model_paths() -> (PathBuf, PathBuf) { /// Check if model files exist fn check_model_files() -> Result<(), String> { - let (st_path, pt_path) = get_model_paths(); + let (bp_path, mpk_path, st_path, pt_path) = get_model_paths(); + // For now, only check safetensors and pytorch files (will generate burnpack/mpk later) if !st_path.exists() || !pt_path.exists() { return Err(format!( "\n❌ Model files not found!\n\ @@ -109,21 +148,42 @@ fn main() { // Check if model files exist before running benchmarks match check_model_files() { Ok(()) => { - let (st_path, pt_path) = get_model_paths(); + let (bp_path, mpk_path, st_path, pt_path) = get_model_paths(); + + // First, generate Burnpack and MPK files if they don't exist + if !bp_path.exists() || !mpk_path.exists() { + println!("⏳ Generating Burnpack and NamedMpk files from SafeTensors..."); + generate_burn_formats(&st_path, &bp_path, &mpk_path); + } + + let bp_size = fs::metadata(&bp_path) + .ok() + .map(|m| m.len() as f64 / 1_048_576.0); + let mpk_size = fs::metadata(&mpk_path) + .ok() + .map(|m| m.len() as f64 / 1_048_576.0); let st_size = fs::metadata(&st_path).unwrap().len() as f64 / 1_048_576.0; let pt_size = fs::metadata(&pt_path).unwrap().len() as f64 / 1_048_576.0; println!("✅ Found model files:"); + if let Some(size) = bp_size { + println!(" Burnpack: {} ({:.1} MB)", bp_path.display(), size); + } + if let Some(size) = mpk_size { + println!(" NamedMpk: {} ({:.1} MB)", mpk_path.display(), size); + } println!(" SafeTensors: {} ({:.1} MB)", st_path.display(), st_size); println!(" PyTorch: {} ({:.1} MB)", pt_path.display(), pt_size); println!(); println!("🚀 Running unified loading benchmarks..."); println!(); - println!("Comparing 4 loading methods:"); - println!(" 1. SafetensorsStore (new)"); - println!(" 2. SafetensorsFileRecorder (old)"); - println!(" 3. PytorchStore (new)"); - println!(" 4. PyTorchFileRecorder (old)"); + println!("Comparing 6 loading methods:"); + println!(" 1. BurnpackStore (new native format - lazy loading)"); + println!(" 2. NamedMpkFileRecorder (old native format - loads all to memory)"); + println!(" 3. SafetensorsStore (new)"); + println!(" 4. SafetensorsFileRecorder (old)"); + println!(" 5. PytorchStore (new)"); + println!(" 6. PyTorchFileRecorder (old)"); println!(); println!("Available backends:"); println!(" - NdArray (CPU)"); @@ -158,9 +218,41 @@ macro_rules! bench_backend { type TestBackend = $backend; type TestDevice = ::Device; + #[divan::bench] + fn burnpack_store(bencher: Bencher) { + let (bp_path, _, _, _) = get_model_paths(); + let file_size = fs::metadata(&bp_path).unwrap().len(); + + bencher + .counter(divan::counter::BytesCount::new(file_size)) + .bench(|| { + let device: TestDevice = Default::default(); + let mut model = LargeModel::::new(&device); + let mut store = BurnpackStore::from_file(bp_path.clone()); + model.load_from(&mut store).expect("Failed to load"); + }); + } + + #[divan::bench] + fn namedmpk_recorder(bencher: Bencher) { + let (_, mpk_path, _, _) = get_model_paths(); + let file_size = fs::metadata(&mpk_path).unwrap().len(); + + bencher + .counter(divan::counter::BytesCount::new(file_size)) + .bench(|| { + let device: TestDevice = Default::default(); + let recorder = NamedMpkFileRecorder::::default(); + let record = recorder + .load(mpk_path.clone().into(), &device) + .expect("Failed to load"); + let _model = LargeModel::::new(&device).load_record(record); + }); + } + #[divan::bench] fn safetensors_store(bencher: Bencher) { - let (st_path, _) = get_model_paths(); + let (_, _, st_path, _) = get_model_paths(); let file_size = fs::metadata(&st_path).unwrap().len(); bencher @@ -170,13 +262,13 @@ macro_rules! bench_backend { let mut model = LargeModel::::new(&device); let mut store = SafetensorsStore::from_file(st_path.clone()) .with_from_adapter(PyTorchToBurnAdapter); - model.apply_from(&mut store).expect("Failed to load"); + model.load_from(&mut store).expect("Failed to load"); }); } #[divan::bench] fn safetensors_recorder(bencher: Bencher) { - let (st_path, _) = get_model_paths(); + let (_, _, st_path, _) = get_model_paths(); let file_size = fs::metadata(&st_path).unwrap().len(); bencher @@ -193,7 +285,7 @@ macro_rules! bench_backend { #[divan::bench] fn pytorch_store(bencher: Bencher) { - let (_, pt_path) = get_model_paths(); + let (_, _, _, pt_path) = get_model_paths(); let file_size = fs::metadata(&pt_path).unwrap().len(); bencher @@ -204,13 +296,13 @@ macro_rules! bench_backend { let mut store = PytorchStore::from_file(pt_path.clone()) .with_top_level_key("model_state_dict") .allow_partial(true); - model.apply_from(&mut store).expect("Failed to load"); + model.load_from(&mut store).expect("Failed to load"); }); } #[divan::bench] fn pytorch_recorder(bencher: Bencher) { - let (_, pt_path) = get_model_paths(); + let (_, _, _, pt_path) = get_model_paths(); let file_size = fs::metadata(&pt_path).unwrap().len(); bencher diff --git a/crates/burn-store/benches/unified_saving.rs b/crates/burn-store/benches/unified_saving.rs new file mode 100644 index 0000000000..2e194bb77d --- /dev/null +++ b/crates/burn-store/benches/unified_saving.rs @@ -0,0 +1,190 @@ +#![recursion_limit = "256"] + +//! Unified benchmark comparing all saving methods: +//! - BurnpackStore (new native format) +//! - NamedMpkFileRecorder (old native format) +//! - SafetensorsStore (new) +//! +//! Before running this benchmark, ensure the directory exists: +//! ```bash +//! mkdir -p /tmp/simple_bench_models +//! ``` +//! +//! Then run the benchmark: +//! ```bash +//! cargo bench --bench unified_saving +//! ``` + +use burn_core::module::Module; +use burn_core::prelude::*; +use burn_core::record::{FullPrecisionSettings, NamedMpkFileRecorder}; +use burn_nn as nn; +use burn_store::{BurnpackStore, ModuleSnapshot, SafetensorsStore}; +use divan::{AllocProfiler, Bencher}; +use std::fs; +use std::path::PathBuf; + +#[global_allocator] +static ALLOC: AllocProfiler = AllocProfiler::system(); + +// Backend type aliases +type NdArrayBackend = burn_ndarray::NdArray; + +#[cfg(feature = "wgpu")] +type WgpuBackend = burn_wgpu::Wgpu; + +#[cfg(feature = "cuda")] +type CudaBackend = burn_cuda::Cuda; + +#[cfg(feature = "candle")] +type CandleBackend = burn_candle::Candle; + +#[cfg(feature = "tch")] +type TchBackend = burn_tch::LibTorch; + +#[cfg(feature = "metal")] +type MetalBackend = burn_wgpu::Metal; + +// Use the same LargeModel as other benchmarks for fair comparison +#[derive(Module, Debug)] +struct LargeModel { + layers: Vec>, +} + +impl LargeModel { + fn new(device: &B::Device) -> Self { + let mut layers = Vec::new(); + // Create a model with 20 layers - same as loading benchmarks + for i in 0..20 { + let in_size = if i == 0 { 1024 } else { 2048 }; + layers.push(nn::LinearConfig::new(in_size, 2048).init(device)); + } + Self { layers } + } +} + +/// Get the path to the output directory +fn get_output_dir() -> PathBuf { + std::env::temp_dir().join("simple_bench_models_saving") +} + +/// Ensure output directory exists +fn ensure_output_dir() -> Result<(), String> { + let dir = get_output_dir(); + if !dir.exists() { + fs::create_dir_all(&dir) + .map_err(|e| format!("Failed to create output directory: {}", e))?; + } + Ok(()) +} + +fn main() { + match ensure_output_dir() { + Ok(()) => { + println!("✅ Output directory ready: {}", get_output_dir().display()); + println!(); + println!("🚀 Running unified saving benchmarks..."); + println!(); + println!("Comparing 3 saving methods:"); + println!(" 1. BurnpackStore (new native format)"); + println!(" 2. NamedMpkFileRecorder (old native format)"); + println!(" 3. SafetensorsStore (new)"); + println!(); + println!("Available backends:"); + println!(" - NdArray (CPU)"); + #[cfg(feature = "wgpu")] + println!(" - WGPU (GPU)"); + #[cfg(feature = "cuda")] + println!(" - CUDA (NVIDIA GPU)"); + #[cfg(feature = "candle")] + println!(" - Candle"); + #[cfg(feature = "tch")] + println!(" - LibTorch"); + #[cfg(feature = "metal")] + println!(" - Metal (Apple GPU)"); + println!(); + + divan::main(); + } + Err(msg) => { + eprintln!("❌ {}", msg); + std::process::exit(1); + } + } +} + +// Macro to generate benchmarks for each backend +macro_rules! bench_backend { + ($backend:ty, $mod_name:ident, $backend_name:literal) => { + #[divan::bench_group(name = $backend_name, sample_count = 10)] + mod $mod_name { + use super::*; + + type TestBackend = $backend; + type TestDevice = ::Device; + + #[divan::bench] + fn burnpack_store(bencher: Bencher) { + bencher.bench(|| { + let device: TestDevice = Default::default(); + let model = LargeModel::::new(&device); + let output_path = get_output_dir().join("test_burnpack.bpk"); + let mut store = BurnpackStore::from_file(output_path.clone()).overwrite(true); + model + .save_into(&mut store) + .expect("Failed to save with BurnpackStore"); + // Clean up + let _ = fs::remove_file(output_path); + }); + } + + #[divan::bench] + fn namedmpk_recorder(bencher: Bencher) { + bencher.bench(|| { + let device: TestDevice = Default::default(); + let model = LargeModel::::new(&device); + let output_path = get_output_dir().join("test_namedmpk.mpk"); + let recorder = NamedMpkFileRecorder::::default(); + model + .save_file(output_path.clone(), &recorder) + .expect("Failed to save with NamedMpkFileRecorder"); + // Clean up + let _ = fs::remove_file(output_path); + }); + } + + #[divan::bench] + fn safetensors_store(bencher: Bencher) { + bencher.bench(|| { + let device: TestDevice = Default::default(); + let model = LargeModel::::new(&device); + let output_path = get_output_dir().join("test_safetensors_store.safetensors"); + let mut store = SafetensorsStore::from_file(output_path.clone()); + model + .save_into(&mut store) + .expect("Failed to save with SafetensorsStore"); + // Clean up + let _ = fs::remove_file(output_path); + }); + } + } + }; +} + +// Generate benchmarks for each backend +bench_backend!(NdArrayBackend, ndarray_backend, "NdArray Backend (CPU)"); + +#[cfg(feature = "wgpu")] +bench_backend!(WgpuBackend, wgpu_backend, "WGPU Backend (GPU)"); + +#[cfg(feature = "cuda")] +bench_backend!(CudaBackend, cuda_backend, "CUDA Backend (NVIDIA GPU)"); + +#[cfg(feature = "candle")] +bench_backend!(CandleBackend, candle_backend, "Candle Backend"); + +#[cfg(feature = "tch")] +bench_backend!(TchBackend, tch_backend, "LibTorch Backend"); + +#[cfg(feature = "metal")] +bench_backend!(MetalBackend, metal_backend, "Metal Backend (Apple GPU)"); diff --git a/crates/burn-store/examples/burnpack_inspect.rs b/crates/burn-store/examples/burnpack_inspect.rs new file mode 100644 index 0000000000..01eed43c12 --- /dev/null +++ b/crates/burn-store/examples/burnpack_inspect.rs @@ -0,0 +1,147 @@ +//! Example: Generate a Burnpack file for inspection +//! +//! This example creates a simple Burnpack file that you can examine to understand the format. +//! +//! Usage: +//! cargo run --example burnpack-inspect [output_path] +//! +//! Example: +//! cargo run --example burnpack-inspect sample.bpk +//! cargo run --example burnpack-inspect /tmp/test.bpk +//! +//! After generating the file, examine it with: +//! hexdump -C sample.bpk | head -100 +//! xxd sample.bpk | head -100 +//! hexyl sample.bpk + +use burn_core::module::Module; +use burn_ndarray::NdArray; +use burn_nn::{Linear, LinearConfig}; +use burn_store::{BurnpackStore, ModuleSnapshot}; +use burn_tensor::backend::Backend; +use std::env; + +// Simple model with a few layers +#[derive(Module, Debug)] +struct SampleModel { + linear1: Linear, + linear2: Linear, + linear3: Linear, +} + +impl SampleModel { + fn new(device: &B::Device) -> Self { + Self { + linear1: LinearConfig::new(128, 64).init(device), + linear2: LinearConfig::new(64, 32).init(device), + linear3: LinearConfig::new(32, 10).init(device), + } + } +} + +fn main() { + type Backend = NdArray; + + // Get output path from command line or use default + let output_path = env::args() + .nth(1) + .unwrap_or_else(|| "sample.bpk".to_string()); + + println!("Creating sample Burnpack file: {}", output_path); + println!(); + + // Create a simple model + let device = Default::default(); + let model = SampleModel::::new(&device); + + // Save to Burnpack format with metadata + let mut store = BurnpackStore::from_file(&output_path) + .overwrite(true) + .metadata("format", "burnpack") + .metadata("description", "Sample file for examining Burnpack format") + .metadata("version", env!("CARGO_PKG_VERSION")) + .metadata("author", "Burn Example"); + + model.save_into(&mut store).expect("Failed to save model"); + + println!("✅ Successfully created: {}", output_path); + println!(); + println!("📋 File Structure:"); + println!(" ┌─────────────────────────────────────┐"); + println!(" │ Header (10 bytes) │"); + println!(" ├─────────────────────────────────────┤"); + println!(" │ - Magic: 0x4E525542 (BURN in LE) │"); + println!(" │ - Version: 0x0001 (2 bytes) │"); + println!(" │ - Metadata size: (4 bytes, u32 LE) │"); + println!(" ├─────────────────────────────────────┤"); + println!(" │ Metadata (CBOR format) │"); + println!(" ├─────────────────────────────────────┤"); + println!(" │ - Tensor descriptors │"); + println!(" │ * name, dtype, shape, offsets │"); + println!(" │ - User metadata │"); + println!(" ├─────────────────────────────────────┤"); + println!(" │ Tensor Data (raw bytes, LE) │"); + println!(" ├─────────────────────────────────────┤"); + println!(" │ - linear1.weight [64, 128] │"); + println!(" │ - linear1.bias [64] │"); + println!(" │ - linear2.weight [32, 64] │"); + println!(" │ - linear2.bias [32] │"); + println!(" │ - linear3.weight [10, 32] │"); + println!(" │ - linear3.bias [10] │"); + println!(" └─────────────────────────────────────┘"); + println!(); + println!("📊 Model Contents:"); + println!(" - linear1.weight: [64, 128] = 8,192 params → 32,768 bytes"); + println!(" - linear1.bias: [64] = 64 params → 256 bytes"); + println!(" - linear2.weight: [32, 64] = 2,048 params → 8,192 bytes"); + println!(" - linear2.bias: [32] = 32 params → 128 bytes"); + println!(" - linear3.weight: [10, 32] = 320 params → 1,280 bytes"); + println!(" - linear3.bias: [10] = 10 params → 40 bytes"); + println!(" ───────────────────────────────────────────────────────"); + + let total_params = 8192 + 64 + 2048 + 32 + 320 + 10; + let total_bytes = total_params * 4; + println!( + " Total: {} parameters = {} KB", + total_params, + total_bytes / 1024 + ); + println!(); + + // Get actual file size + if let Ok(metadata) = std::fs::metadata(&output_path) { + let file_size = metadata.len(); + println!( + "📦 File size: {} bytes ({:.2} KB)", + file_size, + file_size as f64 / 1024.0 + ); + } + + println!(); + println!("🔍 Inspection Commands:"); + println!(); + println!(" # View first 100 bytes in hex:"); + println!(" hexdump -C {} | head -20", output_path); + println!(); + println!(" # View header only (10 bytes):"); + println!(" head -c 10 {} | hexdump -C", output_path); + println!(); + println!(" # View with prettier hex viewer (if installed):"); + println!(" hexyl {} | head -50", output_path); + println!(); + println!(" # View in binary format:"); + println!(" xxd -b {} | head -20", output_path); + println!(); + println!(" # Extract and examine header:"); + println!(" # Magic (bytes 0-3): Should be 42 55 52 4E (BURN)"); + println!(" # Version (bytes 4-5): Should be 01 00"); + println!(" # Metadata size (bytes 6-9): u32 little-endian"); + println!(); + println!(" # Load back the model:"); + println!( + " # let mut store = BurnpackStore::from_file(\"{}\");", + output_path + ); + println!(" # model.load_from(&mut store)?;"); +} diff --git a/crates/burn-store/src/applier.rs b/crates/burn-store/src/applier.rs index d9c2843fee..2fe5195710 100644 --- a/crates/burn-store/src/applier.rs +++ b/crates/burn-store/src/applier.rs @@ -7,8 +7,8 @@ use alloc::vec::Vec; use hashbrown::{HashMap, HashSet}; -use burn_core::module::{ModuleMapper, ParamId}; -use burn_tensor::{Bool, DType, Int, Tensor, backend::Backend}; +use burn_core::module::{ModuleMapper, Param}; +use burn_tensor::{Bool, DType, Int, Shape, Tensor, backend::Backend}; use crate::{ModuleAdapter, PathFilter, TensorSnapshot}; @@ -84,6 +84,8 @@ impl core::fmt::Display for ApplyError { } } +impl core::error::Error for ApplyError {} + /// Result of applying tensor snapshots to a module #[derive(Debug, Clone)] pub struct ApplyResult { @@ -221,8 +223,13 @@ impl Applier { } } - /// Apply a tensor snapshot to the current tensor (generic over tensor kind) - fn apply_tensor(&mut self, tensor: Tensor) -> Tensor + /// Apply a tensor snapshot with shape validation + /// Returns None if snapshot not found, filtered, or validation fails + fn apply_tensor( + &mut self, + target_device: &B::Device, + target_shape: Shape, + ) -> Option> where K: burn_tensor::TensorKind, K: burn_tensor::BasicOps, @@ -233,13 +240,16 @@ impl Applier { // Check if we have a snapshot for this path let snapshot = match self.snapshots.get(&path) { Some(s) => s, - None => return tensor, + None => { + // No snapshot available - signal caller not to apply + return None; + } }; // Check if we should apply based on filter if !self.should_apply() { - self.skipped.insert(path); - return tensor; + self.skipped.insert(path.clone()); + return None; } // Apply adapter with current container context @@ -251,34 +261,22 @@ impl Applier { path: path.clone(), message: format!("Failed to load tensor data: {:?}", e), }); - return tensor; + return None; // Signal caller to fall back to initialization } }; // Validate shape - let expected_shape = tensor.shape().dims; - if data.shape != expected_shape { + if data.shape != target_shape.dims { self.errors.push(ApplyError::ShapeMismatch { path: path.clone(), - expected: expected_shape, + expected: target_shape.dims, found: data.shape.clone(), }); - return tensor; - } - - // Validate dtype - let expected_dtype = tensor.dtype(); - if data.dtype != expected_dtype { - self.errors.push(ApplyError::DTypeMismatch { - path: path.clone(), - expected: expected_dtype, - found: data.dtype, - }); - return tensor; + return None; // Signal caller to fall back to initialization } self.applied.push(path); - Tensor::from_data(data, &tensor.device()) + Some(Tensor::from_data(data, target_device)) } } @@ -293,32 +291,133 @@ impl ModuleMapper for Applier { self.container_stack.pop(); } - fn map_float(&mut self, _id: ParamId, tensor: Tensor) -> Tensor { - if self.path_stack.is_empty() { - return tensor; + fn map_float(&mut self, param: Param>) -> Param> { + let param_id = param.id; + let target_device = param.lazy_device(); + let target_shape = param.lazy_shape(); + + // Try to apply snapshot with shape validation + match self.apply_tensor(&target_device, target_shape) { + Some(tensor) => { + // We have a tensor to apply - load it + param.transform_for_load(tensor, param_id) + } + None => { + // No snapshot, filtered, or validation failed - return param unchanged + param + } } - self.apply_tensor(tensor) } fn map_int( &mut self, - _id: ParamId, - tensor: Tensor, - ) -> Tensor { - if self.path_stack.is_empty() { - return tensor; + param: Param>, + ) -> Param> { + let param_id = param.id; + let target_device = param.lazy_device(); + let target_shape = param.lazy_shape(); + + // Try to apply snapshot with shape validation + match self.apply_tensor(&target_device, target_shape) { + Some(tensor) => { + // We have a tensor to apply - load it + param.transform_for_load(tensor, param_id) + } + None => { + // No snapshot, filtered, or validation failed - return param unchanged + param + } } - self.apply_tensor(tensor) } fn map_bool( &mut self, - _id: ParamId, - tensor: Tensor, - ) -> Tensor { - if self.path_stack.is_empty() { - return tensor; + param: Param>, + ) -> Param> { + let param_id = param.id; + let target_device = param.lazy_device(); + let target_shape = param.lazy_shape(); + + // Try to apply snapshot with shape validation + match self.apply_tensor(&target_device, target_shape) { + Some(tensor) => { + // We have a tensor to apply - load it + param.transform_for_load(tensor, param_id) + } + None => { + // No snapshot, filtered, or validation failed - return param unchanged + param + } } - self.apply_tensor(tensor) + } +} + +#[cfg(all(test, feature = "std", target_has_atomic = "ptr"))] +mod tests { + use super::*; + use burn_core::module::{ModuleMapper, Param, ParamId}; + use burn_tensor::Tensor; + + type TestBackend = burn_ndarray::NdArray; + + #[test] + fn root_level_parameters() { + let device = Default::default(); + + // Create root-level parameters (not inside any module) + let weight = Param::>::from_data([[1.0, 2.0], [3.0, 4.0]], &device); + let bias = Param::>::from_data([5.0, 6.0], &device); + + // Create snapshots with root-level paths (single-element path, no nested modules) + let weight_snapshot = crate::TensorSnapshot::from_data( + weight.val().to_data(), + vec!["weight".to_string()], // root-level parameter name + vec![], // no container + ParamId::new(), + ); + + let bias_snapshot = crate::TensorSnapshot::from_data( + bias.val().to_data(), + vec!["bias".to_string()], // root-level parameter name + vec![], // no container + ParamId::new(), + ); + + // Create applier with root-level snapshots + let mut applier = + Applier::::new(vec![weight_snapshot, bias_snapshot], None, None); + + // Create new params to load into + let weight_target = Param::initialized( + ParamId::new(), + Tensor::::zeros([2, 2], &device), + ); + let bias_target = Param::initialized( + ParamId::new(), + Tensor::::zeros([2], &device), + ); + + // Apply using the ModuleMapper interface - simulate module traversal + // Enter "weight" path (as if we're visiting a field named "weight") + applier.enter_module("weight", ""); + let weight_loaded = applier.map_float(weight_target); + applier.exit_module("weight", ""); + + // Enter "bias" path (as if we're visiting a field named "bias") + applier.enter_module("bias", ""); + let bias_loaded = applier.map_float(bias_target); + applier.exit_module("bias", ""); + + // Verify values were loaded + let weight_data = weight_loaded.val().to_data().to_vec::().unwrap(); + let bias_data = bias_loaded.val().to_data().to_vec::().unwrap(); + + assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]); + assert_eq!(bias_data, vec![5.0, 6.0]); + + // Verify applier result + let result = applier.into_result(); + assert_eq!(result.applied.len(), 2); + assert_eq!(result.errors.len(), 0); } } diff --git a/crates/burn-store/src/burnpack/base.rs b/crates/burn-store/src/burnpack/base.rs new file mode 100644 index 0000000000..17e04da73d --- /dev/null +++ b/crates/burn-store/src/burnpack/base.rs @@ -0,0 +1,199 @@ +//! Core types and constants for the Burnpack file format. +//! +//! See the [parent module](crate::burnpack) for the complete file format specification. + +use alloc::collections::BTreeMap; +use alloc::string::String; +use alloc::vec::Vec; +use burn_tensor::DType; +use byteorder::{ByteOrder, LittleEndian}; +use serde::{Deserialize, Serialize}; + +/// Magic number identifying a Burnpack file: "BURN" in ASCII (0x4255524E) +/// When written to file in little-endian format, appears as "NRUB" bytes +pub const MAGIC_NUMBER: u32 = 0x4255524E; + +/// Current format version +pub const FORMAT_VERSION: u16 = 0x0001; + +/// Size of the magic number in bytes +pub const MAGIC_SIZE: usize = 4; + +/// Size of the format version in bytes +pub const VERSION_SIZE: usize = 2; + +/// Size of the metadata size field in bytes +pub const METADATA_SIZE_FIELD_SIZE: usize = 4; + +/// Total header size (computed from components) +pub const HEADER_SIZE: usize = MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE; + +// Security limits to prevent DoS attacks via resource exhaustion +// These can be adjusted based on your use case + +/// Maximum allowed metadata size (100 MB) +/// Prevents memory exhaustion attacks via oversized metadata claims +pub const MAX_METADATA_SIZE: u32 = 100 * 1024 * 1024; + +/// Maximum allowed tensor size per tensor +/// Prevents memory exhaustion attacks via oversized tensor claims +/// 32-bit platforms: 2 GB limit (to fit within usize range) +/// 64-bit platforms: 10 GB limit +#[cfg(target_pointer_width = "32")] +pub const MAX_TENSOR_SIZE: usize = 2 * 1024 * 1024 * 1024; +#[cfg(not(target_pointer_width = "32"))] +pub const MAX_TENSOR_SIZE: usize = 10 * 1024 * 1024 * 1024; + +/// Maximum allowed number of tensors (100,000) +/// Prevents resource exhaustion via excessive tensor counts +pub const MAX_TENSOR_COUNT: usize = 100_000; + +/// Maximum CBOR deserialization recursion depth (128 levels) +/// Prevents stack overflow attacks via deeply nested CBOR structures +pub const MAX_CBOR_RECURSION_DEPTH: usize = 128; + +/// Maximum allowed file size (100 GB) +/// Prevents resource exhaustion from extremely large files +/// This limit applies to file-based loading (mmap and buffered) +#[cfg(feature = "std")] +pub const MAX_FILE_SIZE: u64 = 100 * 1024 * 1024 * 1024; + +/// Byte range for magic number in header +pub const fn magic_range() -> core::ops::Range { + let start = 0; + let end = start + MAGIC_SIZE; + start..end +} + +/// Byte range for format version in header +pub const fn version_range() -> core::ops::Range { + let start = MAGIC_SIZE; + let end = start + VERSION_SIZE; + start..end +} + +/// Byte range for metadata size field in header +pub const fn metadata_size_range() -> core::ops::Range { + let start = MAGIC_SIZE + VERSION_SIZE; + let end = start + METADATA_SIZE_FIELD_SIZE; + start..end +} + +// Compile-time validation that ranges are correct +const _: () = assert!(MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE == HEADER_SIZE); + +/// Header structure for Burnpack files +#[derive(Debug, Clone, Copy)] +pub struct BurnpackHeader { + /// Magic number (4 bytes): 0x4255524E ("BURN") + pub magic: u32, + /// Format version (2 bytes) + pub version: u16, + /// Size of CBOR metadata in bytes (4 bytes) + pub metadata_size: u32, +} + +impl BurnpackHeader { + /// Create a new header with the given metadata size + #[allow(dead_code)] + pub fn new(metadata_size: u32) -> Self { + Self { + magic: MAGIC_NUMBER, + version: FORMAT_VERSION, + metadata_size, + } + } + + /// Serialize header into bytes + pub fn into_bytes(self) -> [u8; HEADER_SIZE] { + let mut bytes = [0u8; HEADER_SIZE]; + LittleEndian::write_u32(&mut bytes[magic_range()], self.magic); + LittleEndian::write_u16(&mut bytes[version_range()], self.version); + LittleEndian::write_u32(&mut bytes[metadata_size_range()], self.metadata_size); + bytes + } + + /// Deserialize header from bytes + pub fn from_bytes(bytes: &[u8]) -> Result { + if bytes.len() < HEADER_SIZE { + return Err(BurnpackError::InvalidHeader); + } + + let magic = LittleEndian::read_u32(&bytes[magic_range()]); + if magic != MAGIC_NUMBER { + return Err(BurnpackError::InvalidMagicNumber); + } + + let version = LittleEndian::read_u16(&bytes[version_range()]); + let metadata_size = LittleEndian::read_u32(&bytes[metadata_size_range()]); + + Ok(Self { + magic, + version, + metadata_size, + }) + } +} + +/// Metadata structure serialized with CBOR +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BurnpackMetadata { + /// Tensor descriptors mapped by name for efficient lookup + pub tensors: BTreeMap, + /// Optional additional metadata + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + pub metadata: BTreeMap, +} + +/// Individual tensor descriptor +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TensorDescriptor { + /// Data type of the tensor + pub dtype: DType, + /// Tensor shape dimensions + pub shape: Vec, + /// Byte offsets in data section (start, end) + pub data_offsets: (u64, u64), + /// Parameter ID for training state persistence matching. + /// Generated automatically if not present during loading. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub param_id: Option, +} + +/// Error types for Burnpack operations +#[derive(Debug)] +pub enum BurnpackError { + InvalidHeader, + InvalidMagicNumber, + InvalidVersion, + MetadataSerializationError(String), + MetadataDeserializationError(String), + IoError(String), + TensorNotFound(String), + TensorBytesSizeMismatch(String), + ValidationError(String), +} + +impl core::fmt::Display for BurnpackError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + BurnpackError::InvalidHeader => write!(f, "Invalid header: insufficient bytes"), + BurnpackError::InvalidMagicNumber => write!(f, "Invalid magic number"), + BurnpackError::InvalidVersion => write!(f, "Unsupported version"), + BurnpackError::MetadataSerializationError(e) => { + write!(f, "Metadata serialization error: {}", e) + } + BurnpackError::MetadataDeserializationError(e) => { + write!(f, "Metadata deserialization error: {}", e) + } + BurnpackError::IoError(e) => write!(f, "I/O error: {}", e), + BurnpackError::TensorNotFound(name) => write!(f, "Tensor not found: {}", name), + BurnpackError::TensorBytesSizeMismatch(e) => { + write!(f, "Tensor bytes size mismatch: {}", e) + } + BurnpackError::ValidationError(e) => write!(f, "Validation error: {}", e), + } + } +} + +impl core::error::Error for BurnpackError {} diff --git a/crates/burn-store/src/burnpack/mod.rs b/crates/burn-store/src/burnpack/mod.rs new file mode 100644 index 0000000000..a332aae3fc --- /dev/null +++ b/crates/burn-store/src/burnpack/mod.rs @@ -0,0 +1,48 @@ +//! # Burnpack - Native Burn Model Storage Format +//! +//! Burnpack is the native binary storage format for Burn models, designed for efficient +//! serialization, fast loading, and cross-platform compatibility. +//! +//! ## Key Features +//! +//! - **CBOR Metadata**: Structured metadata with efficient binary encoding +//! - **Memory-Mapped Loading**: Zero-copy loading for optimal performance +//! - **No-std Support**: Works in embedded and WASM environments +//! - **ParamId Persistence**: Preserves parameter identities for stateful training +//! - **Lazy Tensor Loading**: Deferred data materialization for efficient memory usage +//! +//! ## File Format Structure +//! +//! ```text +//! ┌──────────────────────────────────┐ +//! │ Header (10 bytes) │ +//! ├──────────────────────────────────┤ +//! │ - Magic number (4 bytes) │ 0x4E525542 ("NRUB" in LE) +//! │ - Version (2 bytes) │ Format version (0x0001) +//! │ - Metadata size (4 bytes) │ Size of CBOR metadata (u32) +//! ├──────────────────────────────────┤ +//! │ Metadata (CBOR) │ +//! ├──────────────────────────────────┤ +//! │ - Tensor descriptors (BTreeMap) │ Order-preserving map of tensor metadata +//! │ Key: tensor name (string) │ e.g., "model.layer1.weight" +//! │ Value: TensorDescriptor │ +//! │ - dtype: DType │ Data type (F32, F64, I32, etc.) +//! │ - shape: Vec │ Tensor dimensions +//! │ - data_offsets: (u64, u64) │ (start, end) byte offsets +//! │ - param_id: Option │ Parameter ID (for training state) +//! │ - Additional metadata(BTreeMap) │ User-defined key-value pairs +//! ├──────────────────────────────────┤ +//! │ Tensor Data Section │ +//! ├──────────────────────────────────┤ +//! │ Raw tensor bytes │ Contiguous tensor data (little-endian) +//! │ (in order of offsets) │ Each tensor's data at specified offsets +//! └──────────────────────────────────┘ +//! ``` + +pub mod base; +pub mod reader; +pub mod store; +pub mod writer; + +#[cfg(test)] +mod tests; diff --git a/crates/burn-store/src/burnpack/reader.rs b/crates/burn-store/src/burnpack/reader.rs new file mode 100644 index 0000000000..6f70e32096 --- /dev/null +++ b/crates/burn-store/src/burnpack/reader.rs @@ -0,0 +1,697 @@ +#[cfg(feature = "std")] +use super::base::MAX_FILE_SIZE; +use super::base::{ + BurnpackError, BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER, + MAX_CBOR_RECURSION_DEPTH, MAX_METADATA_SIZE, MAX_TENSOR_COUNT, MAX_TENSOR_SIZE, +}; +use crate::TensorSnapshot; +use alloc::format; +use alloc::rc::Rc; +use alloc::string::ToString; +use alloc::vec; +use alloc::vec::Vec; +use burn_core::module::ParamId; +use burn_tensor::{Bytes, TensorData}; + +#[cfg(feature = "std")] +use std::cell::RefCell; +#[cfg(feature = "std")] +use std::fs::File; +#[cfg(feature = "std")] +use std::io::{Read, Seek}; +#[cfg(feature = "std")] +use std::path::Path; + +/// Storage backend for BurnpackReader +pub(crate) enum StorageBackend { + /// Memory-based storage + Memory(Rc), + /// Memory-mapped file storage (efficient for large files) + #[cfg(all(feature = "std", feature = "memmap"))] + Mmap(Rc), + /// File-based storage with buffered reading + #[cfg(feature = "std")] + #[allow(dead_code)] + FileBuffered { file: Rc> }, +} + +impl StorageBackend { + /// Read data from storage into the provided buffer at the given offset. + /// + /// # Arguments + /// * `bytes` - The buffer to read into (caller-allocated) + /// * `offset` - Absolute file/data position to start reading from + /// + /// # Errors + /// + /// Returns an error if: + /// - The requested data range is out of bounds + /// - Less data is available than requested (indicates corruption or incorrect offset) + /// - File I/O fails + /// + /// # Notes + /// + /// The caller allocates the buffer, which allows for buffer reuse and future optimizations + /// like memory pools and pinned memory. + /// + /// This method ensures all backends have consistent behavior: if the exact number of + /// requested bytes cannot be read, an error is returned to prevent data corruption. + pub(crate) fn read_into(&self, bytes: &mut [u8], offset: usize) -> Result<(), BurnpackError> { + match self { + StorageBackend::Memory(data) => { + let data_bytes = data.as_ref(); + let end = offset.checked_add(bytes.len()).ok_or_else(|| { + BurnpackError::IoError(format!( + "Offset overflow: offset {} + length {} exceeds maximum", + offset, + bytes.len() + )) + })?; + + if end > data_bytes.len() { + return Err(BurnpackError::IoError(format!( + "Read out of bounds: requested {}..{} but data length is {}", + offset, + end, + data_bytes.len() + ))); + } + + bytes.copy_from_slice(&data_bytes[offset..end]); + Ok(()) + } + #[cfg(all(feature = "std", feature = "memmap"))] + StorageBackend::Mmap(mmap) => { + let mmap_bytes = mmap.as_ref(); + let end = offset.checked_add(bytes.len()).ok_or_else(|| { + BurnpackError::IoError(format!( + "Offset overflow: offset {} + length {} exceeds maximum", + offset, + bytes.len() + )) + })?; + + if end > mmap_bytes.len() { + return Err(BurnpackError::IoError(format!( + "Read out of bounds: requested {}..{} but mmap length is {}", + offset, + end, + mmap_bytes.len() + ))); + } + + bytes.copy_from_slice(&mmap_bytes[offset..end]); + Ok(()) + } + #[cfg(feature = "std")] + StorageBackend::FileBuffered { file } => { + use std::io::SeekFrom; + + let mut file = file.borrow_mut(); + file.seek(SeekFrom::Start(offset as u64)).map_err(|e| { + BurnpackError::IoError(format!("Failed to seek in file: {}", e)) + })?; + + file.read_exact(bytes).map_err(|e| { + BurnpackError::IoError(format!("Failed to read from file: {}", e)) + })?; + Ok(()) + } + } + } + + /// Get full data reference for raw access + #[allow(dead_code)] + pub(crate) fn as_bytes(&self) -> Result<&[u8], BurnpackError> { + match self { + StorageBackend::Memory(data) => Ok(data.as_ref()), + #[cfg(all(feature = "std", feature = "memmap"))] + StorageBackend::Mmap(mmap) => Ok(mmap.as_ref()), + #[cfg(feature = "std")] + StorageBackend::FileBuffered { .. } => Err(BurnpackError::IoError( + "Cannot get full bytes reference for FileBuffered backend".into(), + )), + } + } +} + +/// Reader for loading Burnpack files +pub struct BurnpackReader { + /// Parsed metadata + pub(crate) metadata: BurnpackMetadata, + /// Storage backend + pub(crate) storage: StorageBackend, + /// Offset to the start of tensor data + pub(crate) data_offset: usize, +} + +impl BurnpackReader { + /// Load from bytes + pub fn from_bytes(bytes: Bytes) -> Result { + // Validate minimum size + if bytes.len() < HEADER_SIZE { + return Err(BurnpackError::InvalidHeader); + } + + // Parse header + let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE])?; + + // Verify magic number + if header.magic != MAGIC_NUMBER { + return Err(BurnpackError::InvalidMagicNumber); + } + + // Verify version compatibility + if header.version > FORMAT_VERSION { + return Err(BurnpackError::InvalidVersion); + } + + // Validate metadata size against security limit + if header.metadata_size > MAX_METADATA_SIZE { + return Err(BurnpackError::ValidationError(format!( + "Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)", + header.metadata_size, MAX_METADATA_SIZE + ))); + } + + // Parse metadata + let metadata_start = HEADER_SIZE; + let metadata_end = metadata_start + .checked_add(header.metadata_size as usize) + .ok_or_else(|| { + BurnpackError::IoError(format!( + "Metadata size overflow: {} + {}", + metadata_start, header.metadata_size + )) + })?; + + if bytes.len() < metadata_end { + return Err(BurnpackError::InvalidHeader); + } + + let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit( + &bytes[metadata_start..metadata_end], + MAX_CBOR_RECURSION_DEPTH, + ) + .map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?; + + // Validate tensor count against security limit + if metadata.tensors.len() > MAX_TENSOR_COUNT { + return Err(BurnpackError::ValidationError(format!( + "File contains {} tensors, exceeding maximum of {} (potential DoS attack)", + metadata.tensors.len(), + MAX_TENSOR_COUNT + ))); + } + + // Validate total file size - ensure file is large enough for all claimed tensor data + if !metadata.tensors.is_empty() { + let max_data_offset = metadata + .tensors + .values() + .map(|t| t.data_offsets.1) + .max() + .unwrap_or(0); + + let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| { + BurnpackError::ValidationError(format!( + "Data offset {} exceeds platform maximum", + max_data_offset + )) + })?; + + let min_file_size = + metadata_end + .checked_add(max_data_offset_usize) + .ok_or_else(|| { + BurnpackError::ValidationError("File size calculation overflow".into()) + })?; + + if bytes.len() < min_file_size { + return Err(BurnpackError::ValidationError(format!( + "File truncated: expected at least {} bytes, got {} bytes", + min_file_size, + bytes.len() + ))); + } + } + + Ok(Self { + metadata, + storage: StorageBackend::Memory(Rc::new(bytes)), + data_offset: metadata_end, + }) + } + + /// Load from file with memory mapping (most efficient for large files) + #[cfg(all(feature = "std", feature = "memmap"))] + pub(crate) fn from_file_mmap>(path: P) -> Result { + let file = File::open(&path).map_err(|e| BurnpackError::IoError(e.to_string()))?; + + // Validate maximum file size to prevent resource exhaustion + let file_size = file + .metadata() + .map_err(|e| BurnpackError::IoError(e.to_string()))? + .len(); + + if file_size > MAX_FILE_SIZE { + return Err(BurnpackError::ValidationError(format!( + "File size {} bytes exceeds maximum allowed size of {} bytes", + file_size, MAX_FILE_SIZE + ))); + } + + // Memory map the file + let mmap = unsafe { + memmap2::MmapOptions::new() + .map(&file) + .map_err(|e| BurnpackError::IoError(e.to_string()))? + }; + + // Parse header + if mmap.len() < HEADER_SIZE { + return Err(BurnpackError::InvalidHeader); + } + + let header = BurnpackHeader::from_bytes(&mmap[..HEADER_SIZE])?; + + // Verify magic number and version + if header.magic != MAGIC_NUMBER { + return Err(BurnpackError::InvalidMagicNumber); + } + + if header.version > FORMAT_VERSION { + return Err(BurnpackError::InvalidVersion); + } + + // Validate metadata size against security limit + if header.metadata_size > MAX_METADATA_SIZE { + return Err(BurnpackError::ValidationError(format!( + "Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)", + header.metadata_size, MAX_METADATA_SIZE + ))); + } + + // Parse metadata + let metadata_start = HEADER_SIZE; + let metadata_end = metadata_start + .checked_add(header.metadata_size as usize) + .ok_or_else(|| { + BurnpackError::IoError(format!( + "Metadata size overflow: {} + {}", + metadata_start, header.metadata_size + )) + })?; + + if mmap.len() < metadata_end { + return Err(BurnpackError::InvalidHeader); + } + + let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit( + &mmap[metadata_start..metadata_end], + MAX_CBOR_RECURSION_DEPTH, + ) + .map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?; + + // Validate tensor count against security limit + if metadata.tensors.len() > MAX_TENSOR_COUNT { + return Err(BurnpackError::ValidationError(format!( + "File contains {} tensors, exceeding maximum of {} (potential DoS attack)", + metadata.tensors.len(), + MAX_TENSOR_COUNT + ))); + } + + // Validate total file size - ensure file is large enough for all claimed tensor data + if !metadata.tensors.is_empty() { + let max_data_offset = metadata + .tensors + .values() + .map(|t| t.data_offsets.1) + .max() + .unwrap_or(0); + + let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| { + BurnpackError::ValidationError(format!( + "Data offset {} exceeds platform maximum", + max_data_offset + )) + })?; + + let min_file_size = + metadata_end + .checked_add(max_data_offset_usize) + .ok_or_else(|| { + BurnpackError::ValidationError("File size calculation overflow".into()) + })?; + + if mmap.len() < min_file_size { + return Err(BurnpackError::ValidationError(format!( + "File truncated: expected at least {} bytes, got {} bytes", + min_file_size, + mmap.len() + ))); + } + } + + Ok(Self { + metadata, + storage: StorageBackend::Mmap(Rc::new(mmap)), + data_offset: metadata_end, + }) + } + + /// Load from file - automatically uses memory mapping if available, otherwise uses buffered reading + #[cfg(feature = "std")] + pub fn from_file>(path: P) -> Result { + #[cfg(feature = "memmap")] + { + // Use memory mapping for efficient access + Self::from_file_mmap(path) + } + #[cfg(not(feature = "memmap"))] + { + // Fall back to buffered reading for memory efficiency + Self::from_file_buffered(path) + } + } + + /// Load from file with buffered reading (memory efficient but slower) + /// This is less efficient than memory mapping but works everywhere + #[cfg(feature = "std")] + #[allow(dead_code)] + pub(crate) fn from_file_buffered>(path: P) -> Result { + let mut file = File::open(&path).map_err(|e| BurnpackError::IoError(e.to_string()))?; + + // Validate maximum file size to prevent resource exhaustion + let file_size = file + .metadata() + .map_err(|e| BurnpackError::IoError(e.to_string()))? + .len(); + + if file_size > MAX_FILE_SIZE { + return Err(BurnpackError::ValidationError(format!( + "File size {} bytes exceeds maximum allowed size of {} bytes", + file_size, MAX_FILE_SIZE + ))); + } + + // Read header + let mut header_bytes = [0u8; HEADER_SIZE]; + file.read_exact(&mut header_bytes) + .map_err(|e| BurnpackError::IoError(e.to_string()))?; + + let header = BurnpackHeader::from_bytes(&header_bytes)?; + + // Verify version + if header.version > FORMAT_VERSION { + return Err(BurnpackError::InvalidVersion); + } + + // Validate metadata size against security limit + if header.metadata_size > MAX_METADATA_SIZE { + return Err(BurnpackError::ValidationError(format!( + "Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)", + header.metadata_size, MAX_METADATA_SIZE + ))); + } + + // Read metadata + let mut metadata_bytes = vec![0u8; header.metadata_size as usize]; + file.read_exact(&mut metadata_bytes) + .map_err(|e| BurnpackError::IoError(e.to_string()))?; + + let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit( + metadata_bytes.as_slice(), + MAX_CBOR_RECURSION_DEPTH, + ) + .map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?; + + // Validate tensor count against security limit + if metadata.tensors.len() > MAX_TENSOR_COUNT { + return Err(BurnpackError::ValidationError(format!( + "File contains {} tensors, exceeding maximum of {} (potential DoS attack)", + metadata.tensors.len(), + MAX_TENSOR_COUNT + ))); + } + + // Calculate metadata end offset + let metadata_end = HEADER_SIZE + .checked_add(header.metadata_size as usize) + .ok_or_else(|| { + BurnpackError::IoError(format!( + "Metadata size overflow: {} + {}", + HEADER_SIZE, header.metadata_size + )) + })?; + + // Validate total file size - ensure file is large enough for all claimed tensor data + if !metadata.tensors.is_empty() { + let max_data_offset = metadata + .tensors + .values() + .map(|t| t.data_offsets.1) + .max() + .unwrap_or(0); + + let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| { + BurnpackError::ValidationError(format!( + "Data offset {} exceeds platform maximum", + max_data_offset + )) + })?; + + let min_file_size = + metadata_end + .checked_add(max_data_offset_usize) + .ok_or_else(|| { + BurnpackError::ValidationError("File size calculation overflow".into()) + })?; + + // Get actual file size + let file_size = file + .metadata() + .map_err(|e| BurnpackError::IoError(e.to_string()))? + .len() as usize; + + if file_size < min_file_size { + return Err(BurnpackError::ValidationError(format!( + "File truncated: expected at least {} bytes, got {} bytes", + min_file_size, file_size + ))); + } + } + + Ok(Self { + metadata, + storage: StorageBackend::FileBuffered { + file: Rc::new(RefCell::new(file)), + }, + data_offset: metadata_end, + }) + } + + /// Get all tensor snapshots at once for efficient loading + pub fn get_snapshots(&self) -> Result, BurnpackError> { + let mut snapshots = Vec::new(); + + for (name, descriptor) in &self.metadata.tensors { + // Clone metadata for use in closure + // Convert shape dimensions with overflow checking + let shape: Vec = descriptor + .shape + .iter() + .map(|&s| { + s.try_into().map_err(|_| { + BurnpackError::ValidationError(format!( + "Tensor '{}' has corrupted shape data: dimension {} exceeds platform maximum", + name, s + )) + }) + }) + .collect::, BurnpackError>>()?; + + let dtype = descriptor.dtype; + + // Clone storage reference for the closure + let storage = match &self.storage { + StorageBackend::Memory(data) => StorageBackend::Memory(data.clone()), + #[cfg(all(feature = "std", feature = "memmap"))] + StorageBackend::Mmap(mmap) => StorageBackend::Mmap(mmap.clone()), + #[cfg(feature = "std")] + StorageBackend::FileBuffered { file } => { + StorageBackend::FileBuffered { file: file.clone() } + } + }; + + // Always use absolute positions for all backends + // Convert offsets with overflow checking + let offset_start: usize = descriptor.data_offsets.0.try_into().map_err(|_| { + BurnpackError::ValidationError(format!( + "Tensor '{}' has corrupted offset data: start offset {} exceeds platform maximum", + name, descriptor.data_offsets.0 + )) + })?; + + let offset_end: usize = descriptor.data_offsets.1.try_into().map_err(|_| { + BurnpackError::ValidationError(format!( + "Tensor '{}' has corrupted offset data: end offset {} exceeds platform maximum", + name, descriptor.data_offsets.1 + )) + })?; + + let start = self.data_offset.checked_add(offset_start).ok_or_else(|| { + BurnpackError::ValidationError(format!( + "Tensor '{}' has corrupted offset data: start offset overflow {} + {}", + name, self.data_offset, offset_start + )) + })?; + + let end = self.data_offset.checked_add(offset_end).ok_or_else(|| { + BurnpackError::ValidationError(format!( + "Tensor '{}' has corrupted offset data: end offset overflow {} + {}", + name, self.data_offset, offset_end + )) + })?; + + // Clone shape for the closure (TensorSnapshot::from_closure will also need it) + let shape_for_closure = shape.clone(); + + // Validate offset range + if end < start { + return Err(BurnpackError::ValidationError(format!( + "Tensor '{}' has corrupted offset data: end offset {} < start offset {}", + name, end, start + ))); + } + + // Validate tensor size against security limit + let tensor_size = end - start; + if tensor_size > MAX_TENSOR_SIZE { + return Err(BurnpackError::ValidationError(format!( + "Tensor '{}' size {} exceeds maximum allowed size of {} bytes (potential DoS attack)", + name, tensor_size, MAX_TENSOR_SIZE + ))); + } + + // Restore param_id if it was saved, otherwise generate + let tensor_id = descriptor + .param_id + .map(ParamId::from) + .unwrap_or_else(ParamId::new); + + // Create lazy TensorSnapshot + let snapshot = TensorSnapshot::from_closure( + Rc::new(move || { + // This closure is only called when data is actually needed + let len = end - start; + // TODO Should be allocated by the backend in the future + // See https://github.com/tracel-ai/burn/pull/3792#discussion_r2416812091 + let mut data_bytes = vec![0u8; len]; + storage.read_into(&mut data_bytes, start).map_err(|e| { + crate::TensorSnapshotError::IoError(format!( + "Failed to read tensor data: {}", + e + )) + })?; + Ok(TensorData::from_bytes_vec( + data_bytes, + shape_for_closure.clone(), + dtype, + )) + }), + dtype, + shape, + name.split('.').map(|s| s.to_string()).collect(), + vec![], // empty container_stack + tensor_id, // restored or newly generated param id + ); + + snapshots.push(snapshot); + } + + Ok(snapshots) + } + + // Legacy methods for test compatibility - will be removed + + /// Get tensor as TensorSnapshot with lazy loading + #[allow(dead_code)] + pub(crate) fn get_tensor_snapshot(&self, name: &str) -> Result { + let snapshots = self.get_snapshots()?; + snapshots + .into_iter() + .find(|s| s.full_path() == name) + .ok_or_else(|| BurnpackError::TensorNotFound(name.to_string())) + } + + /// Get list of tensor names + #[allow(dead_code)] + pub(crate) fn tensor_names(&self) -> Vec<&str> { + self.metadata + .tensors + .keys() + .map(|name| name.as_str()) + .collect() + } + + /// Get metadata + #[allow(dead_code)] + pub(crate) fn metadata(&self) -> &BurnpackMetadata { + &self.metadata + } + + /// Get tensor data as raw bytes + #[allow(dead_code)] + pub(crate) fn get_tensor_data(&self, name: &str) -> Result, BurnpackError> { + let descriptor = self + .metadata + .tensors + .get(name) + .ok_or_else(|| BurnpackError::TensorNotFound(name.to_string()))?; + + // Always use absolute positions for all backends + // Convert offsets with overflow checking + let offset_start: usize = descriptor.data_offsets.0.try_into().map_err(|_| { + BurnpackError::IoError(format!( + "Tensor '{}' has corrupted offset data: start offset {} exceeds platform maximum", + name, descriptor.data_offsets.0 + )) + })?; + + let offset_end: usize = descriptor.data_offsets.1.try_into().map_err(|_| { + BurnpackError::IoError(format!( + "Tensor '{}' has corrupted offset data: end offset {} exceeds platform maximum", + name, descriptor.data_offsets.1 + )) + })?; + + let start = self.data_offset.checked_add(offset_start).ok_or_else(|| { + BurnpackError::IoError(format!( + "Tensor '{}' has corrupted offset data: start offset overflow {} + {}", + name, self.data_offset, offset_start + )) + })?; + + let end = self.data_offset.checked_add(offset_end).ok_or_else(|| { + BurnpackError::IoError(format!( + "Tensor '{}' has corrupted offset data: end offset overflow {} + {}", + name, self.data_offset, offset_end + )) + })?; + + // Validate offset range + if end < start { + return Err(BurnpackError::IoError(format!( + "Tensor '{}' has corrupted offset data: end offset {} < start offset {}", + name, end, start + ))); + } + + let len = end - start; + let mut buffer = vec![0u8; len]; + self.storage.read_into(&mut buffer, start)?; + Ok(buffer) + } +} diff --git a/crates/burn-store/src/burnpack/store.rs b/crates/burn-store/src/burnpack/store.rs new file mode 100644 index 0000000000..9d4a04c5bf --- /dev/null +++ b/crates/burn-store/src/burnpack/store.rs @@ -0,0 +1,395 @@ +#[cfg(feature = "std")] +use std::path::PathBuf; + +use super::reader::BurnpackReader; +use super::writer::BurnpackWriter; +#[cfg(feature = "std")] +use crate::KeyRemapper; +use crate::burnpack::base::BurnpackError; +use crate::{ModuleSnapshot, ModuleStore, PathFilter}; +use alloc::collections::BTreeMap; +use alloc::format; +use alloc::string::String; +use burn_core::prelude::Backend; +use burn_tensor::Bytes; + +/// Store mode for BurnpackStore +enum StoreMode { + #[cfg(feature = "std")] + File(PathBuf), + Bytes(Option), +} + +/// BurnpackStore - A Burn-specific file format store using CBOR for metadata +pub struct BurnpackStore { + /// Store mode - either file path or bytes + mode: StoreMode, + /// Optional filter for selective loading/saving + filter: Option, + /// Additional metadata + metadata: BTreeMap, + /// Allow partial loading (ignore missing tensors) + allow_partial: bool, + /// Validate tensors during loading (check shapes and dtypes) + validate: bool, + /// Allow overwriting existing files (default: false) + overwrite: bool, + /// Automatically append .bpk extension if not present (default: true) + #[cfg(feature = "std")] + auto_extension: bool, + /// Key remapper for tensor name transformations + #[cfg(feature = "std")] + remapper: KeyRemapper, + /// Writer for saving + writer: Option, + /// Reader for loading + reader: Option, +} + +impl BurnpackStore { + /// Get the default metadata that includes Burn framework information. + /// + /// This includes: + /// - `format`: "burnpack" + /// - `producer`: "burn" + /// - `version`: The version of burn-store crate (from CARGO_PKG_VERSION) + /// + /// These metadata fields are automatically added to all saved models. + pub fn default_metadata() -> BTreeMap { + let mut metadata = BTreeMap::new(); + metadata.insert("format".into(), "burnpack".into()); + metadata.insert("producer".into(), "burn".into()); + metadata.insert("version".into(), env!("CARGO_PKG_VERSION").into()); + metadata + } + /// Create a new store from a file path + /// + /// By default, automatically appends `.bpk` extension if the path doesn't have one. + /// Use `.auto_extension(false)` to disable this behavior. + /// + /// # Examples + /// + /// ```no_run + /// # use burn_store::BurnpackStore; + /// // Automatically appends .bpk + /// let store = BurnpackStore::from_file("model"); // creates "model.bpk" + /// + /// // Already has extension, no append + /// let store = BurnpackStore::from_file("model.bpk"); // uses "model.bpk" + /// let store = BurnpackStore::from_file("model.myext"); // uses "model.myext" + /// + /// // Disable auto-extension + /// let store = BurnpackStore::from_file("model").auto_extension(false); // uses "model" + /// ``` + #[cfg(feature = "std")] + pub fn from_file>(path: P) -> Self { + Self { + mode: StoreMode::File(path.as_ref().to_path_buf()), + filter: None, + metadata: Self::default_metadata(), + allow_partial: false, + validate: true, + overwrite: false, + #[cfg(feature = "std")] + auto_extension: true, + #[cfg(feature = "std")] + remapper: KeyRemapper::new(), + writer: None, + reader: None, + } + } + + /// Create a new store from bytes (for reading) or empty (for writing) + pub fn from_bytes(bytes: Option) -> Self { + Self { + mode: StoreMode::Bytes(bytes), + filter: None, + metadata: Self::default_metadata(), + allow_partial: false, + validate: true, + overwrite: false, + #[cfg(feature = "std")] + auto_extension: false, // Not used for bytes mode + #[cfg(feature = "std")] + remapper: KeyRemapper::new(), + writer: None, + reader: None, + } + } + + /// Add metadata key-value pair + pub fn metadata(mut self, key: impl Into, value: impl Into) -> Self { + self.metadata.insert(key.into(), value.into()); + self + } + + /// Clear all metadata (including defaults) + /// + /// This removes all metadata including the default format, producer, and version fields. + /// Use with caution as some tools may expect these fields to be present. + pub fn clear_metadata(mut self) -> Self { + self.metadata.clear(); + self + } + + /// Allow partial loading (ignore missing tensors) + /// + /// When set to `true`, the store will not fail if some tensors are missing + /// during loading. This is useful when loading a subset of a model's parameters. + /// + /// Default: `false` + pub fn allow_partial(mut self, allow: bool) -> Self { + self.allow_partial = allow; + self + } + + /// Enable or disable validation during loading + /// + /// When validation is enabled, the store will check that loaded tensors + /// match the expected shapes and data types. Disabling validation can + /// improve performance but may lead to runtime errors if data is corrupted. + /// + /// Default: `true` + pub fn validate(mut self, validate: bool) -> Self { + self.validate = validate; + self + } + + /// Allow overwriting existing files when saving + /// + /// When set to `false`, attempting to save to an existing file will result in an error. + /// When set to `true`, existing files will be overwritten without warning. + /// + /// Default: `false` + pub fn overwrite(mut self, overwrite: bool) -> Self { + self.overwrite = overwrite; + self + } + + /// Enable or disable automatic .bpk extension appending + /// + /// When enabled (default), automatically appends `.bpk` to the file path + /// if no extension is detected. If an extension is already present, it is preserved. + /// + /// When disabled, uses the exact path provided without modification. + /// + /// Default: `true` + /// + /// # Examples + /// + /// ```no_run + /// # use burn_store::BurnpackStore; + /// // With auto_extension enabled (default) + /// let store = BurnpackStore::from_file("model"); // -> "model.bpk" + /// + /// // With auto_extension disabled + /// let store = BurnpackStore::from_file("model") + /// .auto_extension(false); // -> "model" + /// ``` + #[cfg(feature = "std")] + pub fn auto_extension(mut self, enable: bool) -> Self { + self.auto_extension = enable; + self + } + + /// Set path filter for selective loading/saving + pub fn with_filter(mut self, filter: PathFilter) -> Self { + self.filter = Some(filter); + self + } + + /// Add regex pattern to filter + #[cfg(feature = "std")] + pub fn with_regex(mut self, pattern: &str) -> Self { + let filter = self.filter.unwrap_or_default(); + self.filter = Some(filter.with_regex(pattern)); + self + } + + /// Add exact path to filter + pub fn with_full_path(mut self, path: impl Into) -> Self { + let filter = self.filter.unwrap_or_default(); + self.filter = Some(filter.with_full_path(path)); + self + } + + /// Match all tensors (no filtering) + pub fn match_all(mut self) -> Self { + self.filter = Some(PathFilter::new().match_all()); + self + } + + /// Set key remapper for tensor name transformations during loading + #[cfg(feature = "std")] + pub fn remap(mut self, remapper: KeyRemapper) -> Self { + self.remapper = remapper; + self + } + + /// Add a single regex pattern for key remapping + #[cfg(feature = "std")] + pub fn with_remap_pattern(mut self, from: S1, to: S2) -> Self + where + S1: AsRef, + S2: Into, + { + self.remapper = self + .remapper + .add_pattern(from.as_ref(), to.into()) + .expect("Invalid regex pattern"); + self + } + + /// Set the path filter + pub fn filter(mut self, filter: PathFilter) -> Self { + self.filter = Some(filter); + self + } + + /// Get the bytes after writing (only valid for bytes mode after collecting) + pub fn get_bytes(&self) -> Result { + if let Some(writer) = &self.writer { + return writer.to_bytes(); + } + + match &self.mode { + StoreMode::Bytes(Some(bytes)) => Ok(bytes.clone()), + _ => Err(BurnpackError::IoError("No bytes available".into())), + } + } + + /// Process the file path with auto-extension logic + #[cfg(feature = "std")] + fn process_path(&self, path: &std::path::Path) -> PathBuf { + if !self.auto_extension { + return path.to_path_buf(); + } + + // Check if path already has an extension + if path.extension().is_some() { + // Has extension, use as-is + return path.to_path_buf(); + } + + // No extension, append .bpk + let mut new_path = path.to_path_buf(); + new_path.set_extension("bpk"); + new_path + } +} + +impl ModuleStore for BurnpackStore { + type Error = BurnpackError; + + fn collect_from>( + &mut self, + module: &M, + ) -> Result<(), Self::Error> { + // Collect snapshots from module + let snapshots = module.collect(self.filter.clone(), None); + + // Initialize writer with snapshots + let mut writer = BurnpackWriter::new(snapshots); + + // Add metadata using builder pattern + for (key, value) in &self.metadata { + writer = writer.with_metadata(key.as_str(), value.as_str()); + } + + // Store the writer for finalization + self.writer = Some(writer); + + // Write to storage based on mode + if let Some(writer) = &self.writer { + match &self.mode { + #[cfg(feature = "std")] + StoreMode::File(path) => { + // Process path with auto-extension logic + let final_path = self.process_path(path); + + // Check if file exists and overwrite is disabled + if final_path.exists() && !self.overwrite { + return Err(BurnpackError::IoError(format!( + "File already exists: {}. Use .overwrite(true) to overwrite.", + final_path.display() + ))); + } + writer.write_to_file(&final_path)?; + } + StoreMode::Bytes(_) => { + // Generate and store the bytes + let bytes_data = writer.to_bytes()?; + // Update mode with bytes - this pattern is irrefutable in no-std mode + #[cfg_attr(not(feature = "std"), allow(irrefutable_let_patterns))] + let StoreMode::Bytes(bytes_ref) = &mut self.mode else { + unreachable!("We just matched Bytes variant"); + }; + *bytes_ref = Some(bytes_data); + } + } + } + + Ok(()) + } + + fn apply_to>( + &mut self, + module: &mut M, + ) -> Result { + // Load reader if not already loaded + if self.reader.is_none() { + let reader = match &self.mode { + #[cfg(feature = "std")] + StoreMode::File(path) => { + // Process path with auto-extension logic + let final_path = self.process_path(path); + BurnpackReader::from_file(&final_path)? + } + StoreMode::Bytes(Some(bytes)) => BurnpackReader::from_bytes(bytes.clone())?, + StoreMode::Bytes(None) => { + return Err(BurnpackError::IoError("No bytes to read from".into())); + } + }; + self.reader = Some(reader); + } + + let reader = self + .reader + .as_ref() + .ok_or_else(|| BurnpackError::IoError("Reader not initialized".into()))?; + + // Get all snapshots at once for efficient loading + #[cfg(feature = "std")] + let snapshots = if !self.remapper.patterns.is_empty() { + let (remapped, _remapped_names) = self.remapper.remap(reader.get_snapshots()?); + // TODO figure what to do with remapped names + remapped + } else { + reader.get_snapshots()? + }; + + #[cfg(not(feature = "std"))] + let snapshots = reader.get_snapshots()?; + + // Apply all snapshots at once to the module + let result = module.apply(snapshots, self.filter.clone(), None); + + // Validate if needed + if self.validate && !result.errors.is_empty() { + return Err(BurnpackError::ValidationError(format!( + "Import errors: {:?}", + result.errors + ))); + } + + // Check for missing tensors if partial loading is not allowed + if !self.allow_partial && !result.missing.is_empty() { + return Err(BurnpackError::ValidationError(format!( + "Missing tensors: {:?}", + result.missing + ))); + } + + Ok(result) + } +} diff --git a/crates/burn-store/src/burnpack/tests/edge_cases.rs b/crates/burn-store/src/burnpack/tests/edge_cases.rs new file mode 100644 index 0000000000..59a02101f1 --- /dev/null +++ b/crates/burn-store/src/burnpack/tests/edge_cases.rs @@ -0,0 +1,365 @@ +use crate::TensorSnapshot; +use crate::burnpack::{ + base::{BurnpackHeader, HEADER_SIZE}, + reader::BurnpackReader, + writer::BurnpackWriter, +}; +use burn_core::module::ParamId; +use burn_tensor::{DType, TensorData}; + +#[test] +fn test_maximum_metadata_size() { + // Create metadata that approaches u32::MAX (4GB limit) + // In practice, we'll test with a reasonably large metadata + let large_key = "x".repeat(1000); + let large_value = "y".repeat(10000); + + let mut writer = BurnpackWriter::new(vec![]); + + for i in 0..100 { + writer = writer.with_metadata(&format!("{}_{}", large_key, i), &large_value); + } + + let result = writer.to_bytes(); + assert!(result.is_ok()); + + let bytes = result.unwrap(); + let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); + + // Metadata size should be large but within u32 bounds + assert!(header.metadata_size > 1000000); // At least 1MB of metadata + assert!(header.metadata_size < u32::MAX); +} + +#[test] +fn test_zero_size_tensor_shapes() { + // Test various zero-dimensional shapes + let test_cases = vec![ + (vec![0], vec![]), // Empty 1D + (vec![0, 10], vec![]), // Zero rows + (vec![10, 0], vec![]), // Zero columns + (vec![0, 0], vec![]), // Zero both dimensions + (vec![5, 0, 10], vec![]), // Zero in middle dimension + ]; + + let mut snapshots = vec![]; + for (i, (shape, data)) in test_cases.iter().enumerate() { + let name = format!("zero_tensor_{}", i); + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(data.clone(), shape.clone(), DType::F32), + vec![name.clone()], + vec![], + ParamId::new(), + ); + snapshots.push(snapshot); + } + + let writer = BurnpackWriter::new(snapshots); + let bytes = writer.to_bytes().unwrap(); + + // Read back and verify + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + let names = reader.tensor_names(); + assert_eq!(names.len(), 5); +} + +#[test] +fn test_extremely_long_tensor_names() { + // Create a tensor with an extremely long name + let long_name = "a".repeat(10000); + + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8), + vec![long_name.clone()], + vec![], + ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + let bytes = writer.to_bytes().unwrap(); + + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + let names = reader.tensor_names(); + assert_eq!(names[0].len(), 10000); +} + +#[test] +fn test_unicode_in_names_and_metadata() { + // Test various Unicode characters in tensor names and metadata + let unicode_names = vec![ + "测试_tensor", // Chinese + "тест_tensor", // Cyrillic + "テスト_tensor", // Japanese + "🔥_burn_tensor", // Emoji + "αβγδ_tensor", // Greek + "한글_tensor", // Korean + ]; + + let mut snapshots = vec![]; + for name in &unicode_names { + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1], vec![1], DType::U8), + vec![name.to_string()], + vec![], + ParamId::new(), + ); + snapshots.push(snapshot); + } + + let writer = BurnpackWriter::new(snapshots) + .with_metadata("模型名称", "测试模型") + .with_metadata("מודל", "בדיקה") + .with_metadata("🔥", "fire"); + + let bytes = writer.to_bytes().unwrap(); + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + + // Verify all Unicode names are preserved + let names = reader.tensor_names(); + assert_eq!(names.len(), unicode_names.len()); + + // Verify metadata + assert_eq!( + reader.metadata().metadata.get("模型名称"), + Some(&"测试模型".to_string()) + ); + assert_eq!( + reader.metadata().metadata.get("🔥"), + Some(&"fire".to_string()) + ); +} + +#[test] +fn test_all_supported_dtypes() { + // Test all DTypes with their boundary values + let dtypes_with_data = vec![ + ( + DType::F32, + vec![ + f32::MIN.to_le_bytes().to_vec(), + f32::MAX.to_le_bytes().to_vec(), + ] + .concat(), + ), + ( + DType::F64, + vec![ + f64::MIN.to_le_bytes().to_vec(), + f64::MAX.to_le_bytes().to_vec(), + ] + .concat(), + ), + ( + DType::I32, + vec![ + i32::MIN.to_le_bytes().to_vec(), + i32::MAX.to_le_bytes().to_vec(), + ] + .concat(), + ), + ( + DType::I64, + vec![ + i64::MIN.to_le_bytes().to_vec(), + i64::MAX.to_le_bytes().to_vec(), + ] + .concat(), + ), + ( + DType::U32, + vec![ + u32::MIN.to_le_bytes().to_vec(), + u32::MAX.to_le_bytes().to_vec(), + ] + .concat(), + ), + ( + DType::U64, + vec![ + u64::MIN.to_le_bytes().to_vec(), + u64::MAX.to_le_bytes().to_vec(), + ] + .concat(), + ), + (DType::U8, vec![u8::MIN, u8::MAX]), + (DType::Bool, vec![0, 1]), + ]; + + let mut snapshots = vec![]; + for (i, (dtype, data)) in dtypes_with_data.iter().enumerate() { + let name = format!("dtype_test_{}", i); + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(data.clone(), vec![2], *dtype), + vec![name], + vec![], + ParamId::new(), + ); + snapshots.push(snapshot); + } + + let writer = BurnpackWriter::new(snapshots); + let bytes = writer.to_bytes().unwrap(); + + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + assert_eq!(reader.tensor_names().len(), dtypes_with_data.len()); + + // Verify dtypes are preserved + for (i, (expected_dtype, _)) in dtypes_with_data.iter().enumerate() { + let name = format!("dtype_test_{}", i); + let snapshot = reader.get_tensor_snapshot(&name).unwrap(); + assert_eq!(snapshot.dtype, *expected_dtype); + } +} + +#[test] +fn test_special_float_values() { + // Test special floating-point values (NaN, Inf, -Inf) + let special_values = vec![ + f32::NAN, + f32::INFINITY, + f32::NEG_INFINITY, + 0.0_f32, + -0.0_f32, + ]; + + let data: Vec = special_values + .iter() + .flat_map(|f| f.to_le_bytes()) + .collect(); + + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(data.clone(), vec![5], DType::F32), + vec!["special_floats".to_string()], + vec![], + ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + let bytes = writer.to_bytes().unwrap(); + + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + let tensor_data = reader.get_tensor_data("special_floats").unwrap(); + + // Check data is preserved exactly (bit-for-bit) + assert_eq!(tensor_data, data); +} + +#[test] +fn test_metadata_with_empty_values() { + let writer = BurnpackWriter::new(vec![]) + .with_metadata("empty_value", "") + .with_metadata("", "empty_key") + .with_metadata("normal", "value"); + + let bytes = writer.to_bytes().unwrap(); + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + + let metadata = &reader.metadata().metadata; + assert_eq!(metadata.get("empty_value"), Some(&"".to_string())); + assert_eq!(metadata.get(""), Some(&"empty_key".to_string())); + assert_eq!(metadata.get("normal"), Some(&"value".to_string())); +} + +#[test] +fn test_single_byte_tensor() { + // Test the smallest possible tensor (1 byte) + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![42], vec![1], DType::U8), + vec!["single_byte".to_string()], + vec![], + ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + let bytes = writer.to_bytes().unwrap(); + + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + let data = reader.get_tensor_data("single_byte").unwrap(); + assert_eq!(data, vec![42]); +} + +#[test] +fn test_high_dimensional_tensor() { + // Test a tensor with many dimensions (10D) + let shape = vec![2, 2, 2, 2, 2, 2, 2, 2, 2, 2]; // 10 dimensions, 1024 elements total + let data = vec![1u8; 1024]; + + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(data.clone(), shape.clone(), DType::U8), + vec!["high_dim".to_string()], + vec![], + ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + let bytes = writer.to_bytes().unwrap(); + + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + let loaded_snapshot = reader.get_tensor_snapshot("high_dim").unwrap(); + assert_eq!(loaded_snapshot.shape, shape); +} + +#[test] +fn test_metadata_key_collision() { + // Test that later values override earlier ones for the same key + let writer = BurnpackWriter::new(vec![]) + .with_metadata("key", "value1") + .with_metadata("key", "value2") + .with_metadata("key", "value3"); + + let bytes = writer.to_bytes().unwrap(); + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + + assert_eq!( + reader.metadata().metadata.get("key"), + Some(&"value3".to_string()) + ); +} + +#[test] +fn test_tensor_name_with_path_separators() { + // Test tensor names that look like file paths + let path_like_names = vec![ + "model/encoder/layer1/weights", + "model\\decoder\\layer1\\bias", + "model::module::param", + "model.submodule.weight", + ]; + + let mut snapshots = vec![]; + for name in &path_like_names { + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8), + vec![name.to_string()], + vec![], + ParamId::new(), + ); + snapshots.push(snapshot); + } + + let writer = BurnpackWriter::new(snapshots); + let bytes = writer.to_bytes().unwrap(); + + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + let names = reader.tensor_names(); + + // All names should be preserved exactly + for expected_name in &path_like_names { + assert!(names.contains(&expected_name.as_ref())); + } +} + +// The following tests are commented out as they test error conditions +// that might be handled differently in the new API + +// #[test] +// fn test_data_overflow_protection() { +// // Test that we handle potential integer overflows in offset calculations +// ... +// } + +// #[test] +// fn test_reading_corrupted_header() { +// // Test reading files with corrupted headers +// ... +// } diff --git a/crates/burn-store/src/burnpack/tests/header.rs b/crates/burn-store/src/burnpack/tests/header.rs new file mode 100644 index 0000000000..a7ac2ac46c --- /dev/null +++ b/crates/burn-store/src/burnpack/tests/header.rs @@ -0,0 +1,61 @@ +use crate::burnpack::base::*; + +#[test] +fn test_header_serialization() { + let header = BurnpackHeader::new(12345); + + // Check fields + assert_eq!(header.magic, MAGIC_NUMBER); + assert_eq!(header.version, FORMAT_VERSION); + assert_eq!(header.metadata_size, 12345); + + // Serialize to bytes + let bytes = header.into_bytes(); + assert_eq!(bytes.len(), HEADER_SIZE); + + // Deserialize back + let header2 = BurnpackHeader::from_bytes(&bytes).unwrap(); + assert_eq!(header2.magic, header.magic); + assert_eq!(header2.version, header.version); + assert_eq!(header2.metadata_size, header.metadata_size); +} + +#[test] +fn test_header_invalid_magic() { + let mut bytes = [0u8; HEADER_SIZE]; + // Write wrong magic number + bytes[0..4].copy_from_slice(&[0x00, 0x00, 0x00, 0x00]); + + let result = BurnpackHeader::from_bytes(&bytes); + match result { + Err(BurnpackError::InvalidMagicNumber) => {} + _ => panic!("Expected InvalidMagicNumber error"), + } +} + +#[test] +fn test_header_insufficient_bytes() { + let bytes = [0u8; 5]; // Too short + + let result = BurnpackHeader::from_bytes(&bytes); + match result { + Err(BurnpackError::InvalidHeader) => {} + _ => panic!("Expected InvalidHeader error"), + } +} + +#[test] +fn test_version_compatibility() { + // Create a header with current version + let header = BurnpackHeader::new(100); + let bytes = header.into_bytes(); + + // Should succeed with current version + let result = BurnpackHeader::from_bytes(&bytes); + assert!(result.is_ok()); + + // Test with future version (should fail in real implementation) + // For now, we just verify the version field is correctly set + let header = result.unwrap(); + assert_eq!(header.version, FORMAT_VERSION); +} diff --git a/crates/burn-store/src/burnpack/tests/helpers.rs b/crates/burn-store/src/burnpack/tests/helpers.rs new file mode 100644 index 0000000000..3f13d7b8cc --- /dev/null +++ b/crates/burn-store/src/burnpack/tests/helpers.rs @@ -0,0 +1,19 @@ +use crate::TensorSnapshot; +use burn_core::module::ParamId; +use burn_tensor::{DType, TensorData}; + +/// Helper to create a test TensorSnapshot +#[allow(dead_code)] +pub fn create_test_snapshot( + name: String, + data: Vec, + shape: Vec, + dtype: DType, +) -> TensorSnapshot { + TensorSnapshot::from_data( + TensorData::from_bytes_vec(data, shape, dtype), + vec![name], + vec![], + ParamId::new(), + ) +} diff --git a/crates/burn-store/src/burnpack/tests/mod.rs b/crates/burn-store/src/burnpack/tests/mod.rs new file mode 100644 index 0000000000..75861d8146 --- /dev/null +++ b/crates/burn-store/src/burnpack/tests/mod.rs @@ -0,0 +1,9 @@ +use crate::TensorSnapshot; + +mod edge_cases; +mod header; +mod helpers; +mod reader; +mod round_trip; +mod store; +mod writer; diff --git a/crates/burn-store/src/burnpack/tests/reader.rs b/crates/burn-store/src/burnpack/tests/reader.rs new file mode 100644 index 0000000000..8bbd629b6a --- /dev/null +++ b/crates/burn-store/src/burnpack/tests/reader.rs @@ -0,0 +1,772 @@ +use crate::burnpack::{ + base::{ + BurnpackError, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER, magic_range, metadata_size_range, + version_range, + }, + reader::BurnpackReader, + writer::BurnpackWriter, +}; + +use super::*; +use burn_tensor::{Bytes, DType, TensorData}; + +#[test] +fn test_reader_from_bytes_empty() { + // Create empty burnpack data + let writer = BurnpackWriter::new(Vec::new()); + let bytes = writer.to_bytes().unwrap(); + + // Read it back + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + + assert_eq!(reader.metadata().tensors.len(), 0); + assert!(reader.metadata().metadata.is_empty()); +} + +#[test] +fn test_reader_from_bytes_with_data() { + // Create test data + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), + vec!["test_tensor".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "value"); + + let bytes = writer.to_bytes().unwrap(); + + // Read it back + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + + assert_eq!(reader.metadata().tensors.len(), 1); + assert_eq!( + reader.metadata().metadata.get("test"), + Some(&"value".to_string()) + ); + + // Get tensor data + let tensor_data = reader.get_tensor_data("test_tensor").unwrap(); + assert_eq!(tensor_data, &[1, 2, 3, 4]); +} + +#[test] +fn test_reader_invalid_magic_number() { + let mut bytes = vec![0u8; 100]; + // Write invalid magic number + bytes[magic_range()].copy_from_slice(b"NOPE"); + + let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes)); + assert!(matches!(result, Err(BurnpackError::InvalidMagicNumber))); +} + +#[test] +fn test_reader_invalid_version() { + let mut bytes = vec![0u8; 100]; + // Write correct magic but invalid version + bytes[magic_range()].copy_from_slice(&MAGIC_NUMBER.to_le_bytes()); + bytes[version_range()].copy_from_slice(&999u16.to_le_bytes()); // Invalid version + bytes[metadata_size_range()].copy_from_slice(&10u32.to_le_bytes()); // Metadata size + + let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes)); + assert!(matches!(result, Err(BurnpackError::InvalidVersion))); +} + +#[test] +fn test_reader_header_too_short() { + let bytes = vec![0u8; 5]; // Less than HEADER_SIZE + + let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes)); + assert!(matches!(result, Err(BurnpackError::InvalidHeader))); +} + +#[test] +fn test_reader_metadata_truncated() { + let mut bytes = vec![0u8; HEADER_SIZE + 10]; + // Write valid header + bytes[magic_range()].copy_from_slice(&MAGIC_NUMBER.to_le_bytes()); + bytes[version_range()].copy_from_slice(&FORMAT_VERSION.to_le_bytes()); + bytes[metadata_size_range()].copy_from_slice(&100u32.to_le_bytes()); // Claims 100 bytes of metadata + + // But only provide 10 bytes after header + let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes)); + assert!(matches!(result, Err(BurnpackError::InvalidHeader))); +} + +#[test] +fn test_reader_get_tensor_not_found() { + let writer = BurnpackWriter::new(Vec::new()); + let bytes = writer.to_bytes().unwrap(); + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + + let result = reader.get_tensor_data("non_existent"); + assert!(matches!(result, Err(BurnpackError::TensorNotFound(_)))); +} + +#[test] +fn test_reader_get_tensor_snapshot() { + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let bytes: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32), + vec!["weights".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + let writer_bytes = writer.to_bytes().unwrap(); + let reader = BurnpackReader::from_bytes(writer_bytes).unwrap(); + + // Get tensor as snapshot + let loaded_snapshot = reader.get_tensor_snapshot("weights").unwrap(); + + // Verify snapshot metadata + assert_eq!(loaded_snapshot.full_path(), "weights"); + assert_eq!(loaded_snapshot.dtype, DType::F32); + assert_eq!(loaded_snapshot.shape, vec![2, 2]); + + // Verify data through closure + let tensor_data = loaded_snapshot.to_data().unwrap(); + assert_eq!(tensor_data.shape, vec![2, 2]); +} + +#[test] +fn test_reader_multiple_tensors() { + // Add multiple tensors + let mut snapshots = Vec::new(); + for i in 0..10 { + let name = format!("tensor_{}", i); + let data = vec![i as u8; 100]; + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(data, vec![100], DType::U8), + vec![name.clone()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot); + } + + let writer = BurnpackWriter::new(snapshots); + let bytes = writer.to_bytes().unwrap(); + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + + // Verify all tensors can be read + for i in 0..10 { + let name = format!("tensor_{}", i); + let data = reader.get_tensor_data(&name).unwrap(); + assert_eq!(data.len(), 100); + assert!(data.iter().all(|&b| b == i as u8)); + } +} + +#[test] +fn test_reader_lazy_loading() { + // Create large tensor + let size = 1024 * 1024; // 1MB + let data = vec![42u8; size]; + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(data.clone(), vec![size], DType::U8), + vec!["large".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + let bytes = writer.to_bytes().unwrap(); + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + + // Get snapshot (should be lazy) + let snapshot = reader.get_tensor_snapshot("large").unwrap(); + + // Data should only be accessed when to_data is called + let tensor_data = snapshot.to_data().unwrap(); + assert_eq!(tensor_data.bytes.len(), size); + assert!(tensor_data.bytes.iter().all(|&b| b == 42)); +} + +#[test] +fn test_reader_all_dtypes() { + // Test all data types + let test_data = vec![ + (DType::F32, vec![1.0f32.to_le_bytes().to_vec()].concat()), + (DType::F64, vec![2.0f64.to_le_bytes().to_vec()].concat()), + (DType::I32, vec![3i32.to_le_bytes().to_vec()].concat()), + (DType::I64, vec![4i64.to_le_bytes().to_vec()].concat()), + (DType::U32, vec![5u32.to_le_bytes().to_vec()].concat()), + (DType::U64, vec![6u64.to_le_bytes().to_vec()].concat()), + (DType::U8, vec![7u8]), + (DType::Bool, vec![1u8]), + ]; + + let mut snapshots = Vec::new(); + for (i, (dtype, data)) in test_data.iter().enumerate() { + let name = format!("tensor_{}", i); + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(data.clone(), vec![1], *dtype), + vec![name.clone()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot); + } + + let writer = BurnpackWriter::new(snapshots); + let bytes = writer.to_bytes().unwrap(); + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + + // Verify all dtypes are preserved + for (i, (expected_dtype, expected_data)) in test_data.iter().enumerate() { + let name = format!("tensor_{}", i); + let snapshot = reader.get_tensor_snapshot(&name).unwrap(); + assert_eq!(snapshot.dtype, *expected_dtype); + + let data = reader.get_tensor_data(&name).unwrap(); + assert_eq!(data, expected_data.as_slice()); + } +} + +#[test] +fn test_reader_empty_tensor() { + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![], vec![0], DType::F32), + vec!["empty".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + let bytes = writer.to_bytes().unwrap(); + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + + let data = reader.get_tensor_data("empty").unwrap(); + assert_eq!(data.len(), 0); + + let snapshot = reader.get_tensor_snapshot("empty").unwrap(); + assert_eq!(snapshot.shape, vec![0]); +} + +#[cfg(feature = "std")] +#[test] +fn test_reader_from_file() { + use tempfile::tempdir; + + let dir = tempdir().unwrap(); + let file_path = dir.path().join("test.bpk"); + + // Create test file + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![10, 20, 30], vec![3], DType::U8), + vec!["file_tensor".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("from_file_test", "true"); + + writer.write_to_file(&file_path).unwrap(); + + // Read from file + let reader = BurnpackReader::from_file(&file_path).unwrap(); + + assert_eq!( + reader.metadata().metadata.get("from_file_test"), + Some(&"true".to_string()) + ); + + let data = reader.get_tensor_data("file_tensor").unwrap(); + assert_eq!(data, &[10, 20, 30]); +} + +#[cfg(all(feature = "std", feature = "memmap"))] +#[test] +fn test_reader_from_file_mmap() { + use tempfile::tempdir; + + let dir = tempdir().unwrap(); + let file_path = dir.path().join("test_mmap.bpk"); + + // Create large test file + let size = 1024 * 1024; // 1MB + let data = vec![99u8; size]; + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(data, vec![size], DType::U8), + vec!["large_mmap".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + writer.write_to_file(&file_path).unwrap(); + + // Read using mmap + let reader = BurnpackReader::from_file_mmap(&file_path).unwrap(); + + let data = reader.get_tensor_data("large_mmap").unwrap(); + assert_eq!(data.len(), size); + assert!(data.iter().all(|&b| b == 99)); +} + +#[cfg(feature = "std")] +#[test] +fn test_reader_from_file_buffered() { + use tempfile::tempdir; + + let dir = tempdir().unwrap(); + let file_path = dir.path().join("test_buffered.bpk"); + + // Create test file + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![5, 10, 15], vec![3], DType::U8), + vec!["buffered_tensor".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + writer.write_to_file(&file_path).unwrap(); + + // Read using buffered reader + let reader = BurnpackReader::from_file_buffered(&file_path).unwrap(); + + let data = reader.get_tensor_data("buffered_tensor").unwrap(); + assert_eq!(data, &[5, 10, 15]); +} + +#[test] +fn test_reader_metadata_access() { + // Add various metadata using builder pattern + let writer = BurnpackWriter::new(Vec::new()) + .with_metadata("model_name", "test_model") + .with_metadata("version", "1.2.3") + .with_metadata("author", "test_author") + .with_metadata("description", "A test model"); + + let bytes = writer.to_bytes().unwrap(); + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + + let metadata = reader.metadata(); + assert_eq!(metadata.metadata.len(), 4); + assert_eq!( + metadata.metadata.get("model_name"), + Some(&"test_model".to_string()) + ); + assert_eq!(metadata.metadata.get("version"), Some(&"1.2.3".to_string())); + assert_eq!( + metadata.metadata.get("author"), + Some(&"test_author".to_string()) + ); + assert_eq!( + metadata.metadata.get("description"), + Some(&"A test model".to_string()) + ); +} + +#[test] +fn test_reader_tensor_iteration() { + // Add tensors + let tensor_names = vec!["weights", "bias", "running_mean", "running_var"]; + let mut snapshots = Vec::new(); + for name in &tensor_names { + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8), + vec![name.to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot); + } + + let writer = BurnpackWriter::new(snapshots); + let bytes = writer.to_bytes().unwrap(); + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + + // Iterate through all tensors + let metadata = reader.metadata(); + assert_eq!(metadata.tensors.len(), 4); + + // Check that all expected tensor names are present + for name in &tensor_names { + let tensor_desc = metadata.tensors.get(*name).unwrap(); + assert_eq!(tensor_desc.shape, vec![4u64]); + assert_eq!(tensor_desc.dtype, DType::U8); + } + + // Verify the keys match the expected names + let mut actual_names: Vec<_> = metadata.tensors.keys().cloned().collect(); + actual_names.sort(); + let mut expected_names = tensor_names + .iter() + .map(|s| s.to_string()) + .collect::>(); + expected_names.sort(); + assert_eq!(actual_names, expected_names); +} + +#[test] +fn test_reader_corrupt_metadata() { + let mut bytes = vec![0u8; 100]; + + // Write valid header + bytes[magic_range()].copy_from_slice(&MAGIC_NUMBER.to_le_bytes()); + bytes[version_range()].copy_from_slice(&FORMAT_VERSION.to_le_bytes()); + bytes[metadata_size_range()].copy_from_slice(&50u32.to_le_bytes()); // 50 bytes of metadata + + // Write garbage as metadata + for i in HEADER_SIZE..HEADER_SIZE + 50 { + bytes[i] = 0xFF; + } + + let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes)); + assert!(result.is_err()); +} + +#[test] +fn test_reader_data_offsets_validation() { + // Add two tensors + let snapshot1 = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8), + vec!["tensor1".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + let snapshot2 = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![5, 6, 7, 8], vec![4], DType::U8), + vec!["tensor2".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot1, snapshot2]); + let bytes = writer.to_bytes().unwrap(); + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + + // Verify offsets don't overlap + let metadata = reader.metadata(); + let tensor1_desc = metadata.tensors.get("tensor1").unwrap(); + let tensor2_desc = metadata.tensors.get("tensor2").unwrap(); + + assert_eq!(tensor1_desc.data_offsets, (0, 4)); + assert_eq!(tensor2_desc.data_offsets, (4, 8)); +} + +#[test] +fn test_reader_out_of_bounds_error() { + use crate::burnpack::reader::StorageBackend; + use alloc::rc::Rc; + + // Create a small data buffer + let data = Bytes::from_bytes_vec(vec![1, 2, 3, 4, 5]); + let backend = StorageBackend::Memory(Rc::new(data)); + + // Try to read beyond the available data + let mut buffer = vec![0u8; 10]; + let result = backend.read_into(&mut buffer, 0); + + // Should return an error + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.to_string().contains("out of bounds")); +} + +#[test] +fn test_reader_offset_overflow_error() { + use crate::burnpack::reader::StorageBackend; + use alloc::rc::Rc; + + let data = Bytes::from_bytes_vec(vec![1, 2, 3, 4, 5]); + let backend = StorageBackend::Memory(Rc::new(data)); + + // Try to read with an offset that would overflow + let mut buffer = vec![0u8; 10]; + let result = backend.read_into(&mut buffer, usize::MAX - 5); + + // Should return an error about overflow + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.to_string().contains("overflow")); +} + +#[test] +fn test_reader_corrupted_shape_returns_error() { + // Only test this on platforms where usize is smaller than u64 + // On 64-bit platforms, u64 values can fit in usize + #[cfg(target_pointer_width = "32")] + { + use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor}; + use alloc::collections::BTreeMap; + use alloc::rc::Rc; + use burn_tensor::DType; + + // Create metadata with a shape dimension that exceeds usize::MAX on 32-bit platforms + let mut tensors = BTreeMap::new(); + tensors.insert( + "corrupted_tensor".to_string(), + TensorDescriptor { + dtype: DType::F32, + shape: vec![u64::MAX, 2, 3], // First dimension exceeds usize::MAX on 32-bit + data_offsets: (0, 100), + param_id: None, + }, + ); + + let metadata = BurnpackMetadata { + tensors, + metadata: BTreeMap::new(), + }; + + // Create a small data buffer + let data = Bytes::from_bytes_vec(vec![0u8; 1000]); + let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data)); + + let reader = BurnpackReader { + metadata, + storage: backend, + data_offset: 0, + }; + + // This should return an error, not panic + let result = reader.get_snapshots(); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, BurnpackError::ValidationError(_))); + assert!( + err.to_string().contains("corrupted shape data") + || err.to_string().contains("exceeds platform maximum") + ); + } + + #[cfg(not(target_pointer_width = "32"))] + { + // On 64-bit platforms, just pass the test + // The conversion logic is still correct, but u64 fits in usize + } +} + +#[test] +fn test_reader_corrupted_offsets_returns_error() { + // Only test this on platforms where usize is smaller than u64 + #[cfg(target_pointer_width = "32")] + { + use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor}; + use alloc::collections::BTreeMap; + use alloc::rc::Rc; + use burn_tensor::DType; + + // Create metadata with offsets that would overflow + let mut tensors = BTreeMap::new(); + tensors.insert( + "tensor_bad_offset".to_string(), + TensorDescriptor { + dtype: DType::F32, + shape: vec![2, 2], + data_offsets: (u64::MAX - 10, u64::MAX), // Offsets that exceed usize::MAX on 32-bit + param_id: None, + }, + ); + + let metadata = BurnpackMetadata { + tensors, + metadata: BTreeMap::new(), + }; + + let data = Bytes::from_bytes_vec(vec![0u8; 1000]); + let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data)); + + let reader = BurnpackReader { + metadata, + storage: backend, + data_offset: 0, + }; + + // This should return an error, not panic + let result = reader.get_snapshots(); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, BurnpackError::ValidationError(_))); + assert!( + err.to_string().contains("corrupted offset data") + || err.to_string().contains("exceeds platform maximum") + ); + } + + #[cfg(not(target_pointer_width = "32"))] + { + use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor}; + use alloc::collections::BTreeMap; + use alloc::rc::Rc; + use burn_tensor::DType; + + // On 64-bit platforms, test offset overflow during addition + let mut tensors = BTreeMap::new(); + tensors.insert( + "tensor_overflow".to_string(), + TensorDescriptor { + dtype: DType::F32, + shape: vec![2, 2], + data_offsets: (0, 100), + param_id: None, + }, + ); + + let metadata = BurnpackMetadata { + tensors, + metadata: BTreeMap::new(), + }; + + let data = Bytes::from_bytes_vec(vec![0u8; 1000]); + let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data)); + + // Use a data_offset that will overflow when added to the tensor offset + let reader = BurnpackReader { + metadata, + storage: backend, + data_offset: usize::MAX - 50, // Will overflow when added to 100 + }; + + // This should return an error, not panic + let result = reader.get_snapshots(); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, BurnpackError::ValidationError(_))); + assert!(err.to_string().contains("overflow")); + } +} + +#[test] +fn test_reader_inverted_offsets_returns_error() { + use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor}; + use alloc::collections::BTreeMap; + use alloc::rc::Rc; + use burn_tensor::DType; + + // Create metadata with end offset < start offset (corrupted) + let mut tensors = BTreeMap::new(); + tensors.insert( + "inverted_tensor".to_string(), + TensorDescriptor { + dtype: DType::F32, + shape: vec![2, 2], + data_offsets: (100, 50), // End offset < start offset + param_id: None, + }, + ); + + let metadata = BurnpackMetadata { + tensors, + metadata: BTreeMap::new(), + }; + + let data = Bytes::from_bytes_vec(vec![0u8; 1000]); + let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data)); + + let reader = BurnpackReader { + metadata, + storage: backend, + data_offset: 0, + }; + + // This should return an error, not panic + let result = reader.get_snapshots(); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, BurnpackError::ValidationError(_))); + assert!(err.to_string().contains("end offset") && err.to_string().contains("start offset")); +} + +#[test] +fn test_reader_truncated_file_from_bytes() { + // Create a valid burnpack with tensor data + let tensor_size = 1024; // 1KB of data + let data = vec![42u8; tensor_size]; + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(data, vec![tensor_size], DType::U8), + vec!["large_tensor".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + let full_bytes = writer.to_bytes().unwrap(); + + // Truncate the bytes - remove the last 512 bytes of tensor data + let truncated_len = full_bytes.len() - 512; + let truncated_bytes = Bytes::from_bytes_vec(full_bytes.to_vec()[..truncated_len].to_vec()); + + // This should fail with a validation error indicating file truncation + let result = BurnpackReader::from_bytes(truncated_bytes); + assert!(result.is_err()); + if let Err(err) = result { + assert!(matches!(err, BurnpackError::ValidationError(_))); + assert!(err.to_string().contains("File truncated")); + assert!(err.to_string().contains("expected at least")); + } +} + +#[cfg(feature = "std")] +#[test] +fn test_reader_truncated_file_from_file() { + use std::fs::OpenOptions; + use tempfile::tempdir; + + let dir = tempdir().unwrap(); + let file_path = dir.path().join("truncated.bpk"); + + // Create a valid burnpack file with tensor data + let tensor_size = 2048; // 2KB of data + let data = vec![99u8; tensor_size]; + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(data, vec![tensor_size], DType::U8), + vec!["data_tensor".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + writer.write_to_file(&file_path).unwrap(); + + // Read the full file to get its size + let full_size = std::fs::metadata(&file_path).unwrap().len(); + + // Truncate the file - remove the last 1KB + let truncated_size = full_size - 1024; + let truncated_file = OpenOptions::new().write(true).open(&file_path).unwrap(); + truncated_file.set_len(truncated_size).unwrap(); + drop(truncated_file); + + // Try to read the truncated file - should fail with validation error + let result = BurnpackReader::from_file(&file_path); + assert!(result.is_err()); + if let Err(err) = result { + assert!(matches!(err, BurnpackError::ValidationError(_))); + assert!(err.to_string().contains("File truncated")); + assert!(err.to_string().contains("expected at least")); + } +} + +#[test] +fn test_reader_file_size_exactly_correct() { + // Test that a file with exactly the right size passes validation + let tensor_size = 100; + let data = vec![77u8; tensor_size]; + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(data, vec![tensor_size], DType::U8), + vec!["exact_size".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + let bytes = writer.to_bytes().unwrap(); + + // This should succeed - file is exactly the right size + let reader = BurnpackReader::from_bytes(bytes); + assert!(reader.is_ok()); + + // Verify we can read the data + let reader = reader.unwrap(); + let tensor_data = reader.get_tensor_data("exact_size").unwrap(); + assert_eq!(tensor_data.len(), tensor_size); + assert!(tensor_data.iter().all(|&b| b == 77)); +} diff --git a/crates/burn-store/src/burnpack/tests/round_trip.rs b/crates/burn-store/src/burnpack/tests/round_trip.rs new file mode 100644 index 0000000000..b187fbbe2e --- /dev/null +++ b/crates/burn-store/src/burnpack/tests/round_trip.rs @@ -0,0 +1,606 @@ +use crate::burnpack::{reader::BurnpackReader, writer::BurnpackWriter}; + +use super::*; +use alloc::collections::BTreeMap; +use alloc::string::String; +use burn_tensor::{DType, TensorData}; + +/// Helper function to perform round-trip test +fn round_trip_test(setup: F) +where + F: FnOnce(&mut Vec, &mut BTreeMap), +{ + // Collect snapshots and metadata + let mut snapshots = Vec::new(); + let mut metadata = BTreeMap::new(); + setup(&mut snapshots, &mut metadata); + + // Sort snapshots by name to ensure consistent ordering + // This is necessary because BTreeMap will store them sorted + snapshots.sort_by(|a, b| a.full_path().cmp(&b.full_path())); + + // Create writer with snapshots and metadata + let mut writer = BurnpackWriter::new(snapshots); + for (key, value) in &metadata { + writer = writer.with_metadata(key, value); + } + + let bytes = writer.to_bytes().unwrap(); + let reader = BurnpackReader::from_bytes(bytes.clone()).unwrap(); + + // Write to bytes again from reader data + let mut snapshots2 = Vec::new(); + + // Copy tensors (metadata.tensors is now BTreeMap) + // They will come out in sorted order from tensor_names() + for tensor_name in reader.tensor_names() { + let snapshot = reader.get_tensor_snapshot(tensor_name).unwrap(); + snapshots2.push(snapshot); + } + + // Create writer2 with collected snapshots and metadata + let mut writer2 = BurnpackWriter::new(snapshots2); + for (key, value) in &reader.metadata().metadata { + writer2 = writer2.with_metadata(key, value); + } + + let bytes2 = writer2.to_bytes().unwrap(); + + // Both byte representations should be identical + assert_eq!(bytes, bytes2, "Round-trip produced different bytes"); +} + +#[test] +fn test_round_trip_empty() { + round_trip_test(|_snapshots, _metadata| { + // Empty writer + }); +} + +#[test] +fn test_round_trip_metadata_only() { + round_trip_test(|_snapshots, metadata| { + metadata.insert("key1".to_string(), "value1".to_string()); + metadata.insert("key2".to_string(), "value2".to_string()); + metadata.insert("key3".to_string(), "value3".to_string()); + }); +} + +#[test] +fn test_round_trip_f32() { + round_trip_test(|snapshots, _metadata| { + let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; + let bytes: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes, vec![5], DType::F32), + vec!["f32_tensor".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot); + }); +} + +#[test] +fn test_round_trip_f64() { + round_trip_test(|snapshots, _metadata| { + let data = vec![1.0f64, 2.0, 3.0]; + let bytes: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes, vec![3], DType::F64), + vec!["f64_tensor".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot); + }); +} + +#[test] +fn test_round_trip_i32() { + round_trip_test(|snapshots, _metadata| { + let data = vec![-10i32, 0, 10, 20]; + let bytes: Vec = data.iter().flat_map(|i| i.to_le_bytes()).collect(); + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes, vec![4], DType::I32), + vec!["i32_tensor".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot); + }); +} + +#[test] +fn test_round_trip_i64() { + round_trip_test(|snapshots, _metadata| { + let data = vec![i64::MIN, 0, i64::MAX]; + let bytes: Vec = data.iter().flat_map(|i| i.to_le_bytes()).collect(); + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes, vec![3], DType::I64), + vec!["i64_tensor".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot); + }); +} + +#[test] +fn test_round_trip_u32() { + round_trip_test(|snapshots, _metadata| { + let data = vec![0u32, 100, 1000, u32::MAX]; + let bytes: Vec = data.iter().flat_map(|u| u.to_le_bytes()).collect(); + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes, vec![4], DType::U32), + vec!["u32_tensor".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot); + }); +} + +#[test] +fn test_round_trip_u64() { + round_trip_test(|snapshots, _metadata| { + let data = vec![0u64, u64::MAX / 2, u64::MAX]; + let bytes: Vec = data.iter().flat_map(|u| u.to_le_bytes()).collect(); + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes, vec![3], DType::U64), + vec!["u64_tensor".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot); + }); +} + +#[test] +fn test_round_trip_u8() { + round_trip_test(|snapshots, _metadata| { + let data = vec![0u8, 127, 255]; + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(data, vec![3], DType::U8), + vec!["u8_tensor".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot); + }); +} + +#[test] +fn test_round_trip_bool() { + round_trip_test(|snapshots, _metadata| { + let data = vec![0u8, 1, 0, 1, 1]; + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(data, vec![5], DType::Bool), + vec!["bool_tensor".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot); + }); +} + +#[test] +fn test_round_trip_mixed_dtypes() { + round_trip_test(|snapshots, _metadata| { + // F32 + let f32_data = vec![1.0f32, 2.0]; + let f32_bytes: Vec = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect(); + let f32_snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(f32_bytes, vec![2], DType::F32), + vec!["f32".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(f32_snapshot); + + // I64 + let i64_data = vec![100i64, 200]; + let i64_bytes: Vec = i64_data.iter().flat_map(|i| i.to_le_bytes()).collect(); + let i64_snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(i64_bytes, vec![2], DType::I64), + vec!["i64".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(i64_snapshot); + + // Bool + let bool_snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1, 0, 1], vec![3], DType::Bool), + vec!["bool".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(bool_snapshot); + }); +} + +#[test] +fn test_round_trip_multidimensional() { + round_trip_test(|snapshots, _metadata| { + // 2D tensor + let data_2d = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let bytes_2d: Vec = data_2d.iter().flat_map(|f| f.to_le_bytes()).collect(); + let snapshot_2d = TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes_2d, vec![2, 3], DType::F32), + vec!["tensor_2d".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot_2d); + + // 3D tensor + let data_3d = vec![1.0f32; 24]; + let bytes_3d: Vec = data_3d.iter().flat_map(|f| f.to_le_bytes()).collect(); + let snapshot_3d = TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes_3d, vec![2, 3, 4], DType::F32), + vec!["tensor_3d".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot_3d); + + // 4D tensor (common for CNNs) + let data_4d = vec![1.0f32; 120]; + let bytes_4d: Vec = data_4d.iter().flat_map(|f| f.to_le_bytes()).collect(); + let snapshot_4d = TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes_4d, vec![2, 3, 4, 5], DType::F32), + vec!["tensor_4d".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot_4d); + }); +} + +#[test] +fn test_round_trip_with_metadata_and_tensors() { + round_trip_test(|snapshots, metadata| { + // Add metadata + metadata.insert("model_name".to_string(), "test_model".to_string()); + metadata.insert("version".to_string(), "1.0.0".to_string()); + metadata.insert( + "description".to_string(), + "A test model for round-trip testing".to_string(), + ); + + // Add tensors + let weights = vec![0.1f32, 0.2, 0.3, 0.4]; + let weights_bytes: Vec = weights.iter().flat_map(|f| f.to_le_bytes()).collect(); + let weights_snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(weights_bytes, vec![2, 2], DType::F32), + vec!["layer1".to_string(), "weights".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(weights_snapshot); + + let bias = vec![0.5f32, 0.6]; + let bias_bytes: Vec = bias.iter().flat_map(|f| f.to_le_bytes()).collect(); + let bias_snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(bias_bytes, vec![2], DType::F32), + vec!["layer1".to_string(), "bias".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(bias_snapshot); + }); +} + +#[test] +fn test_round_trip_special_values() { + round_trip_test(|snapshots, _metadata| { + // Test special float values + let special_f32 = vec![ + 0.0f32, + -0.0, + f32::INFINITY, + f32::NEG_INFINITY, + f32::NAN, + f32::MIN, + f32::MAX, + f32::EPSILON, + ]; + let f32_bytes: Vec = special_f32.iter().flat_map(|f| f.to_le_bytes()).collect(); + let f32_snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(f32_bytes, vec![8], DType::F32), + vec!["special_f32".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(f32_snapshot); + + // Test special f64 values + let special_f64 = vec![ + 0.0f64, + -0.0, + f64::INFINITY, + f64::NEG_INFINITY, + f64::NAN, + f64::MIN, + f64::MAX, + f64::EPSILON, + ]; + let f64_bytes: Vec = special_f64.iter().flat_map(|f| f.to_le_bytes()).collect(); + let f64_snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(f64_bytes, vec![8], DType::F64), + vec!["special_f64".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(f64_snapshot); + }); +} + +#[test] +fn test_round_trip_large_tensors() { + round_trip_test(|snapshots, _metadata| { + // Large tensor (100KB) + let size = 25600; // 100KB / 4 bytes per f32 + let data: Vec = (0..size).map(|i| i as f32).collect(); + let bytes: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes, vec![size], DType::F32), + vec!["large_tensor".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot); + }); +} + +#[cfg(feature = "std")] +#[test] +fn test_round_trip_file_io() { + use std::fs; + use tempfile::tempdir; + + use crate::burnpack::writer::BurnpackWriter; + + let dir = tempdir().unwrap(); + let file_path = dir.path().join("round_trip.bpk"); + + // Create original data + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let bytes: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32), + vec!["weights".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "round_trip"); + + // Write to file + writer.write_to_file(&file_path).unwrap(); + + // Read from file + let reader = BurnpackReader::from_file(&file_path).unwrap(); + + // Write to another file + let file_path2 = dir.path().join("round_trip2.bpk"); + + // Collect snapshots from reader + let mut snapshots2 = Vec::new(); + for tensor_name in reader.tensor_names() { + let snapshot = reader.get_tensor_snapshot(tensor_name).unwrap(); + snapshots2.push(snapshot); + } + + // Create writer2 with snapshots and metadata + let mut writer2 = BurnpackWriter::new(snapshots2); + for (key, value) in &reader.metadata().metadata { + writer2 = writer2.with_metadata(key, value); + } + + writer2.write_to_file(&file_path2).unwrap(); + + // Compare files + let bytes1 = fs::read(&file_path).unwrap(); + let bytes2 = fs::read(&file_path2).unwrap(); + + assert_eq!( + bytes1, bytes2, + "Round-trip through files produced different content" + ); +} + +#[test] +fn test_round_trip_empty_shapes() { + round_trip_test(|snapshots, _metadata| { + // Scalar (0-dimensional) + let scalar = vec![42.0f32]; + let scalar_bytes: Vec = scalar.iter().flat_map(|f| f.to_le_bytes()).collect(); + let scalar_snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(scalar_bytes, vec![], DType::F32), + vec!["scalar".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(scalar_snapshot); + + // Empty tensor + let empty_snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![], vec![0], DType::F32), + vec!["empty".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(empty_snapshot); + }); +} + +#[test] +fn test_param_id_persistence() { + use burn_core::module::ParamId; + + // Create a specific ParamId with a known value + let original_param_id = ParamId::from(123456789u64); + + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let bytes: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32), + vec!["weights".to_string()], + vec![], + original_param_id, + ); + + // Write to burnpack + let writer = BurnpackWriter::new(vec![snapshot]); + let bytes = writer.to_bytes().unwrap(); + + // Read back from burnpack + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + let loaded_snapshot = reader.get_tensor_snapshot("weights").unwrap(); + + // Verify ParamId was preserved + assert!( + loaded_snapshot.tensor_id.is_some(), + "ParamId should be present" + ); + let loaded_param_id = loaded_snapshot.tensor_id.unwrap(); + assert_eq!( + loaded_param_id.val(), + original_param_id.val(), + "ParamId value should be preserved: expected {}, got {}", + original_param_id.val(), + loaded_param_id.val() + ); +} + +#[test] +fn test_param_id_backward_compatibility() { + use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor}; + use alloc::collections::BTreeMap; + + // Create metadata without param_id (simulating old burnpack format) + let mut tensors = BTreeMap::new(); + tensors.insert( + "old_tensor".to_string(), + TensorDescriptor { + dtype: DType::F32, + shape: vec![2, 2], + data_offsets: (0, 16), + param_id: None, // No param_id stored (old format) + }, + ); + + let metadata = BurnpackMetadata { + tensors, + metadata: BTreeMap::new(), + }; + + // Serialize metadata + let mut metadata_bytes = Vec::new(); + ciborium::ser::into_writer(&metadata, &mut metadata_bytes).unwrap(); + + // Create a complete burnpack with header and data + use crate::burnpack::base::{BurnpackHeader, FORMAT_VERSION, MAGIC_NUMBER}; + + let metadata_size = metadata_bytes.len() as u32; + let header = BurnpackHeader { + magic: MAGIC_NUMBER, + version: FORMAT_VERSION, + metadata_size, + }; + + let mut full_bytes = Vec::new(); + full_bytes.extend_from_slice(&header.into_bytes()); + full_bytes.extend_from_slice(&metadata_bytes); + + // Add tensor data (4 f32 values = 16 bytes) + let tensor_data = vec![1.0f32, 2.0, 3.0, 4.0]; + for value in tensor_data { + full_bytes.extend_from_slice(&value.to_le_bytes()); + } + + // Read the old format burnpack + let reader = + BurnpackReader::from_bytes(burn_tensor::Bytes::from_bytes_vec(full_bytes)).unwrap(); + let loaded_snapshot = reader.get_tensor_snapshot("old_tensor").unwrap(); + + // Verify that a new ParamId was generated (backward compatibility) + assert!( + loaded_snapshot.tensor_id.is_some(), + "ParamId should be generated for old format" + ); + + // The generated ParamId should be different each time (it's new), but we can't test the exact value + // We just verify it exists and has a valid u64 value + let generated_param_id = loaded_snapshot.tensor_id.unwrap(); + assert!( + generated_param_id.val() > 0, + "Generated ParamId should have a valid value" + ); +} + +#[test] +fn test_multiple_tensors_preserve_distinct_param_ids() { + use burn_core::module::ParamId; + + // Create multiple tensors with distinct ParamIds + let param_id_1 = ParamId::from(111111u64); + let param_id_2 = ParamId::from(222222u64); + let param_id_3 = ParamId::from(333333u64); + + let mut snapshots = Vec::new(); + + let data1 = vec![1.0f32, 2.0]; + let bytes1: Vec = data1.iter().flat_map(|f| f.to_le_bytes()).collect(); + snapshots.push(TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes1, vec![2], DType::F32), + vec!["tensor1".to_string()], + vec![], + param_id_1, + )); + + let data2 = vec![3.0f32, 4.0]; + let bytes2: Vec = data2.iter().flat_map(|f| f.to_le_bytes()).collect(); + snapshots.push(TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes2, vec![2], DType::F32), + vec!["tensor2".to_string()], + vec![], + param_id_2, + )); + + let data3 = vec![5.0f32, 6.0]; + let bytes3: Vec = data3.iter().flat_map(|f| f.to_le_bytes()).collect(); + snapshots.push(TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes3, vec![2], DType::F32), + vec!["tensor3".to_string()], + vec![], + param_id_3, + )); + + // Write to burnpack + let writer = BurnpackWriter::new(snapshots); + let bytes = writer.to_bytes().unwrap(); + + // Read back + let reader = BurnpackReader::from_bytes(bytes).unwrap(); + + let snapshot1 = reader.get_tensor_snapshot("tensor1").unwrap(); + let snapshot2 = reader.get_tensor_snapshot("tensor2").unwrap(); + let snapshot3 = reader.get_tensor_snapshot("tensor3").unwrap(); + + // Verify each ParamId was preserved correctly + assert_eq!(snapshot1.tensor_id.unwrap().val(), param_id_1.val()); + assert_eq!(snapshot2.tensor_id.unwrap().val(), param_id_2.val()); + assert_eq!(snapshot3.tensor_id.unwrap().val(), param_id_3.val()); + + // Verify they are distinct + let id1 = snapshot1.tensor_id.unwrap().val(); + let id2 = snapshot2.tensor_id.unwrap().val(); + let id3 = snapshot3.tensor_id.unwrap().val(); + + assert_ne!(id1, id2, "ParamIds should be distinct"); + assert_ne!(id2, id3, "ParamIds should be distinct"); + assert_ne!(id1, id3, "ParamIds should be distinct"); +} diff --git a/crates/burn-store/src/burnpack/tests/store.rs b/crates/burn-store/src/burnpack/tests/store.rs new file mode 100644 index 0000000000..8372ae2abc --- /dev/null +++ b/crates/burn-store/src/burnpack/tests/store.rs @@ -0,0 +1,854 @@ +#[cfg(feature = "std")] +use crate::KeyRemapper; +use crate::burnpack::store::BurnpackStore; +use crate::{ModuleSnapshot, ModuleStore, PathFilter}; +use burn_core::module::{Module, Param}; +use burn_tensor::{Tensor, backend::Backend}; + +type TestBackend = burn_ndarray::NdArray; + +#[derive(Module, Debug)] +struct TestModule { + weight: Param>, + bias: Param>, + nested: NestedModule, +} + +#[derive(Module, Debug)] +struct NestedModule { + gamma: Param>, + beta: Param>, +} + +impl TestModule { + fn new(device: &B::Device) -> Self { + Self { + weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device), + bias: Param::from_data([0.1, 0.2], device), + nested: NestedModule { + gamma: Param::from_data([1.0, 1.0], device), + beta: Param::from_data([0.0, 0.0], device), + }, + } + } + + fn new_zeros(device: &B::Device) -> Self { + Self { + weight: Param::from_tensor(Tensor::zeros([2, 2], device)), + bias: Param::from_tensor(Tensor::zeros([2], device)), + nested: NestedModule { + gamma: Param::from_tensor(Tensor::zeros([2], device)), + beta: Param::from_tensor(Tensor::zeros([2], device)), + }, + } + } + + fn new_uninitialized(device: &B::Device) -> Self { + use burn_core::module::ParamId; + let device_clone = device.clone(); + let device_clone2 = device.clone(); + let device_clone3 = device.clone(); + let device_clone4 = device.clone(); + + Self { + weight: Param::uninitialized( + ParamId::new(), + move |d, _| Tensor::zeros([2, 2], d), + device_clone, + true, + [2, 2].into(), + ), + bias: Param::uninitialized( + ParamId::new(), + move |d, _| Tensor::zeros([2], d), + device_clone2, + true, + [2].into(), + ), + nested: NestedModule { + gamma: Param::uninitialized( + ParamId::new(), + move |d, _| Tensor::zeros([2], d), + device_clone3, + true, + [2].into(), + ), + beta: Param::uninitialized( + ParamId::new(), + move |d, _| Tensor::zeros([2], d), + device_clone4, + true, + [2].into(), + ), + }, + } + } +} + +#[test] +fn test_store_from_bytes_round_trip() { + let device = Default::default(); + let module = TestModule::::new(&device); + + // Save to bytes + let mut save_store = BurnpackStore::from_bytes(None); + save_store.collect_from(&module).unwrap(); + let bytes = save_store.get_bytes().unwrap(); + + // Load from bytes + let mut load_store = BurnpackStore::from_bytes(Some(bytes)); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + + // Verify success + assert!(result.is_success()); + assert_eq!(result.applied.len(), 4); // weight, bias, nested.gamma, nested.beta + assert!(result.errors.is_empty()); + + // Verify data was loaded correctly + let weight1 = module.weight.val().to_data().to_vec::().unwrap(); + let weight2 = module2.weight.val().to_data().to_vec::().unwrap(); + assert_eq!(weight1, weight2); +} + +#[test] +fn test_store_with_metadata() { + let device = Default::default(); + let module = TestModule::::new(&device); + + // Save with metadata + let mut save_store = BurnpackStore::from_bytes(None) + .metadata("version", "1.0.0") + .metadata("model_name", "test_model") + .metadata("author", "burn_team"); + + save_store.collect_from(&module).unwrap(); + let bytes = save_store.get_bytes().unwrap(); + + // Load and verify metadata is preserved + let mut load_store = BurnpackStore::from_bytes(Some(bytes)); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + + assert!(result.is_success()); + assert_eq!(result.applied.len(), 4); +} + +#[test] +#[cfg(feature = "std")] +fn test_store_with_path_filter() { + let device = Default::default(); + let module = TestModule::::new(&device); + + // Save all tensors + let mut save_store = BurnpackStore::from_bytes(None); + save_store.collect_from(&module).unwrap(); + let bytes = save_store.get_bytes().unwrap(); + + // Load with filter - only load weight and bias (not nested) + let mut load_store = BurnpackStore::from_bytes(Some(bytes)).with_regex("^(weight|bias)$"); + + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + + assert!(result.is_success()); + assert_eq!(result.applied.len(), 2); // Only weight and bias + assert_eq!(result.skipped.len(), 2); // nested.gamma and nested.beta skipped + + // Verify weight and bias were loaded + let weight2 = module2.weight.val().to_data().to_vec::().unwrap(); + assert_eq!(weight2, vec![1.0, 2.0, 3.0, 4.0]); + + // Verify nested module was NOT loaded (should still be zeros) + let gamma2 = module2 + .nested + .gamma + .val() + .to_data() + .to_vec::() + .unwrap(); + assert_eq!(gamma2, vec![0.0, 0.0]); +} + +#[test] +#[cfg(feature = "std")] +fn test_store_with_key_remapping() { + let device = Default::default(); + let module = TestModule::::new(&device); + + // Save with original names + let mut save_store = BurnpackStore::from_bytes(None); + save_store.collect_from(&module).unwrap(); + let bytes = save_store.get_bytes().unwrap(); + + // Load with remapping: nested.gamma -> nested.new_gamma, nested.beta -> nested.new_beta + let remapper = KeyRemapper::new() + .add_pattern(r"nested\.gamma", "nested.new_gamma") + .unwrap() + .add_pattern(r"nested\.beta", "nested.new_beta") + .unwrap(); + + let mut load_store = BurnpackStore::from_bytes(Some(bytes)) + .remap(remapper) + .allow_partial(true); + + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + + // The remapping should cause missing tensors since names don't match + assert_eq!(result.applied.len(), 2); // Only weight and bias match + assert_eq!(result.unused.len(), 2); // nested.new_gamma and nested.new_beta are unused + assert_eq!(result.missing.len(), 2); // nested.gamma and nested.beta are missing +} + +#[test] +fn test_store_allow_partial() { + let device = Default::default(); + let module = TestModule::::new(&device); + + // Save only weight and bias + let filter = PathFilter::new() + .with_full_path("weight") + .with_full_path("bias"); + let mut save_store = BurnpackStore::from_bytes(None).with_filter(filter); + save_store.collect_from(&module).unwrap(); + let bytes = save_store.get_bytes().unwrap(); + + // Load with allow_partial + let mut load_store = BurnpackStore::from_bytes(Some(bytes)).allow_partial(true); + + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + + assert!(result.is_success()); + assert_eq!(result.applied.len(), 2); + assert_eq!(result.missing.len(), 2); // nested.gamma and nested.beta are missing but that's OK + + // Verify loaded tensors + let weight2 = module2.weight.val().to_data().to_vec::().unwrap(); + assert_eq!(weight2, vec![1.0, 2.0, 3.0, 4.0]); +} + +#[test] +fn test_store_match_all() { + let device = Default::default(); + let module = TestModule::::new(&device); + + // Save with match_all filter (should save everything) + let mut save_store = BurnpackStore::from_bytes(None).match_all(); + save_store.collect_from(&module).unwrap(); + let bytes = save_store.get_bytes().unwrap(); + + // Load everything + let mut load_store = BurnpackStore::from_bytes(Some(bytes)); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + + assert!(result.is_success()); + assert_eq!(result.applied.len(), 4); + assert!(result.errors.is_empty()); + assert!(result.missing.is_empty()); + assert!(result.unused.is_empty()); +} + +#[test] +fn test_store_with_full_path() { + let device = Default::default(); + let module = TestModule::::new(&device); + + // Save everything + let mut save_store = BurnpackStore::from_bytes(None); + save_store.collect_from(&module).unwrap(); + let bytes = save_store.get_bytes().unwrap(); + + // Load only specific tensors by full path + let mut load_store = BurnpackStore::from_bytes(Some(bytes)) + .with_full_path("weight") + .with_full_path("nested.gamma"); + + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + + assert!(result.is_success()); + assert_eq!(result.applied.len(), 2); // Only weight and nested.gamma + assert_eq!(result.skipped.len(), 2); // bias and nested.beta skipped +} + +#[test] +#[cfg(feature = "std")] +fn test_store_chain_multiple_patterns() { + let device = Default::default(); + let module = TestModule::::new(&device); + + // Save with chained metadata and filters + let mut save_store = BurnpackStore::from_bytes(None) + .metadata("version", "1.0") + .metadata("format", "burnpack") + .with_regex(r"^(weight|nested\.)") + .match_all(); // This overrides the previous filter + + save_store.collect_from(&module).unwrap(); + let bytes = save_store.get_bytes().unwrap(); + + // Load everything since match_all was called last + let mut load_store = BurnpackStore::from_bytes(Some(bytes)); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + + assert!(result.is_success()); + assert_eq!(result.applied.len(), 4); // All tensors loaded +} + +#[test] +#[cfg(feature = "std")] +fn test_store_with_remap_pattern() { + let device = Default::default(); + let module = TestModule::::new(&device); + + // Save normally + let mut save_store = BurnpackStore::from_bytes(None); + save_store.collect_from(&module).unwrap(); + let bytes = save_store.get_bytes().unwrap(); + + // Load with single remap pattern using the convenience method + let mut load_store = BurnpackStore::from_bytes(Some(bytes)) + .with_remap_pattern(r"^nested\.", "sub_module.") + .allow_partial(true); + + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + + // After remapping, nested.* becomes sub_module.*, which won't match + assert_eq!(result.applied.len(), 2); // Only weight and bias + assert_eq!(result.unused.len(), 2); // sub_module.gamma and sub_module.beta unused +} + +#[test] +fn test_store_default_metadata() { + let device = Default::default(); + let module = TestModule::::new(&device); + + // Save without adding custom metadata + let mut save_store = BurnpackStore::from_bytes(None); + save_store.collect_from(&module).unwrap(); + let bytes = save_store.get_bytes().unwrap(); + + // Verify default metadata is included + // We can't directly inspect metadata from bytes, but we can verify + // that the model loads successfully which means metadata was written correctly + let mut load_store = BurnpackStore::from_bytes(Some(bytes)); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + + assert!(result.is_success()); +} + +#[test] +fn test_store_default_metadata_with_custom() { + let device = Default::default(); + let module = TestModule::::new(&device); + + // Save with custom metadata (should preserve defaults) + let mut save_store = BurnpackStore::from_bytes(None) + .metadata("custom_field", "custom_value") + .metadata("author", "test_author"); + save_store.collect_from(&module).unwrap(); + let bytes = save_store.get_bytes().unwrap(); + + // Load and verify it works (metadata including defaults was saved) + let mut load_store = BurnpackStore::from_bytes(Some(bytes)); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + + assert!(result.is_success()); +} + +#[test] +fn test_store_clear_metadata() { + let device = Default::default(); + let module = TestModule::::new(&device); + + // Save with cleared metadata (no defaults) + let mut save_store = BurnpackStore::from_bytes(None).clear_metadata(); + save_store.collect_from(&module).unwrap(); + let bytes = save_store.get_bytes().unwrap(); + + // Verify it still loads correctly + let mut load_store = BurnpackStore::from_bytes(Some(bytes)); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + + assert!(result.is_success()); +} + +#[test] +fn test_store_validate_enabled() { + let device = Default::default(); + let module = TestModule::::new(&device); + + // Save normally + let mut save_store = BurnpackStore::from_bytes(None); + save_store.collect_from(&module).unwrap(); + let bytes = save_store.get_bytes().unwrap(); + + // Load with validation enabled (default) + let mut load_store = BurnpackStore::from_bytes(Some(bytes)); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + + assert!(result.is_success()); + assert!(result.errors.is_empty()); +} + +#[test] +fn test_store_validate_disabled() { + let device = Default::default(); + let module = TestModule::::new(&device); + + // Save normally + let mut save_store = BurnpackStore::from_bytes(None); + save_store.collect_from(&module).unwrap(); + let bytes = save_store.get_bytes().unwrap(); + + // Load with validation disabled + let mut load_store = BurnpackStore::from_bytes(Some(bytes)).validate(false); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + + // Should still succeed + assert!(result.is_success()); +} + +#[test] +fn test_store_allow_partial_missing_tensors() { + let device = Default::default(); + let module = TestModule::::new(&device); + + // Save only weight (not bias or nested) + let filter = PathFilter::new().with_full_path("weight"); + let mut save_store = BurnpackStore::from_bytes(None).with_filter(filter); + save_store.collect_from(&module).unwrap(); + let bytes = save_store.get_bytes().unwrap(); + + // Try to load without allow_partial - should fail due to missing tensors + let mut load_store = BurnpackStore::from_bytes(Some(bytes.clone())); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2); + + // Should fail because of missing tensors + assert!(result.is_err()); + + // Now try with allow_partial - should succeed + let mut load_store = BurnpackStore::from_bytes(Some(bytes)).allow_partial(true); + let mut module3 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module3).unwrap(); + + assert!(result.is_success()); + assert_eq!(result.applied.len(), 1); // Only weight + assert!(!result.missing.is_empty()); // Has missing tensors +} + +#[test] +#[cfg(feature = "std")] +fn test_store_file_round_trip() { + use tempfile::tempdir; + + let device = Default::default(); + let module = TestModule::::new(&device); + + // Create temp directory and file path + let temp_dir = tempdir().unwrap(); + let path = temp_dir.path().join("test_file_round_trip.bpk"); + + // Save to file + let mut save_store = BurnpackStore::from_file(&path).metadata("test", "value"); + save_store.collect_from(&module).unwrap(); + + // Verify file exists + assert!(path.exists()); + + // Load from file + let mut load_store = BurnpackStore::from_file(&path); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + + assert!(result.is_success()); + assert_eq!(result.applied.len(), 4); + + // Verify data + let weight1 = module.weight.val().to_data().to_vec::().unwrap(); + let weight2 = module2.weight.val().to_data().to_vec::().unwrap(); + assert_eq!(weight1, weight2); +} + +#[test] +#[cfg(feature = "std")] +fn test_store_overwrite_protection() { + use tempfile::tempdir; + + let device = Default::default(); + let module = TestModule::::new(&device); + + // Create temp directory and file path (file doesn't exist yet) + let temp_dir = tempdir().unwrap(); + let path = temp_dir.path().join("test_model.bpk"); + + // First save - should succeed + let mut save_store = BurnpackStore::from_file(&path); + save_store.collect_from(&module).unwrap(); + assert!(path.exists()); + + // Second save without overwrite flag - should fail + let mut save_store2 = BurnpackStore::from_file(&path); + let result = save_store2.collect_from(&module); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("File already exists") + ); + + // Third save with overwrite flag - should succeed + let mut save_store3 = BurnpackStore::from_file(&path).overwrite(true); + save_store3.collect_from(&module).unwrap(); + + // Verify file still exists and is valid + let mut load_store = BurnpackStore::from_file(&path); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + assert!(result.is_success()); +} + +#[test] +#[cfg(feature = "std")] +fn test_store_overwrite_with_metadata() { + use tempfile::tempdir; + + let device = Default::default(); + let module = TestModule::::new(&device); + + // Create temp directory and file path + let temp_dir = tempdir().unwrap(); + let path = temp_dir.path().join("test_model_metadata.bpk"); + + // First save with v1 metadata + let mut save_store = BurnpackStore::from_file(&path) + .metadata("version", "1.0") + .overwrite(true); + save_store.collect_from(&module).unwrap(); + + // Second save with v2 metadata and overwrite enabled + let mut save_store2 = BurnpackStore::from_file(&path) + .metadata("version", "2.0") + .overwrite(true); + save_store2.collect_from(&module).unwrap(); + + // Verify file loads correctly + let mut load_store = BurnpackStore::from_file(&path); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + assert!(result.is_success()); +} + +#[test] +#[cfg(feature = "std")] +fn test_store_auto_extension_default() { + use tempfile::tempdir; + + let device = Default::default(); + let module = TestModule::::new(&device); + + // Create temp directory + let temp_dir = tempdir().unwrap(); + let path = temp_dir.path().join("model"); + + // Save without extension - should auto-append .bpk + let mut save_store = BurnpackStore::from_file(&path); + save_store.collect_from(&module).unwrap(); + + // Verify that model.bpk was created + let expected_path = temp_dir.path().join("model.bpk"); + assert!(expected_path.exists()); + assert!(!path.exists()); // Original path without extension should not exist + + // Load using the path without extension - should work + let mut load_store = BurnpackStore::from_file(&path); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + assert!(result.is_success()); +} + +#[test] +#[cfg(feature = "std")] +fn test_store_auto_extension_with_existing_extension() { + use tempfile::tempdir; + + let device = Default::default(); + let module = TestModule::::new(&device); + + // Create temp directory + let temp_dir = tempdir().unwrap(); + let path = temp_dir.path().join("model.bpk"); + + // Save with .bpk extension - should not double append + let mut save_store = BurnpackStore::from_file(&path); + save_store.collect_from(&module).unwrap(); + + // Verify that only model.bpk was created + assert!(path.exists()); + let double_ext_path = temp_dir.path().join("model.bpk.bpk"); + assert!(!double_ext_path.exists()); + + // Load and verify + let mut load_store = BurnpackStore::from_file(&path); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + assert!(result.is_success()); +} + +#[test] +#[cfg(feature = "std")] +fn test_store_auto_extension_with_custom_extension() { + use tempfile::tempdir; + + let device = Default::default(); + let module = TestModule::::new(&device); + + // Create temp directory + let temp_dir = tempdir().unwrap(); + let path = temp_dir.path().join("model.mpk"); + + // Save with .mpk extension - should preserve it + let mut save_store = BurnpackStore::from_file(&path); + save_store.collect_from(&module).unwrap(); + + // Verify that model.mpk was created (not model.mpk.bpk) + assert!(path.exists()); + let burnpack_path = temp_dir.path().join("model.mpk.bpk"); + assert!(!burnpack_path.exists()); + + // Load and verify + let mut load_store = BurnpackStore::from_file(&path); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + assert!(result.is_success()); +} + +#[test] +#[cfg(feature = "std")] +fn test_store_auto_extension_disabled() { + use tempfile::tempdir; + + let device = Default::default(); + let module = TestModule::::new(&device); + + // Create temp directory + let temp_dir = tempdir().unwrap(); + let path = temp_dir.path().join("model"); + + // Save with auto_extension disabled - should use exact path + let mut save_store = BurnpackStore::from_file(&path).auto_extension(false); + save_store.collect_from(&module).unwrap(); + + // Verify that "model" (without extension) was created + assert!(path.exists()); + let burnpack_path = temp_dir.path().join("model.bpk"); + assert!(!burnpack_path.exists()); + + // Load with auto_extension disabled + let mut load_store = BurnpackStore::from_file(&path).auto_extension(false); + let mut module2 = TestModule::::new_zeros(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + assert!(result.is_success()); +} + +#[test] +#[cfg(feature = "std")] +fn test_partial_loading_preserves_lazy_initialization() { + use tempfile::tempdir; + + let device = Default::default(); + + // Create and save a full module + let module = TestModule::::new(&device); + let temp_dir = tempdir().unwrap(); + let path = temp_dir.path().join("model.bpk"); + + let mut save_store = BurnpackStore::from_file(&path); + save_store.collect_from(&module).unwrap(); + + // Create an uninitialized module (all params lazy) + let mut load_module = TestModule::::new_uninitialized(&device); + + // Before loading: verify ALL params are uninitialized (lazy) + assert!( + !load_module.weight.is_initialized(), + "weight should be uninitialized before loading" + ); + assert!( + !load_module.bias.is_initialized(), + "bias should be uninitialized before loading" + ); + assert!( + !load_module.nested.gamma.is_initialized(), + "nested.gamma should be uninitialized before loading" + ); + assert!( + !load_module.nested.beta.is_initialized(), + "nested.beta should be uninitialized before loading" + ); + + // Partial load: only load weight and bias (skip nested.*) + let filter = PathFilter::new().with_regex("^(weight|bias)$"); + let mut load_store = BurnpackStore::from_file(&path).filter(filter); + let result = load_module.load_from(&mut load_store).unwrap(); + + // Verify only weight and bias were loaded + assert_eq!(result.applied.len(), 2); + assert!(result.applied.contains(&"weight".to_string())); + assert!(result.applied.contains(&"bias".to_string())); + assert_eq!(result.skipped.len(), 2); + assert!(result.skipped.contains(&"nested.gamma".to_string())); + assert!(result.skipped.contains(&"nested.beta".to_string())); + + // After loading: verify loaded params are initialized, skipped remain lazy + assert!( + load_module.weight.is_initialized(), + "weight should be initialized after loading" + ); + assert!( + load_module.bias.is_initialized(), + "bias should be initialized after loading" + ); + assert!( + !load_module.nested.gamma.is_initialized(), + "nested.gamma should remain uninitialized (was skipped)" + ); + assert!( + !load_module.nested.beta.is_initialized(), + "nested.beta should remain uninitialized (was skipped)" + ); + + // Verify the loaded values are correct + let weight_data = load_module.weight.val().to_data().to_vec::().unwrap(); + assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]); + + let bias_data = load_module.bias.val().to_data().to_vec::().unwrap(); + assert_eq!(bias_data, vec![0.1, 0.2]); + + // Now check that nested params can still be initialized on first access + let gamma_data = load_module + .nested + .gamma + .val() + .to_data() + .to_vec::() + .unwrap(); + assert_eq!(gamma_data, vec![0.0, 0.0]); // Initialized to zeros via the init function + + // After accessing, they should be initialized + assert!( + load_module.nested.gamma.is_initialized(), + "nested.gamma should be initialized after first access" + ); +} + +// Model with forward pass for testing weight preservation +#[derive(Module, Debug)] +struct ForwardTestModel { + linear1: burn::nn::Linear, + linear2: burn::nn::Linear, +} + +impl ForwardTestModel { + /// Forward pass: input -> linear1 -> gelu -> linear2 + fn forward(&self, input: Tensor) -> Tensor { + let x = self.linear1.forward(input); + let x = burn::tensor::activation::gelu(x); + self.linear2.forward(x) + } +} + +#[derive(burn::config::Config, Debug)] +struct ForwardTestModelConfig { + input_size: usize, + hidden_size: usize, + output_size: usize, +} + +impl ForwardTestModelConfig { + fn init(&self, device: &B::Device) -> ForwardTestModel { + ForwardTestModel { + linear1: burn::nn::LinearConfig::new(self.input_size, self.hidden_size) + .with_bias(true) + .init(device), + linear2: burn::nn::LinearConfig::new(self.hidden_size, self.output_size) + .with_bias(true) + .init(device), + } + } +} + +#[test] +#[cfg(feature = "std")] +fn test_forward_pass_preservation_after_save_load() { + use tempfile::tempdir; + + let device = Default::default(); + + // Create model config + let config = ForwardTestModelConfig { + input_size: 4, + hidden_size: 8, + output_size: 2, + }; + + // Initialize model1 with random weights + let model1 = config.init::(&device); + + // Create random input + let input = Tensor::::random( + [1, 4], + burn_tensor::Distribution::Uniform(-1.0, 1.0), + &device, + ); + + // Forward pass with model1 -> output1 + let output1 = model1.forward(input.clone()); + + // Save model1 weights + let temp_dir = tempdir().unwrap(); + let path = temp_dir.path().join("forward_test_model.bpk"); + let mut save_store = BurnpackStore::from_file(&path); + save_store.collect_from(&model1).unwrap(); + + // Initialize model2 with different random weights + let mut model2 = config.init::(&device); + + // Forward pass with model2 -> output2 (should differ from output1) + let output2 = model2.forward(input.clone()); + + // Verify output2 differs from output1 (different random weights) + assert!( + !output1 + .clone() + .all_close(output2.clone(), Some(1e-6), Some(1e-6)), + "output2 should differ from output1 (different random initializations)" + ); + + // Load model1 weights into model2 + let mut load_store = BurnpackStore::from_file(&path); + let result = load_store.apply_to(&mut model2).unwrap(); + assert!(result.is_success()); + assert_eq!(result.applied.len(), 4); // 2 weights + 2 biases + + // Forward pass with model2 (now has model1 weights) -> output3 + let output3 = model2.forward(input.clone()); + + // Verify output3 equals output1 (same weights) + assert!( + output1.all_close(output3, Some(1e-6), Some(1e-6)), + "output3 should equal output1 after loading weights" + ); +} diff --git a/crates/burn-store/src/burnpack/tests/writer.rs b/crates/burn-store/src/burnpack/tests/writer.rs new file mode 100644 index 0000000000..3cbadaf3a0 --- /dev/null +++ b/crates/burn-store/src/burnpack/tests/writer.rs @@ -0,0 +1,535 @@ +use crate::burnpack::{ + base::{ + BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER, magic_range, + }, + writer::BurnpackWriter, +}; + +use super::*; +use burn_core::module::ParamId; +use burn_tensor::{DType, TensorData}; +use std::rc::Rc; + +#[test] +fn test_writer_new() { + let writer = BurnpackWriter::new(vec![]); + assert_eq!(writer.snapshots.len(), 0); + assert!(writer.metadata.is_empty()); +} + +#[test] +fn test_writer_add_metadata() { + let writer = BurnpackWriter::new(vec![]) + .with_metadata("model_name", "test_model") + .with_metadata("version", "1.0.0") + .with_metadata("author", "test_author"); + + assert_eq!(writer.metadata.len(), 3); + assert_eq!( + writer.metadata.get("model_name"), + Some(&"test_model".to_string()) + ); + assert_eq!(writer.metadata.get("version"), Some(&"1.0.0".to_string())); + assert_eq!( + writer.metadata.get("author"), + Some(&"test_author".to_string()) + ); +} + +#[test] +fn test_writer_add_tensor_snapshot() { + // Create test tensor snapshots + let snapshot1 = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), + vec!["layer1".to_string(), "weights".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let snapshot2 = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![5, 6, 7, 8], vec![4], DType::U8), + vec!["layer1".to_string(), "bias".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot1, snapshot2]); + + assert_eq!(writer.snapshots.len(), 2); + assert_eq!(writer.snapshots[0].full_path(), "layer1.weights"); + assert_eq!(writer.snapshots[1].full_path(), "layer1.bias"); +} + +#[test] +fn test_writer_to_bytes_empty() { + let writer = BurnpackWriter::new(vec![]); + let bytes = writer.to_bytes().unwrap(); + + // Verify header + assert!(bytes.len() >= HEADER_SIZE); + assert_eq!(&bytes[magic_range()], &MAGIC_NUMBER.to_le_bytes()); + + // Parse header + let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); + assert_eq!(header.magic, MAGIC_NUMBER); + assert_eq!(header.version, FORMAT_VERSION); + + // Verify metadata + let metadata_end = HEADER_SIZE + header.metadata_size as usize; + let metadata_bytes = &bytes[HEADER_SIZE..metadata_end]; + let metadata: BurnpackMetadata = ciborium::de::from_reader(metadata_bytes).unwrap(); + + assert_eq!(metadata.tensors.len(), 0); + assert!(metadata.metadata.is_empty()); +} + +#[test] +fn test_writer_to_bytes_with_tensors() { + // Add tensors with different data types + let f32_data = vec![1.0f32, 2.0, 3.0, 4.0]; + let f32_bytes: Vec = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect(); + let snapshot_f32 = TensorSnapshot::from_data( + TensorData::from_bytes_vec(f32_bytes.clone(), vec![2, 2], DType::F32), + vec!["weights".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let i64_data = vec![10i64, 20, 30]; + let i64_bytes: Vec = i64_data.iter().flat_map(|i| i.to_le_bytes()).collect(); + let snapshot_i64 = TensorSnapshot::from_data( + TensorData::from_bytes_vec(i64_bytes.clone(), vec![3], DType::I64), + vec!["bias".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot_f32, snapshot_i64]) + .with_metadata("test_key", "test_value"); + + let bytes = writer.to_bytes().unwrap(); + + // Parse and verify + let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); + let metadata_end = HEADER_SIZE + header.metadata_size as usize; + let metadata: BurnpackMetadata = + ciborium::de::from_reader(&bytes[HEADER_SIZE..metadata_end]).unwrap(); + + // Verify metadata + assert_eq!( + metadata.metadata.get("test_key"), + Some(&"test_value".to_string()) + ); + + // Verify tensors + assert_eq!(metadata.tensors.len(), 2); + + let weights = metadata.tensors.get("weights").unwrap(); + assert_eq!(weights.dtype, DType::F32); + assert_eq!(weights.shape, vec![2, 2]); + assert_eq!(weights.data_offsets.1 - weights.data_offsets.0, 16); // 4 * 4 bytes + + let bias = metadata.tensors.get("bias").unwrap(); + assert_eq!(bias.dtype, DType::I64); + assert_eq!(bias.shape, vec![3]); + assert_eq!(bias.data_offsets.1 - bias.data_offsets.0, 24); // 3 * 8 bytes + + // Verify actual tensor data + let weights = metadata.tensors.get("weights").unwrap(); + let bias = metadata.tensors.get("bias").unwrap(); + let weights_data = &bytes[metadata_end + weights.data_offsets.0 as usize + ..metadata_end + weights.data_offsets.1 as usize]; + assert_eq!(weights_data, f32_bytes); + + let bias_data = &bytes + [metadata_end + bias.data_offsets.0 as usize..metadata_end + bias.data_offsets.1 as usize]; + assert_eq!(bias_data, i64_bytes); +} + +#[test] +fn test_writer_all_dtypes() { + // Test all supported data types + let test_cases = vec![ + (DType::F32, 4, vec![1.0f32.to_le_bytes().to_vec()].concat()), + (DType::F64, 8, vec![1.0f64.to_le_bytes().to_vec()].concat()), + (DType::I32, 4, vec![1i32.to_le_bytes().to_vec()].concat()), + (DType::I64, 8, vec![1i64.to_le_bytes().to_vec()].concat()), + (DType::U32, 4, vec![1u32.to_le_bytes().to_vec()].concat()), + (DType::U64, 8, vec![1u64.to_le_bytes().to_vec()].concat()), + (DType::U8, 1, vec![255u8]), + (DType::Bool, 1, vec![1u8]), + ]; + + let mut snapshots = vec![]; + for (i, (dtype, _expected_size, data)) in test_cases.into_iter().enumerate() { + let name = format!("tensor_{}", i); + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(data.clone(), vec![1], dtype), + vec![name.clone()], + vec![], + burn_core::module::ParamId::new(), + ); + snapshots.push(snapshot); + } + + let writer = BurnpackWriter::new(snapshots); + + let bytes = writer.to_bytes().unwrap(); + + // Parse and verify + let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); + let metadata: BurnpackMetadata = + ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize]) + .unwrap(); + + assert_eq!(metadata.tensors.len(), 8); + + for i in 0..8 { + let name = format!("tensor_{}", i); + let tensor = metadata.tensors.get(&name).unwrap(); + assert_eq!(tensor.shape, vec![1]); + } +} + +#[test] +fn test_writer_large_tensor() { + // Create a large tensor (1MB) + let size = 256 * 1024; // 256K floats = 1MB + let data: Vec = (0..size).map(|i| i as f32).collect(); + let bytes: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); + + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(bytes.clone(), vec![size], DType::F32), + vec!["large_tensor".to_string()], + vec![], + burn_core::module::ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + + let result = writer.to_bytes().unwrap(); + + // Verify the large tensor is correctly stored + let header = BurnpackHeader::from_bytes(&result[..HEADER_SIZE]).unwrap(); + let metadata: BurnpackMetadata = ciborium::de::from_reader( + &result[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize], + ) + .unwrap(); + + assert_eq!(metadata.tensors.len(), 1); + let tensor = metadata.tensors.get("large_tensor").unwrap(); + assert_eq!(tensor.shape, vec![size as u64]); + assert_eq!( + tensor.data_offsets.1 - tensor.data_offsets.0, + (size * 4) as u64 + ); +} + +#[test] +fn test_writer_empty_tensors() { + // Add tensor with empty data + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![], vec![0], DType::F32), + vec!["empty".to_string()], + vec![], + ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + + let bytes = writer.to_bytes().unwrap(); + + let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); + let metadata: BurnpackMetadata = + ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize]) + .unwrap(); + + assert_eq!(metadata.tensors.len(), 1); + let tensor = metadata.tensors.get("empty").unwrap(); + assert_eq!(tensor.shape, vec![0]); + assert_eq!(tensor.data_offsets.1 - tensor.data_offsets.0, 0); +} + +#[test] +fn test_writer_special_characters_in_names() { + // Test various special characters in tensor names + let special_names = vec![ + "layer.0.weight", + "model/encoder/layer1", + "model::layer::weight", + "layer[0].bias", + "layer_1_weight", + "layer-1-bias", + "layer@1#weight", + "emoji_😀_tensor", + "unicode_测试_tensor", + "spaces in name", + ]; + + let mut snapshots = vec![]; + for name in &special_names { + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8), + vec![name.to_string()], + vec![], + ParamId::new(), + ); + snapshots.push(snapshot); + } + + let writer = BurnpackWriter::new(snapshots); + + let bytes = writer.to_bytes().unwrap(); + + let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); + let metadata: BurnpackMetadata = + ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize]) + .unwrap(); + + assert_eq!(metadata.tensors.len(), 10); + for (tensor_name, _tensor) in metadata.tensors.iter() { + assert!(tensor_name.len() > 0); + // Names should be preserved exactly + assert!( + tensor_name.contains("layer") + || tensor_name.contains("model") + || tensor_name.contains("emoji") + || tensor_name.contains("unicode") + || tensor_name.contains("spaces") + ); + } +} + +#[test] +fn test_writer_metadata_overwrite() { + let writer = BurnpackWriter::new(vec![]) + .with_metadata("key", "value1") + .with_metadata("key", "value2"); + + assert_eq!(writer.metadata.get("key"), Some(&"value2".to_string())); + assert_eq!(writer.metadata.len(), 1); +} + +#[test] +fn test_writer_tensor_order_preserved() { + // Add tensors in specific order + let names = vec!["z_tensor", "a_tensor", "m_tensor", "b_tensor"]; + + let mut snapshots = vec![]; + for name in &names { + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1], vec![1], DType::U8), + vec![name.to_string()], + vec![], + ParamId::new(), + ); + snapshots.push(snapshot); + } + + let writer = BurnpackWriter::new(snapshots); + + let bytes = writer.to_bytes().unwrap(); + + let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); + let metadata: BurnpackMetadata = + ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize]) + .unwrap(); + + // Verify all tensors are present (BTreeMap stores in sorted order by key) + let expected_sorted = vec!["a_tensor", "b_tensor", "m_tensor", "z_tensor"]; + let actual_names: Vec<_> = metadata.tensors.keys().collect(); + assert_eq!(actual_names, expected_sorted); +} + +#[test] +fn test_writer_lazy_snapshot_evaluation() { + // Create a lazy snapshot using closure + let data = Rc::new(vec![1.0f32, 2.0, 3.0, 4.0]); + let data_clone = data.clone(); + + let snapshot = TensorSnapshot::from_closure( + Rc::new(move || { + let bytes: Vec = data_clone.iter().flat_map(|f| f.to_le_bytes()).collect(); + Ok(TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32)) + }), + DType::F32, + vec![2, 2], + vec!["lazy".to_string()], + vec![], + ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + + // The closure should only be evaluated when to_bytes is called + let bytes = writer.to_bytes().unwrap(); + + let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); + let metadata_end = HEADER_SIZE + header.metadata_size as usize; + let metadata: BurnpackMetadata = + ciborium::de::from_reader(&bytes[HEADER_SIZE..metadata_end]).unwrap(); + + assert_eq!(metadata.tensors.len(), 1); + let tensor = metadata.tensors.get("lazy").unwrap(); + assert_eq!(tensor.dtype, DType::F32); + assert_eq!(tensor.shape, vec![2, 2]); + + // Verify the data was correctly written + let tensor_data = &bytes[metadata_end..metadata_end + 16]; + let expected: Vec = vec![1.0f32, 2.0, 3.0, 4.0] + .iter() + .flat_map(|f| f.to_le_bytes()) + .collect(); + assert_eq!(tensor_data, expected.as_slice()); +} + +#[cfg(feature = "std")] +#[test] +fn test_writer_write_to_file() { + use std::fs; + use tempfile::tempdir; + + let dir = tempdir().unwrap(); + let file_path = dir.path().join("test.bpk"); + + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), + vec!["test".to_string()], + vec![], + ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("file_test", "true"); + + writer.write_to_file(&file_path).unwrap(); + + // Verify file exists and has correct content + assert!(file_path.exists()); + + let file_bytes = fs::read(&file_path).unwrap(); + let memory_bytes = writer.to_bytes().unwrap(); + + assert_eq!(file_bytes.as_slice(), &*memory_bytes); +} + +#[test] +fn test_writer_size() { + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), + vec!["test".to_string()], + vec![], + ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "value"); + + let size = writer.size().unwrap(); + let bytes = writer.to_bytes().unwrap(); + + // Size should match actual bytes length + assert_eq!(size, bytes.len()); +} + +#[test] +fn test_writer_write_into() { + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), + vec!["test".to_string()], + vec![], + ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "value"); + + // Get size and allocate buffer + let size = writer.size().unwrap(); + let mut buffer = vec![0u8; size]; + + // Write into buffer + writer.write_into(&mut buffer).unwrap(); + + // Compare with to_bytes() + let bytes = writer.to_bytes().unwrap(); + assert_eq!(buffer.as_slice(), &*bytes); +} + +#[test] +fn test_writer_write_into_buffer_too_small() { + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), + vec!["test".to_string()], + vec![], + ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + + // Allocate a buffer that's too small + let mut buffer = vec![0u8; 10]; + + // Should fail with buffer too small error + let result = writer.write_into(&mut buffer); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Buffer too small")); +} + +#[test] +fn test_writer_write_into_buffer_larger_than_needed() { + let snapshot = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), + vec!["test".to_string()], + vec![], + ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot]); + + // Allocate a larger buffer + let size = writer.size().unwrap(); + let mut buffer = vec![0u8; size + 100]; // Extra 100 bytes + + // Should succeed and only write the necessary bytes + writer.write_into(&mut buffer).unwrap(); + + // Compare the written portion with to_bytes() + let bytes = writer.to_bytes().unwrap(); + assert_eq!(&buffer[..size], &*bytes); +} + +#[test] +fn test_writer_write_into_multiple_tensors() { + let snapshot1 = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), + vec!["tensor1".to_string()], + vec![], + ParamId::new(), + ); + + let snapshot2 = TensorSnapshot::from_data( + TensorData::from_bytes_vec(vec![5, 6, 7, 8, 9, 10], vec![2, 3], DType::U8), + vec!["tensor2".to_string()], + vec![], + ParamId::new(), + ); + + let writer = BurnpackWriter::new(vec![snapshot1, snapshot2]).with_metadata("test", "multiple"); + + let size = writer.size().unwrap(); + let mut buffer = vec![0u8; size]; + writer.write_into(&mut buffer).unwrap(); + + let bytes = writer.to_bytes().unwrap(); + assert_eq!(buffer.as_slice(), &*bytes); +} + +#[test] +fn test_writer_write_into_empty() { + let writer = BurnpackWriter::new(vec![]); + + let size = writer.size().unwrap(); + let mut buffer = vec![0u8; size]; + writer.write_into(&mut buffer).unwrap(); + + let bytes = writer.to_bytes().unwrap(); + assert_eq!(buffer.as_slice(), &*bytes); +} diff --git a/crates/burn-store/src/burnpack/writer.rs b/crates/burn-store/src/burnpack/writer.rs new file mode 100644 index 0000000000..a44587aa3b --- /dev/null +++ b/crates/burn-store/src/burnpack/writer.rs @@ -0,0 +1,244 @@ +use super::base::{ + BurnpackError, BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER, + TensorDescriptor, +}; +use crate::TensorSnapshot; +use alloc::collections::BTreeMap; +use alloc::format; +use alloc::string::{String, ToString}; +use alloc::vec; +use alloc::vec::Vec; +use burn_tensor::Bytes; + +#[cfg(feature = "std")] +use std::fs::File; +#[cfg(feature = "std")] +use std::io::Write; +#[cfg(feature = "std")] +use std::path::Path; + +/// Writer for creating Burnpack files +pub struct BurnpackWriter { + /// Tensors to write + pub(crate) snapshots: Vec, + /// Metadata key-value pairs + pub(crate) metadata: BTreeMap, +} + +impl BurnpackWriter { + /// Create a new writer + pub fn new(snapshots: Vec) -> Self { + Self { + snapshots, + metadata: BTreeMap::new(), + } + } + + /// Builder pattern: add metadata and return self + pub fn with_metadata(mut self, key: &str, value: &str) -> Self { + self.metadata.insert(key.to_string(), value.to_string()); + self + } + + /// Build tensor descriptors and metadata + fn build_metadata(&self) -> Result<(BurnpackMetadata, Vec), BurnpackError> { + // Build tensor descriptors and calculate offsets + let mut tensors = BTreeMap::new(); + let mut current_offset = 0u64; + + for snapshot in &self.snapshots { + let data_len = snapshot.data_len() as u64; + let start = current_offset; + let end = start.checked_add(data_len).ok_or_else(|| { + BurnpackError::IoError(format!( + "Tensor offset overflow: {} + {} exceeds maximum", + start, data_len + )) + })?; + + tensors.insert( + snapshot.full_path(), + TensorDescriptor { + dtype: snapshot.dtype, + shape: snapshot.shape.iter().map(|&s| s as u64).collect(), + data_offsets: (start, end), + param_id: snapshot.tensor_id.map(|id| id.val()), + }, + ); + + current_offset = end; + } + + // Create metadata structure + let metadata = BurnpackMetadata { + tensors, + metadata: self.metadata.clone(), + }; + + // Serialize metadata with CBOR + let mut metadata_bytes = Vec::new(); + ciborium::ser::into_writer(&metadata, &mut metadata_bytes) + .map_err(|e| BurnpackError::IoError(e.to_string()))?; + + Ok((metadata, metadata_bytes)) + } + + /// Calculate the total size needed for the burnpack data + /// + /// This is useful when you want to pre-allocate a buffer for `write_into()`. + pub fn size(&self) -> Result { + let (_, metadata_bytes) = self.build_metadata()?; + let data_size = self.snapshots.iter().map(|s| s.data_len()).sum::(); + Ok(HEADER_SIZE + metadata_bytes.len() + data_size) + } + + /// Write burnpack data into a caller-provided buffer + /// + /// The buffer must be large enough to hold all data. Use `size()` to determine + /// the required buffer size. If the buffer is too small, this will return an error. + /// + /// This allows the caller to control buffer allocation, enabling optimizations like: + /// - Buffer reuse across multiple writes + /// - Custom allocators + /// - Pinned memory for GPU transfers + /// + /// # Arguments + /// + /// * `buffer` - Mutable slice to write data into. Must be at least `size()` bytes. + pub fn write_into(&self, buffer: &mut [u8]) -> Result<(), BurnpackError> { + let (_, metadata_bytes) = self.build_metadata()?; + + // Check metadata size fits in u32 + let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| { + BurnpackError::IoError(format!( + "Metadata size {} exceeds maximum of {} bytes", + metadata_bytes.len(), + u32::MAX + )) + })?; + + // Create header + let header = BurnpackHeader { + magic: MAGIC_NUMBER, + version: FORMAT_VERSION, + metadata_size, + }; + + // Calculate required size + let data_size = self.snapshots.iter().map(|s| s.data_len()).sum::(); + let total_size = HEADER_SIZE + metadata_bytes.len() + data_size; + + // Check buffer size + if buffer.len() < total_size { + return Err(BurnpackError::IoError(format!( + "Buffer too small: need {} bytes, got {} bytes", + total_size, + buffer.len() + ))); + } + + let mut offset = 0; + + // Write header + let header_bytes = header.into_bytes(); + buffer[offset..offset + HEADER_SIZE].copy_from_slice(&header_bytes); + offset += HEADER_SIZE; + + // Write metadata + buffer[offset..offset + metadata_bytes.len()].copy_from_slice(&metadata_bytes); + offset += metadata_bytes.len(); + + // Write tensor data + for snapshot in &self.snapshots { + let expected_len = snapshot.data_len(); + let data = snapshot.to_data().map_err(|e| { + BurnpackError::IoError(format!("Failed to get tensor data: {:?}", e)) + })?; + let actual_len = data.bytes.len(); + + // Validate data length consistency + if actual_len != expected_len { + return Err(BurnpackError::IoError(format!( + "Data corruption: tensor '{}' has inconsistent length (expected {}, got {})", + snapshot.full_path(), + expected_len, + actual_len + ))); + } + + buffer[offset..offset + actual_len].copy_from_slice(&data.bytes); + offset += actual_len; + } + + Ok(()) + } + + /// Write to a byte buffer (convenience method) + /// + /// This allocates a buffer internally and writes the burnpack data. + /// For more control over buffer allocation, use `size()` + `write_into()`. + pub fn to_bytes(&self) -> Result { + let size = self.size()?; + let mut buffer = vec![0u8; size]; + self.write_into(&mut buffer)?; + Ok(Bytes::from_bytes_vec(buffer)) + } + + /// Write directly to a file (more memory efficient for large models) + #[cfg(feature = "std")] + pub fn write_to_file>(&self, path: P) -> Result<(), BurnpackError> { + let mut file = File::create(path).map_err(|e| BurnpackError::IoError(e.to_string()))?; + + let (_, metadata_bytes) = self.build_metadata()?; + + // Check metadata size fits in u32 + let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| { + BurnpackError::IoError(format!( + "Metadata size {} exceeds maximum of {} bytes", + metadata_bytes.len(), + u32::MAX + )) + })?; + + // Create and write header + let header = BurnpackHeader { + magic: MAGIC_NUMBER, + version: FORMAT_VERSION, + metadata_size, + }; + + file.write_all(&header.into_bytes()) + .map_err(|e| BurnpackError::IoError(e.to_string()))?; + + // Write metadata + file.write_all(&metadata_bytes) + .map_err(|e| BurnpackError::IoError(e.to_string()))?; + + // Stream tensor data directly to file + for snapshot in &self.snapshots { + let expected_len = snapshot.data_len(); + let data = snapshot.to_data().map_err(|e| { + BurnpackError::IoError(format!("Failed to get tensor data: {:?}", e)) + })?; + let actual_len = data.bytes.len(); + + // Validate data length consistency + if actual_len != expected_len { + return Err(BurnpackError::IoError(format!( + "Data corruption: tensor '{}' has inconsistent length (expected {}, got {})", + snapshot.full_path(), + expected_len, + actual_len + ))); + } + + file.write_all(&data.bytes) + .map_err(|e| BurnpackError::IoError(e.to_string()))?; + } + + file.flush() + .map_err(|e| BurnpackError::IoError(e.to_string()))?; + + Ok(()) + } +} diff --git a/crates/burn-store/src/collector.rs b/crates/burn-store/src/collector.rs index 132c3a8de6..4ff2a0263e 100644 --- a/crates/burn-store/src/collector.rs +++ b/crates/burn-store/src/collector.rs @@ -5,7 +5,7 @@ use alloc::vec::Vec; use burn_tensor::{Bool, Int, Tensor, backend::Backend}; use crate::{ModuleAdapter, PathFilter, TensorSnapshot}; -use burn_core::module::{ModuleVisitor, ParamId}; +use burn_core::module::{ModuleVisitor, Param, ParamId}; /// Collects tensor views from modules without copying data. /// @@ -15,26 +15,30 @@ use burn_core::module::{ModuleVisitor, ParamId}; /// # Examples /// /// ## Collect all tensors -/// ```rust,ignore -/// let collector = Collector::new(); -/// module.visit(&mut collector); +/// ```rust,no_run +/// # use burn_store::Collector; +/// let collector = Collector::new(None, None); +/// // Use with module.visit(&mut collector); /// let all_tensors = collector.tensors; /// ``` /// /// ## Filter with single pattern -/// ```rust,ignore -/// let collector = Collector::with_filter(PathFilter::new().with_regex(r"^encoder\..*")); -/// module.visit(&mut collector); +/// ```rust,no_run +/// # use burn_store::{Collector, PathFilter}; +/// let filter = PathFilter::new().with_regex(r"^encoder\..*"); +/// let collector = Collector::new(Some(filter), None); +/// // Use with module.visit(&mut collector); /// // Only collects tensors starting with "encoder." /// ``` /// /// ## Filter with multiple patterns (OR union) -/// ```rust,ignore +/// ```rust,no_run +/// # use burn_store::{Collector, PathFilter}; /// let filter = PathFilter::new() /// .with_regex(r"^encoder\..*") // Match all encoder tensors /// .with_regex(r".*\.bias$"); // OR match any bias tensors -/// let collector = Collector::with_filter(filter); -/// module.visit(&mut collector); +/// let collector = Collector::new(Some(filter), None); +/// // Use with module.visit(&mut collector); /// // Collects tensors matching ANY of the patterns /// ``` pub struct Collector { @@ -64,17 +68,16 @@ impl Collector { /// /// # Examples /// - /// ```rust,ignore - /// use burn_store::{Collector, PathFilter}; - /// + /// ```rust,no_run + /// # use burn_store::{Collector, PathFilter}; /// // Collect all tensors without adapter /// let collector = Collector::new(None, None); /// - /// // Use PathFilter builder with adapter + /// // Use PathFilter builder /// let filter = PathFilter::new() /// .with_regex(r"^encoder\..*") /// .with_full_path("decoder.weight"); - /// let collector = Collector::new(Some(filter), Some(adapter)); + /// let collector = Collector::new(Some(filter), None); /// ``` pub fn new(filter: Option, adapter: Option>) -> Self { Self { @@ -118,41 +121,35 @@ impl ModuleVisitor for Collector { self.container_stack.pop(); } - fn visit_float(&mut self, id: ParamId, tensor: &Tensor) { - if !self.path_stack.is_empty() - && self.should_collect(&self.path_stack, &self.container_stack) - { + fn visit_float(&mut self, param: &Param>) { + if self.should_collect(&self.path_stack, &self.container_stack) { self.tensors.push(TensorSnapshot::from_float( - tensor, + ¶m.transform_for_save().val(), self.path_stack.clone(), self.container_stack.clone(), - id, + param.id, )); } } - fn visit_int(&mut self, id: ParamId, tensor: &Tensor) { - if !self.path_stack.is_empty() - && self.should_collect(&self.path_stack, &self.container_stack) - { + fn visit_int(&mut self, param: &Param>) { + if self.should_collect(&self.path_stack, &self.container_stack) { self.tensors.push(TensorSnapshot::from_int( - tensor, + ¶m.transform_for_save().val(), self.path_stack.clone(), self.container_stack.clone(), - id, + param.id, )); } } - fn visit_bool(&mut self, id: ParamId, tensor: &Tensor) { - if !self.path_stack.is_empty() - && self.should_collect(&self.path_stack, &self.container_stack) - { + fn visit_bool(&mut self, param: &Param>) { + if self.should_collect(&self.path_stack, &self.container_stack) { self.tensors.push(TensorSnapshot::from_bool( - tensor, + ¶m.transform_for_save().val(), self.path_stack.clone(), self.container_stack.clone(), - id, + param.id, )); } } @@ -236,6 +233,52 @@ mod tests { assert_eq!(data.shape, vec![2, 2]); } + #[test] + fn root_level_parameters() { + use burn_core::module::ModuleVisitor; + + let device = Default::default(); + + // Create root-level parameters (single-element path, not nested in modules) + let weight = Param::>::from_data([[1.0, 2.0], [3.0, 4.0]], &device); + let bias = Param::>::from_data([5.0, 6.0], &device); + + let mut collector = Collector::new(None, None); + + // Simulate module traversal for root-level parameters + // Enter "weight" path (as if we're visiting a field named "weight") + ModuleVisitor::::enter_module(&mut collector, "weight", ""); + ModuleVisitor::::visit_float(&mut collector, &weight); + ModuleVisitor::::exit_module(&mut collector, "weight", ""); + + // Enter "bias" path (as if we're visiting a field named "bias") + ModuleVisitor::::enter_module(&mut collector, "bias", ""); + ModuleVisitor::::visit_float(&mut collector, &bias); + ModuleVisitor::::exit_module(&mut collector, "bias", ""); + + // Verify both parameters were collected + assert_eq!(collector.tensors.len(), 2); + + // Verify paths are correct (single-element paths) + assert_eq!(collector.tensors[0].full_path(), "weight"); + assert_eq!(collector.tensors[1].full_path(), "bias"); + + // Verify data is correct + let weight_data = collector.tensors[0] + .to_data() + .unwrap() + .to_vec::() + .unwrap(); + let bias_data = collector.tensors[1] + .to_data() + .unwrap() + .to_vec::() + .unwrap(); + + assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]); + assert_eq!(bias_data, vec![5.0, 6.0]); + } + #[test] #[cfg(target_has_atomic = "ptr")] fn tensor_snapshot_collector_with_filter() { @@ -439,24 +482,33 @@ mod tests { self.path_stack.pop(); } - fn visit_float(&mut self, id: ParamId, tensor: &Tensor) { + fn visit_float(&mut self, param: &Param>) { let path = self.current_path(); if !path.is_empty() { - self.paths.insert(path, (id, tensor.shape().to_vec())); + self.paths.insert( + path, + (param.id, param.transform_for_save().val().shape().to_vec()), + ); } } - fn visit_int(&mut self, id: ParamId, tensor: &Tensor) { + fn visit_int(&mut self, param: &Param>) { let path = self.current_path(); if !path.is_empty() { - self.paths.insert(path, (id, tensor.shape().to_vec())); + self.paths.insert( + path, + (param.id, param.transform_for_save().val().shape().to_vec()), + ); } } - fn visit_bool(&mut self, id: ParamId, tensor: &Tensor) { + fn visit_bool(&mut self, param: &Param>) { let path = self.current_path(); if !path.is_empty() { - self.paths.insert(path, (id, tensor.shape().to_vec())); + self.paths.insert( + path, + (param.id, param.transform_for_save().val().shape().to_vec()), + ); } } } diff --git a/crates/burn-store/src/filter.rs b/crates/burn-store/src/filter.rs index 6824a79fd0..1d995b6219 100644 --- a/crates/burn-store/src/filter.rs +++ b/crates/burn-store/src/filter.rs @@ -13,7 +13,8 @@ use regex::Regex; /// /// # Examples /// -/// ```rust,ignore +/// ```rust,no_run +/// # use burn_store::PathFilter; /// // Create a filter that matches encoder paths or any weight path /// let filter = PathFilter::new() /// .with_regex(r"^encoder\..*") diff --git a/crates/burn-store/src/keyremapper.rs b/crates/burn-store/src/keyremapper.rs index c15a9be7c0..7b9f224e91 100644 --- a/crates/burn-store/src/keyremapper.rs +++ b/crates/burn-store/src/keyremapper.rs @@ -12,16 +12,18 @@ use crate::TensorSnapshot; /// /// # Examples /// -/// ```rust,ignore -/// use burn_store::KeyRemapper; -/// +/// ```rust,no_run +/// # use burn_store::KeyRemapper; +/// # fn example() -> Result<(), Box> { /// // Create a key remapper /// let remapper = KeyRemapper::new() /// .add_pattern(r"^pytorch\.(.*)", "burn.$1")? // pytorch.layer -> burn.layer /// .add_pattern(r"\.gamma$", ".weight")?; // layer.gamma -> layer.weight /// -/// // Apply to tensor views -/// let (remapped_tensors, transformations) = remapper.remap(tensors); +/// // Use remapper with stores +/// // store.remap(remapper) +/// # Ok(()) +/// # } /// ``` #[derive(Debug, Clone, Default)] pub struct KeyRemapper { diff --git a/crates/burn-store/src/lib.rs b/crates/burn-store/src/lib.rs index ffb4655b0c..ba6ecf2b91 100644 --- a/crates/burn-store/src/lib.rs +++ b/crates/burn-store/src/lib.rs @@ -10,6 +10,7 @@ //! //! ## Key Features //! +//! - **Burnpack Format**: Native Burn format with CBOR metadata, ParamId persistence for stateful training, and no-std support //! - **SafeTensors Format**: Industry-standard format for secure and efficient tensor serialization //! - **PyTorch Compatibility**: Load PyTorch models directly into Burn with automatic weight transformation //! - **Zero-Copy Loading**: Memory-mapped files and lazy tensor materialization for optimal performance @@ -26,11 +27,11 @@ //! //! // Save a model //! let mut store = SafetensorsStore::from_file("model.safetensors"); -//! model.collect_to(&mut store)?; +//! model.save_into(&mut store)?; //! //! // Load a model //! let mut store = SafetensorsStore::from_file("model.safetensors"); -//! model.apply_from(&mut store)?; +//! model.load_from(&mut store)?; //! ``` //! //! ### Loading PyTorch Models @@ -43,26 +44,26 @@ //! .with_top_level_key("state_dict") // Access nested state dict if needed //! .allow_partial(true); // Skip unknown tensors //! -//! model.apply_from(&mut store)?; +//! model.load_from(&mut store)?; //! ``` //! //! ### Filtering and Remapping //! -//! ```rust,ignore -//! use burn_store::SafetensorsStore; -//! +//! ```rust,no_run +//! # use burn_store::SafetensorsStore; //! // Save only specific layers with renaming //! let mut store = SafetensorsStore::from_file("encoder.safetensors") //! .with_regex(r"^encoder\..*") // Filter: only encoder layers //! .with_key_remapping(r"^encoder\.", "transformer.") // Rename: encoder.X -> transformer.X //! .metadata("subset", "encoder_only"); //! -//! model.collect_to(&mut store)?; +//! // Use store with model.save_into(&mut store)?; //! ``` //! //! ## Core Components //! //! - [`ModuleSnapshot`]: Extension trait for Burn modules providing `collect()` and `apply()` methods +//! - [`BurnpackStore`]: Native Burn format with ParamId persistence for stateful training workflows //! - [`SafetensorsStore`]: Primary storage implementation supporting the SafeTensors format //! - [`PytorchStore`]: PyTorch model loader supporting .pth and .pt files //! - [`PathFilter`]: Flexible filtering system for selective tensor loading/saving @@ -88,7 +89,7 @@ pub use applier::{Applier, ApplyError, ApplyResult}; pub use collector::Collector; pub use filter::PathFilter; pub use tensor_snapshot::{TensorSnapshot, TensorSnapshotError}; -pub use traits::{ModuleSnapshot, ModuleSnapshoter}; +pub use traits::{ModuleSnapshot, ModuleStore}; #[cfg(feature = "std")] mod keyremapper; @@ -104,3 +105,8 @@ pub use pytorch::{PytorchStore, PytorchStoreError}; mod safetensors; #[cfg(feature = "safetensors")] pub use safetensors::{SafetensorsStore, SafetensorsStoreError}; + +#[cfg(feature = "burnpack")] +mod burnpack; +#[cfg(feature = "burnpack")] +pub use burnpack::store::BurnpackStore; diff --git a/crates/burn-store/src/pytorch/mod.rs b/crates/burn-store/src/pytorch/mod.rs index f289f440ca..5eecf6b452 100644 --- a/crates/burn-store/src/pytorch/mod.rs +++ b/crates/burn-store/src/pytorch/mod.rs @@ -27,7 +27,7 @@ //! .allow_partial(true); // Skip missing tensors //! //! let mut model = MyModel::new(&device); -//! let result = model.apply_from(&mut store)?; +//! let result = model.load_from(&mut store)?; //! //! println!("Loaded {} tensors", result.applied.len()); //! if !result.missing.is_empty() { diff --git a/crates/burn-store/src/pytorch/reader.rs b/crates/burn-store/src/pytorch/reader.rs index 3edd9a352d..f573b04dbe 100644 --- a/crates/burn-store/src/pytorch/reader.rs +++ b/crates/burn-store/src/pytorch/reader.rs @@ -166,9 +166,9 @@ pub enum ByteOrder { /// legacy format (0.1.10-1.5), and simple pickle files. /// /// # Example -/// ```ignore -/// use burn_store::pytorch::PytorchReader; -/// +/// ```rust,no_run +/// # use burn_store::pytorch::PytorchReader; +/// # fn example() -> Result<(), Box> { /// // Load a checkpoint file /// let reader = PytorchReader::new("model.pt")?; /// @@ -181,8 +181,10 @@ pub enum ByteOrder { /// } /// /// // Check file metadata -/// println!("Format: {:?}", reader.format_type()); +/// println!("Format: {:?}", reader.metadata().format_type); /// println!("Tensor count: {}", reader.metadata().tensor_count); +/// # Ok(()) +/// # } /// ``` pub struct PytorchReader { tensors: HashMap, @@ -212,8 +214,12 @@ impl PytorchReader { /// * `key` - Top-level key to extract (e.g., "state_dict") /// /// # Example - /// ```ignore + /// ```rust,no_run + /// # use burn_store::pytorch::PytorchReader; + /// # fn example() -> Result<(), Box> { /// let reader = PytorchReader::with_top_level_key("checkpoint.pt", "state_dict")?; + /// # Ok(()) + /// # } /// ``` pub fn with_top_level_key>(path: P, key: &str) -> Result { let (tensors, metadata) = load_pytorch_file_with_metadata(path.as_ref(), Some(key))?; @@ -317,10 +323,10 @@ impl PytorchReader { /// reading or deserialization fails. /// /// # Example - /// ```ignore - /// use burn_store::pytorch::PytorchReader; - /// use serde::Deserialize; - /// + /// ```rust,no_run + /// # use burn_store::pytorch::PytorchReader; + /// # use serde::Deserialize; + /// # fn example() -> Result<(), Box> { /// #[derive(Debug, Deserialize)] /// struct ModelConfig { /// hidden_size: usize, @@ -328,6 +334,8 @@ impl PytorchReader { /// } /// /// let config: ModelConfig = PytorchReader::load_config("model.pth", Some("config"))?; + /// # Ok(()) + /// # } /// ``` pub fn load_config(path: P, top_level_key: Option<&str>) -> Result where diff --git a/crates/burn-store/src/pytorch/store.rs b/crates/burn-store/src/pytorch/store.rs index fb118df8d4..537b55c855 100644 --- a/crates/burn-store/src/pytorch/store.rs +++ b/crates/burn-store/src/pytorch/store.rs @@ -1,7 +1,7 @@ //! PyTorch store implementation for saving and loading models in PyTorch format. use crate::{ - ApplyResult, KeyRemapper, ModuleSnapshot, ModuleSnapshoter, PathFilter, PyTorchToBurnAdapter, + ApplyResult, KeyRemapper, ModuleSnapshot, ModuleStore, PathFilter, PyTorchToBurnAdapter, TensorSnapshot, }; @@ -83,8 +83,8 @@ impl PytorchStore { /// * `path` - Path to the PyTorch checkpoint file (.pt or .pth) /// /// # Example - /// ```rust,ignore - /// use burn_store::pytorch::PytorchStore; + /// ```rust,no_run + /// use burn_store::PytorchStore; /// /// let store = PytorchStore::from_file("model.pth"); /// ``` @@ -105,7 +105,8 @@ impl PytorchStore { /// tensors from a specific top-level key like "state_dict" or "model_state_dict". /// /// # Example - /// ```rust,ignore + /// ```rust,no_run + /// # use burn_store::PytorchStore; /// let store = PytorchStore::from_file("checkpoint.pth") /// .with_top_level_key("model_state_dict"); /// ``` @@ -125,7 +126,8 @@ impl PytorchStore { /// Multiple patterns can be added and they work with OR logic. /// /// # Example - /// ```rust,ignore + /// ```rust,no_run + /// # use burn_store::PytorchStore; /// let store = PytorchStore::from_file("model.pth") /// .with_regex(r"^encoder\..*") // Match all encoder tensors /// .with_regex(r".*\.weight$"); // OR match any weight tensors @@ -148,7 +150,8 @@ impl PytorchStore { /// Add an exact full path to match. /// /// # Example - /// ```rust,ignore + /// ```rust,no_run + /// # use burn_store::PytorchStore; /// let store = PytorchStore::from_file("model.pth") /// .with_full_path("encoder.layer1.weight") /// .with_full_path("decoder.output.bias"); @@ -173,7 +176,8 @@ impl PytorchStore { /// The predicate receives the tensor path and container path. /// /// # Example - /// ```rust,ignore + /// ```rust,no_run + /// # use burn_store::PytorchStore; /// let store = PytorchStore::from_file("model.pth") /// .with_predicate(|path, _| path.starts_with("encoder.") || path.ends_with(".bias")); /// ``` @@ -206,7 +210,8 @@ impl PytorchStore { /// Add a regex pattern to remap tensor names during load. /// /// # Example - /// ```rust,ignore + /// ```rust,no_run + /// # use burn_store::PytorchStore; /// let store = PytorchStore::from_file("model.pth") /// .with_key_remapping(r"^encoder\.", "transformer.encoder.") // encoder.X -> transformer.encoder.X /// .with_key_remapping(r"\.gamma$", ".weight"); // X.gamma -> X.weight @@ -260,7 +265,7 @@ impl PytorchStore { } } -impl ModuleSnapshoter for PytorchStore { +impl ModuleStore for PytorchStore { type Error = PytorchStoreError; fn collect_from>( diff --git a/crates/burn-store/src/pytorch/tests/store/mod.rs b/crates/burn-store/src/pytorch/tests/store/mod.rs index 5843df6e22..3d0f477d82 100644 --- a/crates/burn-store/src/pytorch/tests/store/mod.rs +++ b/crates/burn-store/src/pytorch/tests/store/mod.rs @@ -2,7 +2,7 @@ use std::path::PathBuf; -use crate::ModuleSnapshoter; +use crate::ModuleStore; use crate::pytorch::PytorchStore; use burn_core::module::Module; use burn_nn::conv::{Conv2d, Conv2dConfig}; diff --git a/crates/burn-store/src/safetensors/mod.rs b/crates/burn-store/src/safetensors/mod.rs index 6850cc1f84..23be75b7bb 100644 --- a/crates/burn-store/src/safetensors/mod.rs +++ b/crates/burn-store/src/safetensors/mod.rs @@ -19,44 +19,40 @@ //! ## Basic Save and Load //! //! ```rust,ignore -//! use burn_store::safetensors::SafetensorsStore; -//! use burn_store::ModuleSnapshot; +//! use burn_store::{SafetensorsStore, ModuleSnapshot}; //! //! // Save a model to a file //! let mut store = SafetensorsStore::from_file("model.safetensors"); -//! model.collect_to(&mut store)?; +//! model.save_into(&mut store)?; //! //! // Load a model from a file //! let mut store = SafetensorsStore::from_file("model.safetensors"); //! let mut model = Model::new(&device); -//! model.apply_from(&mut store)?; +//! model.load_from(&mut store)?; //! ``` //! //! ## Memory-Based Operations //! //! ```rust,ignore -//! use burn_store::safetensors::SafetensorsStore; -//! use burn_store::ModuleSnapshot; +//! use burn_store::{SafetensorsStore, ModuleSnapshot}; //! //! // Save to memory buffer //! let mut store = SafetensorsStore::from_bytes(None); -//! model.collect_to(&mut store)?; +//! model.save_into(&mut store)?; //! let bytes = store.get_bytes()?; //! //! // Load from memory buffer //! let mut store = SafetensorsStore::from_bytes(Some(bytes)); //! let mut model = Model::new(&device); -//! model.apply_from(&mut store)?; +//! model.load_from(&mut store)?; //! ``` //! //! ## Advanced Features //! //! ### Filter Configuration with Builder Pattern //! -//! ```rust,ignore -//! use burn_store::ModuleSnapshot; -//! use burn_store::safetensors::SafetensorsStore; -//! +//! ```rust,no_run +//! # use burn_store::SafetensorsStore; //! // Filter with regex patterns (OR logic - matches any pattern) //! let mut store = SafetensorsStore::from_file("model.safetensors") //! .with_regex(r"^encoder\..*") // Match all encoder tensors @@ -92,11 +88,9 @@ //! //! Remap tensor names during load/save operations for compatibility between different frameworks: //! -//! ```rust,ignore -//! use burn_store::ModuleSnapshot; -//! use burn_store::safetensors::SafetensorsStore; -//! use burn_store::KeyRemapper; -//! +//! ```rust,no_run +//! # use burn_store::{SafetensorsStore, KeyRemapper}; +//! # fn example() -> Result<(), Box> { //! // Using builder pattern for common remapping patterns //! let mut store = SafetensorsStore::from_file("model.safetensors") //! .with_key_remapping(r"^encoder\.", "transformer.encoder.") // encoder.X -> transformer.encoder.X @@ -111,6 +105,8 @@ //! //! let mut store = SafetensorsStore::from_file("model.safetensors") //! .remap(remapper); +//! # Ok(()) +//! # } //! ``` //! //! ### Framework Adapters @@ -118,8 +114,7 @@ //! Use adapters for automatic framework-specific transformations: //! //! ```rust,ignore -//! use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, BurnToPyTorchAdapter}; -//! use burn_store::safetensors::SafetensorsStore; +//! use burn_store::{SafetensorsStore, ModuleSnapshot, PyTorchToBurnAdapter, BurnToPyTorchAdapter}; //! //! // Loading PyTorch model into Burn //! let mut store = SafetensorsStore::from_file("pytorch_model.safetensors") @@ -127,20 +122,19 @@ //! .allow_partial(true); // PyTorch models may have extra tensors //! //! let mut burn_model = Model::new(&device); -//! burn_model.apply_from(&mut store)?; +//! burn_model.load_from(&mut store)?; //! //! // Saving Burn model for PyTorch //! let mut store = SafetensorsStore::from_file("for_pytorch.safetensors") //! .with_to_adapter(BurnToPyTorchAdapter); // Transposes weights back, renames for PyTorch //! -//! burn_model.collect_to(&mut store)?; +//! burn_model.save_into(&mut store)?; //! ``` //! //! ### Additional Configuration Options //! //! ```rust,ignore -//! use burn_store::ModuleSnapshot; -//! use burn_store::safetensors::SafetensorsStore; +//! use burn_store::{SafetensorsStore, ModuleSnapshot}; //! //! let mut store = SafetensorsStore::from_file("model.safetensors") //! // Add custom metadata @@ -152,9 +146,9 @@ //! .validate(false); //! //! // Use the configured store -//! model.collect_to(&mut store)?; // For saving +//! model.save_into(&mut store)?; // For saving //! // or -//! model.apply_from(&mut store)?; // For loading +//! model.load_from(&mut store)?; // For loading //! ``` //! //! # Efficient Loading with SafeTensors @@ -162,11 +156,13 @@ //! SafeTensors provides efficient tensor loading through its zero-copy design: //! //! ```rust,ignore +//! use burn_store::{SafetensorsStore, ModuleSnapshot}; +//! //! let mut store = SafetensorsStore::from_file("large_model.safetensors"); //! // Uses memory mapping (when available) for zero-copy access //! // Falls back to buffered reading when mmap is not available //! let mut model = Model::new(&device); -//! model.apply_from(&mut store)?; +//! model.load_from(&mut store)?; //! ``` //! //! The safetensors approach provides: @@ -180,7 +176,7 @@ //! zero-copy design and built-in metadata handling: //! //! ```rust,ignore -//! use burn_store::safetensors::SafetensorsStore; +//! use burn_store::SafetensorsStore; //! //! // Open a file - uses safetensors' efficient header reading //! let store = SafetensorsStore::from_file("large_model.safetensors"); @@ -242,7 +238,9 @@ //! //! All methods return `Self` for chaining: //! -//! ```rust,ignore +//! ```rust,no_run +//! use burn_store::{SafetensorsStore, PyTorchToBurnAdapter}; +//! //! let store = SafetensorsStore::from_file("model.safetensors") //! .with_regex(r"^encoder\..*") //! .with_key_remapping(r"\.gamma$", ".weight") @@ -256,15 +254,14 @@ //! For direct byte operations without files: //! //! ```rust,ignore -//! use burn_store::ModuleSnapshot; -//! use burn_store::safetensors::SafetensorsStore; +//! use burn_store::{SafetensorsStore, ModuleSnapshot}; //! //! // Save to bytes with filtering and remapping //! let mut store = SafetensorsStore::from_bytes(None) //! .with_regex(r"^encoder\..*") // Only save encoder tensors //! .with_key_remapping(r"^encoder\.", "transformer.") // Rename encoder.X -> transformer.X //! .metadata("subset", "encoder_only"); -//! model.collect_to(&mut store)?; +//! model.save_into(&mut store)?; //! let bytes = store.get_bytes()?; //! //! // Load from bytes (allow partial since we only saved encoder) @@ -272,7 +269,7 @@ //! .with_key_remapping(r"^transformer\.", "encoder.") // Rename back: transformer.X -> encoder.X //! .allow_partial(true); //! let mut model = Model::new(&device); -//! let result = model.apply_from(&mut store)?; +//! let result = model.load_from(&mut store)?; //! println!("Applied {} tensors", result.applied.len()); //! ``` //! @@ -281,8 +278,7 @@ //! Migrating a PyTorch model to Burn with filtering, remapping, and adapters: //! //! ```rust,ignore -//! use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter}; -//! use burn_store::safetensors::SafetensorsStore; +//! use burn_store::{SafetensorsStore, ModuleSnapshot, PyTorchToBurnAdapter}; //! //! // Load PyTorch model with all transformations //! let mut store = SafetensorsStore::from_file("pytorch_model.safetensors") @@ -299,7 +295,7 @@ //! .metadata("converted_by", "burn-store"); //! //! let mut model = TransformerModel::new(&device); -//! let result = model.apply_from(&mut store)?; +//! let result = model.load_from(&mut store)?; //! //! println!("Successfully loaded {} tensors", result.applied.len()); //! if !result.missing.is_empty() { diff --git a/crates/burn-store/src/safetensors/store.rs b/crates/burn-store/src/safetensors/store.rs index 36fcc6abd7..4ed543b684 100644 --- a/crates/burn-store/src/safetensors/store.rs +++ b/crates/burn-store/src/safetensors/store.rs @@ -1,7 +1,7 @@ //! SafeTensors store implementation using the official safetensors crate. use crate::{ - ApplyResult, IdentityAdapter, ModuleAdapter, ModuleSnapshot, ModuleSnapshoter, PathFilter, + ApplyResult, IdentityAdapter, ModuleAdapter, ModuleSnapshot, ModuleStore, PathFilter, TensorSnapshot, }; @@ -119,6 +119,7 @@ impl SafetensorsStore { metadata: Self::default_metadata(), validate: true, allow_partial: false, + overwrite: false, from_adapter: Box::new(IdentityAdapter), to_adapter: Box::new(IdentityAdapter), }) @@ -154,7 +155,8 @@ impl SafetensorsStore { /// Multiple patterns can be added and they work with OR logic. /// /// # Example - /// ```rust,ignore + /// ```rust,no_run + /// # use burn_store::SafetensorsStore; /// let store = SafetensorsStore::from_file("model.safetensors") /// .with_regex(r"^encoder\..*") // Match all encoder tensors /// .with_regex(r".*\.weight$"); // OR match any weight tensors @@ -187,7 +189,8 @@ impl SafetensorsStore { /// Add an exact full path to match. /// /// # Example - /// ```rust,ignore + /// ```rust,no_run + /// # use burn_store::SafetensorsStore; /// let store = SafetensorsStore::from_file("model.safetensors") /// .with_full_path("encoder.layer1.weight") /// .with_full_path("decoder.output.bias"); @@ -220,7 +223,8 @@ impl SafetensorsStore { /// The predicate receives the tensor path and container path. /// /// # Example - /// ```rust,ignore + /// ```rust,no_run + /// # use burn_store::SafetensorsStore; /// let store = SafetensorsStore::from_file("model.safetensors") /// .with_predicate(|path, _| path.starts_with("encoder.") || path.ends_with(".bias")); /// ``` @@ -269,7 +273,8 @@ impl SafetensorsStore { /// Add a regex pattern to remap tensor names during load/save. /// /// # Example - /// ```rust,ignore + /// ```rust,no_run + /// # use burn_store::SafetensorsStore; /// let store = SafetensorsStore::from_file("model.safetensors") /// .with_key_remapping(r"^encoder\.", "transformer.encoder.") // encoder.X -> transformer.encoder.X /// .with_key_remapping(r"\.gamma$", ".weight"); // X.gamma -> X.weight @@ -353,6 +358,31 @@ impl SafetensorsStore { self } + /// Set whether to overwrite existing files when saving (default: false). + /// + /// When set to `false`, attempting to save to an existing file will result in an error. + /// When set to `true`, existing files will be overwritten without warning. + /// + /// This setting only applies to file-based stores. + /// + /// # Example + /// ```rust,no_run + /// # use burn_store::SafetensorsStore; + /// let mut store = SafetensorsStore::from_file("model.safetensors") + /// .overwrite(true); + /// // Will overwrite if file exists when saving + /// ``` + #[cfg(feature = "std")] + pub fn overwrite(mut self, overwrite: bool) -> Self { + match &mut self { + Self::File(p) => p.overwrite = overwrite, + Self::Memory(_) => { + // Memory stores don't have overwrite semantics, ignore + } + } + self + } + /// Set the adapter for loading tensors (converting from source format to Burn). pub fn with_from_adapter(mut self, adapter: impl ModuleAdapter + 'static) -> Self { match &mut self { @@ -376,10 +406,14 @@ impl SafetensorsStore { /// Get saved bytes from memory-based store. /// /// # Example - /// ```rust,ignore + /// ```rust,no_run + /// # use burn_store::SafetensorsStore; + /// # fn example() -> Result<(), Box> { /// let mut store = SafetensorsStore::from_bytes(None); - /// model.collect_to(&mut store)?; + /// // After saving model with collect_to()... /// let bytes = store.get_bytes()?; + /// # Ok(()) + /// # } /// ``` pub fn get_bytes(&self) -> Result, SafetensorsStoreError> { match self { @@ -404,6 +438,7 @@ pub struct FileStore { metadata: HashMap, validate: bool, allow_partial: bool, + overwrite: bool, from_adapter: Box, to_adapter: Box, } @@ -482,7 +517,7 @@ impl safetensors::View for TensorSnapshotAdapter { } } -impl ModuleSnapshoter for SafetensorsStore { +impl ModuleStore for SafetensorsStore { type Error = SafetensorsStoreError; fn collect_from>( @@ -520,6 +555,14 @@ impl ModuleSnapshoter for SafetensorsStore { match self { #[cfg(feature = "std")] Self::File(p) => { + // Check if file exists and overwrite is disabled + if p.path.exists() && !p.overwrite { + return Err(SafetensorsStoreError::Other(format!( + "File already exists: {}. Use .overwrite(true) to overwrite.", + p.path.display() + ))); + } + // Convert to safetensors format let tensors = snapshots_to_safetensors(snapshots)?; diff --git a/crates/burn-store/src/safetensors/tests/adapter.rs b/crates/burn-store/src/safetensors/tests/adapter.rs index d7f68b26e1..8b4b7cf8f4 100644 --- a/crates/burn-store/src/safetensors/tests/adapter.rs +++ b/crates/burn-store/src/safetensors/tests/adapter.rs @@ -30,7 +30,7 @@ fn pytorch_to_burn_adapter_linear_transpose() { // Save with BurnToPyTorch adapter (will transpose linear weights) let mut save_store = SafetensorsStore::from_bytes(None).with_to_adapter(BurnToPyTorchAdapter); - model.collect_to(&mut save_store).unwrap(); + model.save_into(&mut save_store).unwrap(); // Load with PyTorchToBurn adapter (will transpose back) let mut load_store = SafetensorsStore::from_bytes(None).with_from_adapter(PyTorchToBurnAdapter); @@ -41,7 +41,7 @@ fn pytorch_to_burn_adapter_linear_transpose() { } let mut model2 = TestModel::::new(&device); - let result = model2.apply_from(&mut load_store).unwrap(); + let result = model2.load_from(&mut load_store).unwrap(); // Should successfully load all tensors assert!(!result.applied.is_empty()); @@ -86,7 +86,7 @@ fn pytorch_to_burn_adapter_norm_rename() { // Save with BurnToPyTorch adapter (will rename gamma->weight, beta->bias) let mut save_store = SafetensorsStore::from_bytes(None).with_to_adapter(BurnToPyTorchAdapter); - model.collect_to(&mut save_store).unwrap(); + model.save_into(&mut save_store).unwrap(); // The saved data should have PyTorch naming convention // We can't directly verify the internal names, but we can verify round-trip works @@ -100,7 +100,7 @@ fn pytorch_to_burn_adapter_norm_rename() { } let mut model2 = NormModel::::new(&device); - let result = model2.apply_from(&mut load_store).unwrap(); + let result = model2.load_from(&mut load_store).unwrap(); // Should load successfully assert!(!result.applied.is_empty()); @@ -122,7 +122,7 @@ fn no_adapter_preserves_original() { // Save without adapter let mut save_store = SafetensorsStore::from_bytes(None); - model.collect_to(&mut save_store).unwrap(); + model.save_into(&mut save_store).unwrap(); // Load without adapter let mut load_store = SafetensorsStore::from_bytes(None); @@ -133,7 +133,7 @@ fn no_adapter_preserves_original() { } let mut model2 = TestModel::::new(&device); - let result = model2.apply_from(&mut load_store).unwrap(); + let result = model2.load_from(&mut load_store).unwrap(); assert!(result.is_success()); assert!(!result.applied.is_empty()); @@ -183,7 +183,7 @@ fn adapter_with_pytorch_import() { .allow_partial(true); let mut model = SimpleNet::::new(&device); - let result = model.apply_from(&mut store).unwrap(); + let result = model.load_from(&mut store).unwrap(); // Should load some tensors (fc1 if it exists in the file) // This mainly tests that the adapter works with real PyTorch files diff --git a/crates/burn-store/src/safetensors/tests/error_handling.rs b/crates/burn-store/src/safetensors/tests/error_handling.rs index 50ca06a918..303353a5d1 100644 --- a/crates/burn-store/src/safetensors/tests/error_handling.rs +++ b/crates/burn-store/src/safetensors/tests/error_handling.rs @@ -14,7 +14,7 @@ fn shape_mismatch_errors() { // Save module let mut save_store = SafetensorsStore::from_bytes(None); - module.collect_to(&mut save_store).unwrap(); + module.save_into(&mut save_store).unwrap(); // Try to load into incompatible module (different dimensions) let mut incompatible_module = LinearConfig::new(3, 3) @@ -31,7 +31,7 @@ fn shape_mismatch_errors() { p.set_data(data_arc.as_ref().clone()); } - let result = incompatible_module.apply_from(&mut load_store).unwrap(); + let result = incompatible_module.load_from(&mut load_store).unwrap(); // Should have errors due to shape mismatch assert!(!result.errors.is_empty()); @@ -46,6 +46,6 @@ fn shape_mismatch_errors() { p.set_data(data_arc.as_ref().clone()); } - let validation_result = incompatible_module.apply_from(&mut load_store_with_validation); + let validation_result = incompatible_module.load_from(&mut load_store_with_validation); assert!(validation_result.is_err()); } diff --git a/crates/burn-store/src/safetensors/tests/file_io.rs b/crates/burn-store/src/safetensors/tests/file_io.rs index 6deb8c1d56..ae9452ea56 100644 --- a/crates/burn-store/src/safetensors/tests/file_io.rs +++ b/crates/burn-store/src/safetensors/tests/file_io.rs @@ -1,4 +1,4 @@ -use crate::{ModuleSnapshot, SafetensorsStore}; +use crate::{ModuleSnapshot, ModuleStore, SafetensorsStore}; use burn::nn::LinearConfig; type TestBackend = burn_ndarray::NdArray; @@ -20,7 +20,7 @@ fn file_based_loading() { // Save to file let mut save_store = SafetensorsStore::from_file(&file_path).metadata("test", "file_loading"); - module.collect_to(&mut save_store).unwrap(); + module.save_into(&mut save_store).unwrap(); // Verify file exists assert!(file_path.exists()); @@ -32,7 +32,7 @@ fn file_based_loading() { .with_bias(true) .init::(&device); - let result = loaded_module.apply_from(&mut load_store).unwrap(); + let result = loaded_module.load_from(&mut load_store).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 2); // weight and bias @@ -40,3 +40,186 @@ fn file_based_loading() { // Clean up fs::remove_file(file_path).ok(); } + +#[test] +#[cfg(feature = "std")] +fn test_store_overwrite_protection() { + use tempfile::tempdir; + + let device = Default::default(); + let module = LinearConfig::new(4, 2) + .with_bias(true) + .init::(&device); + + // Create temp directory and file path (file doesn't exist yet) + let temp_dir = tempdir().unwrap(); + let path = temp_dir.path().join("test_model.safetensors"); + + // First save - should succeed + let mut save_store = SafetensorsStore::from_file(&path); + save_store.collect_from(&module).unwrap(); + assert!(path.exists()); + + // Second save without overwrite flag - should fail + let mut save_store2 = SafetensorsStore::from_file(&path); + let result = save_store2.collect_from(&module); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("File already exists") + ); + + // Third save with overwrite flag - should succeed + let mut save_store3 = SafetensorsStore::from_file(&path).overwrite(true); + save_store3.collect_from(&module).unwrap(); + + // Verify file still exists and is valid + let mut load_store = SafetensorsStore::from_file(&path); + let mut module2 = LinearConfig::new(4, 2) + .with_bias(true) + .init::(&device); + let result = load_store.apply_to(&mut module2).unwrap(); + assert!(result.is_success()); +} + +#[test] +#[cfg(feature = "std")] +fn test_store_overwrite_with_metadata() { + use tempfile::tempdir; + + let device = Default::default(); + let module = LinearConfig::new(4, 2) + .with_bias(true) + .init::(&device); + + // Create temp directory and file path + let temp_dir = tempdir().unwrap(); + let path = temp_dir.path().join("test_model_metadata.safetensors"); + + // First save with v1 metadata and overwrite enabled + let mut save_store = SafetensorsStore::from_file(&path) + .metadata("model_version", "v1") + .overwrite(true); + save_store.collect_from(&module).unwrap(); + + // Second save with v2 metadata and overwrite enabled + let mut save_store2 = SafetensorsStore::from_file(&path) + .metadata("model_version", "v2") + .overwrite(true); + save_store2.collect_from(&module).unwrap(); + + // Load and verify the metadata was updated to v2 + let mut load_store = SafetensorsStore::from_file(&path); + // Since we can't easily access metadata after loading, we just verify the file loads successfully + let mut module2 = LinearConfig::new(4, 2) + .with_bias(true) + .init::(&device); + let result = module2.load_from(&mut load_store).unwrap(); + assert!(result.is_success()); +} + +#[test] +#[cfg(feature = "std")] +fn test_forward_pass_preservation_after_save_load() { + use burn_core::module::Module; + use burn_tensor::Tensor; + use tempfile::tempdir; + + // Define a test model with forward pass + #[derive(Module, Debug)] + struct ForwardTestModel { + linear1: burn::nn::Linear, + linear2: burn::nn::Linear, + } + + impl ForwardTestModel { + fn forward(&self, input: Tensor) -> Tensor { + let x = self.linear1.forward(input); + let x = burn::tensor::activation::gelu(x); + self.linear2.forward(x) + } + } + + // Define config for the model + #[derive(burn::config::Config, Debug)] + struct ForwardTestModelConfig { + input_size: usize, + hidden_size: usize, + output_size: usize, + } + + impl ForwardTestModelConfig { + fn init( + &self, + device: &B::Device, + ) -> ForwardTestModel { + ForwardTestModel { + linear1: LinearConfig::new(self.input_size, self.hidden_size) + .with_bias(true) + .init(device), + linear2: LinearConfig::new(self.hidden_size, self.output_size) + .with_bias(true) + .init(device), + } + } + } + + let device = Default::default(); + + // Create model config + let config = ForwardTestModelConfig { + input_size: 4, + hidden_size: 8, + output_size: 2, + }; + + // Initialize model1 with random weights + let model1 = config.init::(&device); + + // Create random input + let input = Tensor::::random( + [1, 4], + burn_tensor::Distribution::Uniform(-1.0, 1.0), + &device, + ); + + // Forward pass with model1 -> output1 + let output1 = model1.forward(input.clone()); + + // Save model1 weights + let temp_dir = tempdir().unwrap(); + let path = temp_dir.path().join("forward_test_model.safetensors"); + let mut save_store = SafetensorsStore::from_file(&path); + save_store.collect_from(&model1).unwrap(); + + // Initialize model2 with different random weights + let mut model2 = config.init::(&device); + + // Forward pass with model2 -> output2 (should differ from output1) + let output2 = model2.forward(input.clone()); + + // Verify output2 differs from output1 (different random weights) + assert!( + !output1 + .clone() + .all_close(output2.clone(), Some(1e-6), Some(1e-6)), + "output2 should differ from output1 (different random initializations)" + ); + + // Load model1 weights into model2 + let mut load_store = SafetensorsStore::from_file(&path); + let result = load_store.apply_to(&mut model2).unwrap(); + assert!(result.is_success()); + assert_eq!(result.applied.len(), 4); // 2 weights + 2 biases + + // Forward pass with model2 (now has model1 weights) -> output3 + let output3 = model2.forward(input.clone()); + + // Verify output3 equals output1 (same weights) + assert!( + output1.all_close(output3, Some(1e-6), Some(1e-6)), + "output3 should equal output1 after loading weights" + ); +} diff --git a/crates/burn-store/src/safetensors/tests/filtering.rs b/crates/burn-store/src/safetensors/tests/filtering.rs index c5aaf8b7c0..0b4d2effb2 100644 --- a/crates/burn-store/src/safetensors/tests/filtering.rs +++ b/crates/burn-store/src/safetensors/tests/filtering.rs @@ -13,7 +13,7 @@ fn filtered_export_import() { // Export only encoder tensors using the builder pattern let mut save_store = SafetensorsStore::from_bytes(None).with_regex(r"^encoder\..*"); - module1.collect_to(&mut save_store).unwrap(); + module1.save_into(&mut save_store).unwrap(); // Import filtered tensors - need to allow partial since we only saved encoder tensors let mut load_store = SafetensorsStore::from_bytes(None).allow_partial(true); @@ -24,7 +24,7 @@ fn filtered_export_import() { let data_arc = p_save.data().unwrap(); p.set_data(data_arc.as_ref().clone()); } - let result = module2.apply_from(&mut load_store).unwrap(); + let result = module2.load_from(&mut load_store).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 3); // encoder.weight, encoder.bias, encoder.norm @@ -51,7 +51,7 @@ fn builder_pattern_filtering() { }) .count(); - module.collect_to(&mut store).unwrap(); + module.save_into(&mut store).unwrap(); // Verify we saved the expected number of tensors if let SafetensorsStore::Memory(ref p) = store { @@ -72,7 +72,7 @@ fn builder_pattern_exact_paths() { .with_full_path("encoder.norm") .with_full_paths(paths.clone()); - module.collect_to(&mut store).unwrap(); + module.save_into(&mut store).unwrap(); // Verify only specified tensors were saved if let SafetensorsStore::Memory(ref p) = store { @@ -97,7 +97,7 @@ fn builder_pattern_with_predicate() { path.contains("layer") && path.ends_with("weight") }); - module.collect_to(&mut store).unwrap(); + module.save_into(&mut store).unwrap(); // Verify only layer weights were saved if let SafetensorsStore::Memory(ref p) = store { @@ -127,7 +127,7 @@ fn builder_pattern_combined() { path.contains("projection") }); - module.collect_to(&mut store).unwrap(); + module.save_into(&mut store).unwrap(); if let SafetensorsStore::Memory(ref p) = store { let data = p.data().unwrap(); @@ -159,7 +159,7 @@ fn builder_pattern_match_all() { // Test match_all - should save everything let mut store = SafetensorsStore::from_bytes(None).match_all(); - module.collect_to(&mut store).unwrap(); + module.save_into(&mut store).unwrap(); if let SafetensorsStore::Memory(ref p) = store { let data = p.data().unwrap(); diff --git a/crates/burn-store/src/safetensors/tests/integration.rs b/crates/burn-store/src/safetensors/tests/integration.rs index 522bade530..d547b508b7 100644 --- a/crates/burn-store/src/safetensors/tests/integration.rs +++ b/crates/burn-store/src/safetensors/tests/integration.rs @@ -117,7 +117,7 @@ fn basic_usage() { let mut save_store = SafetensorsStore::from_bytes(None).metadata("model_name", "test_model"); // Use collect_to method - model.collect_to(&mut save_store).unwrap(); + model.save_into(&mut save_store).unwrap(); // Load using new API let mut load_store = SafetensorsStore::from_bytes(None); @@ -128,7 +128,7 @@ fn basic_usage() { } let mut target_model = IntegrationTestModel::::new(&device); - let result = target_model.apply_from(&mut load_store).unwrap(); + let result = target_model.load_from(&mut load_store).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 14); // All tensors should be applied @@ -147,7 +147,7 @@ fn with_filtering() { .with_regex(r"^encoder\..*") .metadata("subset", "encoder_only"); - model.collect_to(&mut save_store).unwrap(); + model.save_into(&mut save_store).unwrap(); // Load into new model - need to allow partial loading since we only saved encoder tensors let mut load_store = SafetensorsStore::from_bytes(None).allow_partial(true); @@ -158,7 +158,7 @@ fn with_filtering() { } let mut target_model = IntegrationTestModel::::new(&device); - let result = target_model.apply_from(&mut load_store).unwrap(); + let result = target_model.load_from(&mut load_store).unwrap(); // Only encoder tensors should be applied assert_eq!(result.applied.len(), 6); // encoder has 6 tensors (2 layers × 2 + norm × 2) diff --git a/crates/burn-store/src/safetensors/tests/metadata.rs b/crates/burn-store/src/safetensors/tests/metadata.rs index e664041903..3838d4189d 100644 --- a/crates/burn-store/src/safetensors/tests/metadata.rs +++ b/crates/burn-store/src/safetensors/tests/metadata.rs @@ -30,7 +30,7 @@ fn metadata_preservation() { .metadata("model_type", "linear") .metadata("custom_field", "test_value"); - module.collect_to(&mut save_store).unwrap(); + module.save_into(&mut save_store).unwrap(); // Verify metadata was saved (would need to add a method to check metadata) // For now, just verify the round trip works @@ -46,7 +46,7 @@ fn metadata_preservation() { let mut module2 = LinearConfig::new(4, 2) .with_bias(true) .init::(&device); - let result = module2.apply_from(&mut load_store).unwrap(); + let result = module2.load_from(&mut load_store).unwrap(); assert!(result.is_success()); } @@ -64,7 +64,7 @@ fn clear_metadata_removes_all() { .metadata("custom_field", "test_value") .clear_metadata(); // Should remove all metadata including defaults - module.collect_to(&mut save_store).unwrap(); + module.save_into(&mut save_store).unwrap(); // Load and verify the module still works (metadata is optional) let mut load_store = SafetensorsStore::from_bytes(None); @@ -78,7 +78,7 @@ fn clear_metadata_removes_all() { let mut module2 = LinearConfig::new(4, 2) .with_bias(true) .init::(&device); - let result = module2.apply_from(&mut load_store).unwrap(); + let result = module2.load_from(&mut load_store).unwrap(); assert!(result.is_success()); } @@ -95,7 +95,7 @@ fn clear_then_add_custom_metadata() { .clear_metadata() .metadata("only_custom", "value"); - module.collect_to(&mut save_store).unwrap(); + module.save_into(&mut save_store).unwrap(); // Verify round-trip works let mut load_store = SafetensorsStore::from_bytes(None); @@ -109,7 +109,7 @@ fn clear_then_add_custom_metadata() { let mut module2 = LinearConfig::new(4, 2) .with_bias(true) .init::(&device); - let result = module2.apply_from(&mut load_store).unwrap(); + let result = module2.load_from(&mut load_store).unwrap(); assert!(result.is_success()); } diff --git a/crates/burn-store/src/safetensors/tests/mixed_datatypes.rs b/crates/burn-store/src/safetensors/tests/mixed_datatypes.rs index 4ebc16b061..359f64a92a 100644 --- a/crates/burn-store/src/safetensors/tests/mixed_datatypes.rs +++ b/crates/burn-store/src/safetensors/tests/mixed_datatypes.rs @@ -67,14 +67,14 @@ mod tests { // Save to bytes let mut save_store = SafetensorsStore::from_bytes(None); - model.collect_to(&mut save_store).expect("Failed to save"); + model.save_into(&mut save_store).expect("Failed to save"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load into a new model let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); let mut loaded_model = MixedDtypeModel::::new(&device); loaded_model - .apply_from(&mut load_store) + .load_from(&mut load_store) .expect("Failed to load"); // Verify float tensor is preserved @@ -171,13 +171,13 @@ mod tests { // Save and load let mut save_store = SafetensorsStore::from_bytes(None); - model.collect_to(&mut save_store).expect("Failed to save"); + model.save_into(&mut save_store).expect("Failed to save"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); let mut loaded_model = ExtremeValueModel::::new(&device); loaded_model - .apply_from(&mut load_store) + .load_from(&mut load_store) .expect("Failed to load"); // Check exact preservation @@ -218,14 +218,14 @@ mod tests { // Save to bytes let mut save_store = SafetensorsStore::from_bytes(None); - model.collect_to(&mut save_store).expect("Failed to save"); + model.save_into(&mut save_store).expect("Failed to save"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load and verify let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); let mut loaded_model = MixedDtypeModel::::new(&device); loaded_model - .apply_from(&mut load_store) + .load_from(&mut load_store) .expect("Failed to load"); assert_eq!( @@ -259,7 +259,7 @@ mod tests { // Save to bytes let mut save_store = SafetensorsStore::from_bytes(None); - model.collect_to(&mut save_store).expect("Failed to save"); + model.save_into(&mut save_store).expect("Failed to save"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load and verify @@ -269,7 +269,7 @@ mod tests { double_precision: Param::from_tensor(Tensor::zeros([2, 2], &device)), }; loaded_model - .apply_from(&mut load_store) + .load_from(&mut load_store) .expect("Failed to load"); let orig = model.double_precision.val().into_data(); @@ -310,7 +310,7 @@ mod tests { // Save to bytes let mut save_store = SafetensorsStore::from_bytes(None); - model.collect_to(&mut save_store).expect("Failed to save"); + model.save_into(&mut save_store).expect("Failed to save"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load and verify @@ -321,7 +321,7 @@ mod tests { large_ints: Param::initialized(ParamId::new(), Tensor::zeros([4], &device)), }; loaded_model - .apply_from(&mut load_store) + .load_from(&mut load_store) .expect("Failed to load"); assert_eq!( @@ -400,7 +400,7 @@ mod tests { // Save to bytes let mut save_store = SafetensorsStore::from_bytes(None); - model.collect_to(&mut save_store).expect("Failed to save"); + model.save_into(&mut save_store).expect("Failed to save"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load into fresh model @@ -419,7 +419,7 @@ mod tests { ), }; loaded_model - .apply_from(&mut load_store) + .load_from(&mut load_store) .expect("Failed to load"); // Verify all data is preserved diff --git a/crates/burn-store/src/safetensors/tests/multi_layer_verify.rs b/crates/burn-store/src/safetensors/tests/multi_layer_verify.rs index 009fa1d3e5..95bbf1ef91 100644 --- a/crates/burn-store/src/safetensors/tests/multi_layer_verify.rs +++ b/crates/burn-store/src/safetensors/tests/multi_layer_verify.rs @@ -66,7 +66,7 @@ fn multi_layer_model() { .allow_partial(true); let mut model = Net::::new(&device); - let result = model.apply_from(&mut store).unwrap(); + let result = model.load_from(&mut store).unwrap(); // Verify loading was successful assert!( diff --git a/crates/burn-store/src/safetensors/tests/pytorch_import.rs b/crates/burn-store/src/safetensors/tests/pytorch_import.rs index 7380e6f4aa..a0aaf1039a 100644 --- a/crates/burn-store/src/safetensors/tests/pytorch_import.rs +++ b/crates/burn-store/src/safetensors/tests/pytorch_import.rs @@ -56,17 +56,19 @@ fn multi_layer_model_import() { // PyTorch stores as [out_features, in_features], Burn as [in_features, out_features] // Also, tensor names may differ (e.g., PyTorch uses different names for BatchNorm params) let mut store = SafetensorsStore::from_file(safetensors_path) - .validate(false) // Disable validation due to shape differences + .with_from_adapter(crate::PyTorchToBurnAdapter) // Use adapter to handle PyTorch format .allow_partial(true); // Allow partial loading due to naming differences let mut model = Net::::new(&device); - let result = model.apply_from(&mut store).unwrap(); + let result = model.load_from(&mut store).unwrap(); - // Since we have shape mismatches with PyTorch model (transposed weights), - // we expect some errors but should still load what we can + // With the adapter, weights should load correctly assert!(!result.applied.is_empty()); - // fc1.weight will have errors due to shape mismatch - assert!(!result.errors.is_empty()); + assert!( + result.errors.is_empty(), + "Should have no errors with adapter: {:?}", + result.errors + ); // Test forward pass with the loaded weights // Note: Due to shape mismatches (PyTorch vs Burn conventions for linear layers), @@ -94,17 +96,22 @@ fn safetensors_round_trip_with_pytorch_model() { // Load the model from PyTorch safetensors let mut load_store = SafetensorsStore::from_file(safetensors_path) - .validate(false) // Disable validation due to shape differences + .with_from_adapter(crate::PyTorchToBurnAdapter) // Use adapter to handle PyTorch format .allow_partial(true); // Allow partial loading due to naming differences let mut model = Net::::new(&device); - let load_result = model.apply_from(&mut load_store).unwrap(); - // We expect some errors due to shape mismatch but some tensors should load + let load_result = model.load_from(&mut load_store).unwrap(); + // With the adapter, weights should load correctly assert!(!load_result.applied.is_empty()); + assert!( + load_result.errors.is_empty(), + "Should have no errors with adapter: {:?}", + load_result.errors + ); // Save the model to memory // Note: format, producer and version are automatically added let mut save_store = SafetensorsStore::from_bytes(None).metadata("source", "pytorch"); - model.collect_to(&mut save_store).unwrap(); + model.save_into(&mut save_store).unwrap(); // Load into a new model let mut model2 = Net::::new(&device); @@ -115,7 +122,7 @@ fn safetensors_round_trip_with_pytorch_model() { p.set_data(p_save.data().unwrap().as_ref().clone()); } - let result = model2.apply_from(&mut load_store2).unwrap(); + let result = model2.load_from(&mut load_store2).unwrap(); assert!(!result.applied.is_empty()); // Verify both models produce the same output @@ -153,7 +160,7 @@ fn partial_load_from_pytorch_model() { // Save initial fc1 weights for comparison let _initial_fc1_weight = model.fc1.weight.val().to_data(); - let result = model.apply_from(&mut store).unwrap(); + let result = model.load_from(&mut store).unwrap(); // Should load available tensors (with some errors due to shape mismatch) assert!(!result.applied.is_empty()); @@ -179,7 +186,7 @@ fn verify_tensor_names_from_pytorch() { let mut store = SafetensorsStore::from_file(safetensors_path) .validate(false) // Disable validation due to shape differences .allow_partial(true); // Allow partial loading due to naming differences - let result = model.apply_from(&mut store).unwrap(); + let result = model.load_from(&mut store).unwrap(); // Check that we loaded some tensors (with errors due to shape mismatch) assert!(!result.applied.is_empty()); diff --git a/crates/burn-store/src/safetensors/tests/round_trip.rs b/crates/burn-store/src/safetensors/tests/round_trip.rs index 45922bdc3c..c2c0b953f2 100644 --- a/crates/burn-store/src/safetensors/tests/round_trip.rs +++ b/crates/burn-store/src/safetensors/tests/round_trip.rs @@ -75,7 +75,7 @@ fn complex_module_round_trip() { // Save module1 using new store API let mut save_store = SafetensorsStore::from_bytes(None); - module1.collect_to(&mut save_store).unwrap(); + module1.save_into(&mut save_store).unwrap(); // Load into module2 let mut load_store = SafetensorsStore::from_bytes(None); @@ -86,7 +86,7 @@ fn complex_module_round_trip() { let data_arc = p_save.data().unwrap(); p.set_data(data_arc.as_ref().clone()); } - let result = module2.apply_from(&mut load_store).unwrap(); + let result = module2.load_from(&mut load_store).unwrap(); assert!(result.is_success()); assert!(result.applied.len() > 5); diff --git a/crates/burn-store/src/tensor_snapshot.rs b/crates/burn-store/src/tensor_snapshot.rs index b493eeb714..8271dc409f 100644 --- a/crates/burn-store/src/tensor_snapshot.rs +++ b/crates/burn-store/src/tensor_snapshot.rs @@ -16,6 +16,18 @@ pub enum TensorSnapshotError { PanicError(String), } +impl core::fmt::Display for TensorSnapshotError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::IoError(e) => write!(f, "I/O error: {}", e), + Self::DataError(e) => write!(f, "Data error: {}", e), + Self::PanicError(e) => write!(f, "Panic error: {}", e), + } + } +} + +impl core::error::Error for TensorSnapshotError {} + /// A lightweight snapshot of a tensor that can lazily produce TensorData. /// /// TensorSnapshot stores a cloned tensor internally (which is cheap due to reference counting) @@ -410,6 +422,31 @@ mod tests { } } + #[test] + fn error_propagation_in_closure() { + use alloc::rc::Rc; + + // Create a snapshot with a closure that returns an error + let snapshot = TensorSnapshot::from_closure( + Rc::new(|| Err(TensorSnapshotError::IoError("Simulated IO error".into()))), + DType::F32, + vec![2, 2], + vec!["error_test".into()], + vec![], + ParamId::new(), + ); + + // Should return an error when trying to get data + let result = snapshot.to_data(); + assert!(result.is_err()); + match result { + Err(TensorSnapshotError::IoError(msg)) => { + assert!(msg.contains("Simulated IO error")); + } + _ => panic!("Expected IoError"), + } + } + #[test] fn container_type_extraction() { let device = Default::default(); diff --git a/crates/burn-store/src/traits.rs b/crates/burn-store/src/traits.rs index 9fa9229679..fe4515cde2 100644 --- a/crates/burn-store/src/traits.rs +++ b/crates/burn-store/src/traits.rs @@ -99,32 +99,32 @@ pub trait ModuleSnapshot: Module { applier.into_result() } - /// Collects tensor snapshots into a [`ModuleSnapshoter`] for saving. + /// Saves tensor snapshots into a [`ModuleStore`]. /// - /// This method allows using a `ModuleSnapshoter` implementation to handle the + /// This method allows using a `ModuleStore` implementation to handle the /// collection and writing logic in a configurable way. /// /// # Arguments /// - /// * `store` - A mutable reference to a [`ModuleSnapshoter`] that will collect and save the tensors - fn collect_to

(&self, store: &mut P) -> Result<(), P::Error> + /// * `store` - A mutable reference to a [`ModuleStore`] that will collect and save the tensors + fn save_into

(&self, store: &mut P) -> Result<(), P::Error> where - P: ModuleSnapshoter, + P: ModuleStore, { store.collect_from(self) } - /// Applies tensor data from a [`ModuleSnapshoter`] for loading. + /// Loads tensor data from a [`ModuleStore`]. /// - /// This method allows using a `ModuleSnapshoter` implementation to handle the + /// This method allows using a `ModuleStore` implementation to handle the /// loading and application logic in a configurable way. /// /// # Arguments /// - /// * `store` - A mutable reference to a [`ModuleSnapshoter`] that will load and apply tensors - fn apply_from

(&mut self, store: &mut P) -> Result + /// * `store` - A mutable reference to a [`ModuleStore`] that will load and apply tensors + fn load_from

(&mut self, store: &mut P) -> Result where - P: ModuleSnapshoter, + P: ModuleStore, { store.apply_to(self) } @@ -132,10 +132,10 @@ pub trait ModuleSnapshot: Module { /// A trait for handling module storage operations. /// -/// `ModuleSnapshoter` provides a unified interface for saving and loading module +/// `ModuleStore` provides a unified interface for saving and loading module /// tensor data with support for various storage formats and advanced features like filtering, /// remapping, and metadata handling. -pub trait ModuleSnapshoter { +pub trait ModuleStore { /// The error type that can be returned during storage operations. /// /// This should be a format-specific error type that provides detailed diff --git a/examples/wgan/src/model.rs b/examples/wgan/src/model.rs index dc85e1b606..c839984c6c 100644 --- a/examples/wgan/src/model.rs +++ b/examples/wgan/src/model.rs @@ -1,5 +1,5 @@ use burn::{ - module::{Module, ModuleMapper, ParamId}, + module::{Module, ModuleMapper, Param}, prelude::*, tensor::backend::AutodiffBackend, }; @@ -143,7 +143,8 @@ pub struct Clip { } impl ModuleMapper for Clip { - fn map_float(&mut self, _id: ParamId, tensor: Tensor) -> Tensor { + fn map_float(&mut self, param: Param>) -> Param> { + let (id, tensor, mapper) = param.consume(); let is_require_grad = tensor.is_require_grad(); let mut tensor = Tensor::from_inner(tensor.inner().clamp(self.min, self.max)); @@ -151,6 +152,6 @@ impl ModuleMapper for Clip { if is_require_grad { tensor = tensor.require_grad(); } - tensor + Param::from_mapped_value(id, tensor, mapper) } }