Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
127 commits
Select commit Hold shift + click to select a range
cea0804
Add copy Candle's PyTorch pickle reader to burn-store crate
antimora Sep 16, 2025
581e0e1
Refactor PyTorch pickle reader for TensorSnapshot use
antimora Sep 16, 2025
f59a894
Refactor PyTorch import to use burn-store reader
antimora Sep 17, 2025
dc93872
Handle Class object in tensor storage type parsing
antimora Sep 17, 2025
1484d74
Improve storage file matching in pickle_reader
antimora Sep 17, 2025
ce1347a
Refactor tensor extraction to remove unused data_files param
antimora Sep 17, 2025
08b7647
Add support for loading boolean tensors from PyTorch
antimora Sep 17, 2025
1f01458
Add comprehensive PyTorch reader tests and fix int tensor parsing
antimora Sep 17, 2025
e5f3f4b
Add support for legacy PyTorch file formats
antimora Sep 17, 2025
6e11085
Add PyTorch file metadata extraction to reader
antimora Sep 17, 2025
f466aeb
Refactor PyTorch config and reader APIs for pickle support
antimora Sep 17, 2025
d3e5acc
Add config deserialization to PytorchReader
antimora Sep 17, 2025
f142bfd
Remove unused read_pytorch_tensors function
antimora Sep 17, 2025
c8781ae
Remove PyTorch version warning in reader
antimora Sep 17, 2025
15e2511
Remove unused metastack field from Stack struct
antimora Sep 17, 2025
a7afbe9
Update script usage instruction in test_data.py
antimora Sep 17, 2025
74c1a9e
Add PyTorchStore for loading PyTorch model files
antimora Sep 17, 2025
7d4366c
Refactor adapter and applier for module-aware tensor transforms
antimora Sep 17, 2025
6a7a124
Merge remote-tracking branch 'upstream/main' into pytorch-pt-store
antimora Sep 17, 2025
6dff8bf
Refactor tensor collection and application APIs
antimora Sep 17, 2025
2ccf55e
Add adapter support to tensor collection
antimora Sep 17, 2025
9a4be60
Add multi-layer SafeTensors model loading test
antimora Sep 17, 2025
1e46470
Remove unused test modules from safetensors tests
antimora Sep 17, 2025
306877f
Rename with_key_pattern to with_key_remapping
antimora Sep 17, 2025
278ee5a
Expand documentation for Burn Store and PyTorch support
antimora Sep 17, 2025
beaf311
Update PyTorchStore docs for automatic adapter usage
antimora Sep 17, 2025
2df6402
Add PyTorch model loading benchmarks
antimora Sep 18, 2025
83d51c1
Add lazy data loading for PyTorch file formats
antimora Sep 18, 2025
c69fe14
Refactor legacy PyTorch format handling and metadata
antimora Sep 18, 2025
f1603bb
Add ResNet18 loading benchmark and improve lazy loading
antimora Sep 18, 2025
a13ed71
Refactor feature gating and error types in burn-store
antimora Sep 18, 2025
831dc09
Refactor PyTorch model definitions and benchmarks
antimora Sep 18, 2025
b547e25
Add unified loading benchmark and model generator
antimora Sep 18, 2025
a5665f1
Remove Clone bound from ModuleSnapshot trait
antimora Sep 18, 2025
af633a1
Refactor benchmarks to unified large model format
antimora Sep 18, 2025
e4bebc6
Merge remote-tracking branch 'upstream/main' into pytorch-pt-store
antimora Sep 18, 2025
cd1423c
Make serde and zip optional dependencies
antimora Sep 18, 2025
3a32330
Update burn-store README with advanced usage examples
antimora Sep 19, 2025
a7b7d57
Document unsafe memory handling in traits.rs
antimora Sep 21, 2025
ddde0ed
Merge remote-tracking branch 'upstream/main' into pytorch-pt-store
antimora Sep 23, 2025
46d4358
Migrate nn imports from burn_core to burn_nn
antimora Sep 23, 2025
cc5c31f
Merge remote-tracking branch 'upstream/main' into pytorch-pt-store
antimora Sep 25, 2025
4b4fe5a
Refactor data_len to use DType::size() method
antimora Sep 26, 2025
7879d1b
Refactor TensorSnapshot to handle errors in data loading
antimora Sep 28, 2025
b2b1108
Import alloc::format in applier.rs
antimora Sep 29, 2025
96661a5
Merge remote-tracking branch 'upstream/main' into pytorch-pt-store
antimora Oct 2, 2025
238008f
Improve PyTorch legacy format detection
antimora Oct 2, 2025
3da7332
Fix benchmark
antimora Oct 2, 2025
7d0c756
Optimize tensor extraction by using path vector
antimora Oct 2, 2025
95fba49
Update pickle_reader docs for PyTorch compatibility
antimora Oct 2, 2025
93c9ff5
Refactor PyTorch adapter logic into shared function
antimora Oct 2, 2025
0a3e460
Fix opcode mapping for 'd' to Dict in PickleReader
antimora Oct 2, 2025
5bdf6ae
Add support for List and Memoize opcodes in pickle reader
antimora Oct 2, 2025
4c810a6
Handle poisoned locks in PyTorch lazy data sources
antimora Oct 2, 2025
8519420
Fix no-std
antimora Oct 2, 2025
78f4b30
Initial commit
antimora Sep 25, 2025
ae8bf56
Refactor Burnpack store for lazy tensor loading
antimora Sep 26, 2025
a6d2c39
Update Burnpack format specification comment
antimora Sep 26, 2025
ceb885f
Refactor Burnpack header and tensor offset handling
antimora Sep 26, 2025
3c8fa6b
Refactor and add tests
antimora Sep 28, 2025
0d5ab80
Add key remapping support to BurnpackStore
antimora Sep 28, 2025
e6da5bd
Improve error handling in tensor snapshot IO
antimora Sep 29, 2025
279cb3d
Update Burnpack format documentation layout
antimora Sep 29, 2025
7f0d011
Refactor test modules and rename store.rs to header.rs
antimora Sep 29, 2025
6b83b5c
Add tests for BurnpackStore functionality
antimora Sep 29, 2025
1a9ccdb
Add no-std tests for Burnpack storage
antimora Sep 29, 2025
5ca71e9
Fix format
antimora Sep 29, 2025
d6171a3
Add missing newline at end of files
antimora Sep 29, 2025
14dc4d3
Add metadata and validation options to BurnpackStore
antimora Sep 29, 2025
488d309
Switch burnpack metadata serialization to CBOR
antimora Sep 29, 2025
cc20930
Fix import paths
antimora Sep 29, 2025
06f81ab
Add Burnpack and NamedMpk formats to benchmarks
antimora Sep 29, 2025
34ebd62
Update default features and improve file loading logic
antimora Sep 29, 2025
11a084a
Fix formatting
antimora Sep 30, 2025
46a0a1f
Add unified saving benchmark and update docs
antimora Sep 30, 2025
faeba40
Add overwrite protection to BurnpackStore file saving
antimora Sep 30, 2025
065ee2d
Add auto-extension logic for BurnpackStore file paths
antimora Sep 30, 2025
85523e5
Add Burnpack format documentation and inspection example
antimora Sep 30, 2025
40bf3d4
Merge remote-tracking branch 'upstream/main' into burnpackstore
antimora Oct 2, 2025
d734303
Merge remote-tracking branch 'upstream/main' into burnpackstore
antimora Oct 3, 2025
99e0e63
Refactor param record mapping and add load/save
antimora Oct 4, 2025
6eb1188
Refactor ModuleMapper to use Param for parameter mapping
antimora Oct 4, 2025
7313b56
Refactor ModuleVisitor to use Param for parameter visits
antimora Oct 4, 2025
5b48c67
Refactor Param::save to take &self and update Collector
antimora Oct 4, 2025
868b973
Support applying snapshots to uninitialized parameters
antimora Oct 5, 2025
07c127a
Improve Param documentation and clarify lazy init design
antimora Oct 5, 2025
c27a506
Add lazy_shape to Param for shape access without init
antimora Oct 5, 2025
b7e1789
Refactor tensor snapshot application with shape validation
antimora Oct 5, 2025
e4ced35
Use PyTorchToBurnAdapter for safetensors import tests
antimora Oct 5, 2025
bc15bff
Remove path_stack check in tensor collection
antimora Oct 5, 2025
7ff0b66
Allow large enum variant in Activation enum
antimora Oct 5, 2025
0e4b696
Add shape argument to Tensor import in ConstantNode
antimora Oct 5, 2025
5e6679e
Add overwrite protection to SafetensorsStore file saves
antimora Oct 5, 2025
6fa0420
Add test for partial loading with lazy initialization
antimora Oct 5, 2025
69ae96a
Update documentation to use runnable Rust examples
antimora Oct 5, 2025
4bc02bc
Add tests for forward pass preservation after save/load
antimora Oct 5, 2025
9cb21fb
Update metadata format references to CBOR
antimora Oct 6, 2025
dda67f6
Restrict visibility of legacy and file methods to crate
antimora Oct 6, 2025
4792e13
Merge remote-tracking branch 'upstream/main' into burnpackstore
antimora Oct 6, 2025
d0fb481
Merge branch 'main' into burnpackstore
antimora Oct 6, 2025
b6d00d1
Merge remote-tracking branch 'upstream/main' into burnpackstore
antimora Oct 9, 2025
a8cbec5
Fix merge issues
antimora Oct 9, 2025
1ec0e8e
Remove extensive inline documentation from param/base.rs
antimora Oct 9, 2025
60c8639
Rename Param::into_initialized to from_mapped_value
antimora Oct 9, 2025
19ed28d
Remove redundant return value docs from Param methods
antimora Oct 9, 2025
0a82871
Rename Param save/load methods for clarity
antimora Oct 9, 2025
8624593
Rename collect_to/apply_from to save_into/load_from APIs
antimora Oct 9, 2025
bb6ead7
Remove outdated loading benchmark results from README
antimora Oct 9, 2025
81ceb5a
Rename ModuleSnapshoter trait to ModuleStore
antimora Oct 9, 2025
68e0eb6
Rename BurnpackHeader::to_bytes to into_bytes
antimora Oct 10, 2025
5449ab0
Add buffer-based write support to BurnpackWriter
antimora Oct 10, 2025
0f460fb
Switch Burnpack file extension from .burnpack to .bpk
antimora Oct 10, 2025
18019ea
Merge remote-tracking branch 'upstream/main' into burnpackstore
antimora Oct 10, 2025
a4ea166
Update ParamMapper doc to reflect method name changes
antimora Oct 10, 2025
c84c2c3
Fix build
antimora Oct 10, 2025
a9385e4
Move burnpack tests to separate module file
antimora Oct 10, 2025
85acb48
Improve read_into error handling and add tests
antimora Oct 10, 2025
ccf90f8
Refactor tests to use slice deref for byte comparison
antimora Oct 10, 2025
d59a7f0
Add overflow and consistency checks to burnpack reader/writer
antimora Oct 10, 2025
00ae841
Implement Error trait for custom error types
antimora Oct 10, 2025
82e984b
Return errors for corrupted tensor metadata in get_snapshots
antimora Oct 10, 2025
ba0d46d
Add security limits to Burnpack reader
antimora Oct 10, 2025
376f108
Add file size validation for tensor data in BurnpackReader
antimora Oct 10, 2025
fe60198
Add maximum file size limit to BurnpackReader
antimora Oct 10, 2025
478723f
Add platform-specific tensor size limits in burnpack
antimora Oct 10, 2025
207310c
Fix formatting
antimora Oct 10, 2025
3e00be3
Add ParamId persistence to Burnpack format
antimora Oct 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ regex = { version = "1.11.3", default-features = false, features = [
reqwest = { version = "0.12.23", default-features = false, features = [
"rustls-tls",
] }
rmp-serde = "1.3.0"
ciborium = { version = "0.2", default-features = false }
rmp-serde = { version = "1.3.0", default-features = false }
rstest = "0.25.0"
rusqlite = "0.37.0"
rust-format = "0.3.4"
Expand Down
78 changes: 37 additions & 41 deletions crates/burn-core/src/module/base.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{ParamId, Quantizer};
use super::{Param, ParamId, Quantizer};
use crate::{
record::Record,
tensor::backend::{AutodiffBackend, Backend},
Expand All @@ -19,11 +19,12 @@ macro_rules! module {
impl<B: Backend> ModuleMapper<B> for Mapper {
fn map_float<const D: usize>(
&mut self,
_id: ParamId,
tensor: Tensor<B, D>,
) -> Tensor<B, D> {
param: Param<Tensor<B, D>>,
) -> Param<Tensor<B, D>> {
let (id, tensor, mapper) = param.consume();
let func = $item;
func(tensor)
let tensor = func(tensor);
Param::from_mapped_value(id, tensor, mapper)
}
}
let mut mapper = Mapper;
Expand All @@ -35,9 +36,9 @@ macro_rules! module {
backend: core::marker::PhantomData<B>,
}
impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
fn visit_float<const D: usize>(&mut self, _id: ParamId, tensor: &Tensor<B, D>) {
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
let func = $item;
func(tensor, &mut self.state)
func(&param.val(), &mut self.state)
}
}
#[allow(clippy::redundant_closure_call)]
Expand Down Expand Up @@ -211,29 +212,26 @@ pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {

/// Module visitor trait for traversing and inspecting module parameters.
pub trait ModuleVisitor<B: Backend> {
/// Visit a float tensor in the module.
/// Visit a float parameter in the module.
///
/// # Parameters
/// - `id`: The unique identifier of the parameter
/// - `tensor`: The float tensor to visit
/// - `param`: The float parameter to visit
#[allow(unused_variables)]
fn visit_float<const D: usize>(&mut self, id: ParamId, tensor: &Tensor<B, D>) {}
fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {}

/// Visit an int tensor in the module.
/// Visit an int parameter in the module.
///
/// # Parameters
/// - `id`: The unique identifier of the parameter
/// - `tensor`: The integer tensor to visit
/// - `param`: The integer parameter to visit
#[allow(unused_variables)]
fn visit_int<const D: usize>(&mut self, id: ParamId, tensor: &Tensor<B, D, Int>) {}
fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {}

/// Visit a bool tensor in the module.
/// Visit a bool parameter in the module.
///
/// # Parameters
/// - `id`: The unique identifier of the parameter
/// - `tensor`: The boolean tensor to visit
/// - `param`: The boolean parameter to visit
#[allow(unused_variables)]
fn visit_bool<const D: usize>(&mut self, id: ParamId, tensor: &Tensor<B, D, Bool>) {}
fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {}

/// Called when entering a submodule.
///
Expand Down Expand Up @@ -321,51 +319,49 @@ pub trait ModuleMapper<B: Backend> {
#[allow(unused_variables)]
fn exit_module(&mut self, name: &str, container_type: &str) {}

/// Map a float tensor in the module.
/// Map a float parameter in the module.
///
/// # Parameters
/// - `id`: The unique identifier of the parameter
/// - `tensor`: The float tensor to transform
/// - `param`: The float parameter to transform
///
/// # Returns
/// The transformed tensor
/// The transformed parameter
#[allow(unused_variables)]
fn map_float<const D: usize>(&mut self, id: ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor
fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
let (id, tensor, mapper) = param.consume();
Param::from_mapped_value(id, tensor, mapper)
}

/// Map an int tensor in the module.
/// Map an int parameter in the module.
///
/// # Parameters
/// - `id`: The unique identifier of the parameter
/// - `tensor`: The integer tensor to transform
/// - `param`: The integer parameter to transform
///
/// # Returns
/// The transformed tensor
/// The transformed parameter
#[allow(unused_variables)]
fn map_int<const D: usize>(
&mut self,
id: ParamId,
tensor: Tensor<B, D, Int>,
) -> Tensor<B, D, Int> {
tensor
param: Param<Tensor<B, D, Int>>,
) -> Param<Tensor<B, D, Int>> {
let (id, tensor, mapper) = param.consume();
Param::from_mapped_value(id, tensor, mapper)
}

/// Map a bool tensor in the module.
/// Map a bool parameter in the module.
///
/// # Parameters
/// - `id`: The unique identifier of the parameter
/// - `tensor`: The boolean tensor to transform
/// - `param`: The boolean parameter to transform
///
/// # Returns
/// The transformed tensor
/// The transformed parameter
#[allow(unused_variables)]
fn map_bool<const D: usize>(
&mut self,
id: ParamId,
tensor: Tensor<B, D, Bool>,
) -> Tensor<B, D, Bool> {
tensor
param: Param<Tensor<B, D, Bool>>,
) -> Param<Tensor<B, D, Bool>> {
let (id, tensor, mapper) = param.consume();
Param::from_mapped_value(id, tensor, mapper)
}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/burn-core/src/module/initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ impl Initializer {
let device = device.clone();
let shape: Shape = shape.into();
let config = self.clone();
let shape_for_closure = shape.clone();

Param::uninitialized(
ParamId::new(),
Expand All @@ -123,6 +124,7 @@ impl Initializer {
},
device,
true,
shape_for_closure,
)
}

Expand Down
Loading