- 
                Notifications
    
You must be signed in to change notification settings  - Fork 730
 
BurnpackStore #3792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BurnpackStore #3792
Conversation
Replaces candle-core pickle parsing with burn-store's PyTorch reader for improved compatibility and maintainability. Adds tensor snapshot support, updates config and reader modules, and adjusts dependencies. Includes new test files and updates feature flags in Cargo.toml files.
Implemented handling for 'BoolStorage' in the PyTorch pickle reader, allowing boolean tensors to be loaded. Updated the test to enable the boolean tensor test, which was previously ignored due to lack of support.
Introduces extensive tests for the PyTorch file reader covering various tensor types, shapes, edge cases, and nested structures. Updates the pickle_reader to correctly parse int32, int16, and int8 tensor data. Adds a Python script to generate test .pt files and integrates all test data into the repository for robust validation.
This update adds robust handling for legacy PyTorch checkpoint formats (pre-1.6), including sequential pickle streams and embedded storage data. The pickle reader now supports additional opcodes and improved error messages, while the main reader detects legacy formats and extracts tensors with correct storage offsets. New tests and test data files verify compatibility with legacy files, shared storage, and error handling for corrupted files.
Introduces the PytorchMetadata struct and related enums to capture metadata about loaded PyTorch files, including format type, version, byte order, tensor count, and data size. Updates PytorchReader to expose metadata and adds tests to verify metadata extraction for ZIP and legacy formats.
Replaces legacy zip/pickle reading logic in burn-import's config.rs with the new PytorchReader and PickleValue API from burn-store. Adds PickleValue enum and read_pickle_data method to PytorchReader for simplified config extraction. Updates error handling, test coverage, and public API to support reading configuration and metadata from PyTorch files in a more robust and extensible way.
Introduces PytorchReader::load_config for deserializing configuration data from PyTorch files using serde. Refactors config loading in burn-import to use this new API, adds related tests, and updates dependencies to support custom serde features.
Introduces the PytorchStore struct for loading models from PyTorch checkpoint files (.pt/.pth), with support for filtering, remapping, and validation. Adds comprehensive tests for various model types and error handling. Saving to PyTorch format is not yet supported.
Improves the adapter system to use container type information for correct tensor transformations (e.g., transposing linear weights, renaming normalization parameters) and refactors the Applier to support adapters and provide more detailed error handling. Updates tests and documentation to reflect new module-aware behavior.
Unified the `collect` and `apply` methods to accept an optional `PathFilter` for flexible tensor filtering. Deprecated specialized methods in favor of a single interface, updated all usages and tests, and improved documentation for clarity. This change simplifies the API and enhances consistency across tensor snapshot operations.
The Collector and ModuleSnapshot traits now accept an optional ModuleAdapter to transform tensors during collection. This change centralizes tensor adaptation logic, simplifies usage in SafetensorsStore, and updates all relevant tests and usages to support the new adapter parameter.
Introduces a test for verifying multi-layer neural network loading from SafeTensors format using a PyTorch adapter. The test checks successful parameter loading and validates the model's forward output against expected values.
Enhanced crate-level and module-level documentation for burn-store, detailing key features, usage examples, and configuration options for model storage and PyTorch interoperability. Improves clarity for users integrating Burn with PyTorch and using advanced storage features.
Documentation and examples now clarify that PyTorchToBurnAdapter is applied automatically when loading PyTorch models, handling weight transposition and normalization parameter renaming by default. Code comments and docstrings have been updated for consistency and improved guidance.
Introduces a new benchmark suite for PyTorch model loading in burn-store, including a Python script to generate model files and a Rust benchmark comparing old and new loading methods across multiple backends. Updates Cargo.toml to register the new benchmark.
Introduces the LazyDataSource abstraction to support efficient, on-demand loading of tensor data from PyTorch files, including ZIP archives and legacy multi-storage formats. Refactors pickle_reader and reader modules to utilize lazy loading, reducing memory usage and improving performance for large models.
Removes unused FileSource and related code, improves lazy boundary detection for legacy multi-storage format by tracking storage usage and storage keys, and adds skip_pickle for efficient pickle skipping. Updates pickle_reader to support optional data sources and refactors error handling for tensor data loading. Enhances legacy format metadata extraction and adds a detailed test for legacy metadata correctness.
Introduces a Python script and Rust benchmark for loading and profiling a ResNet18 PyTorch model in burn-store. Refactors lazy loading in LegacyMultiStorageSource to strictly require storage boundaries, removing fallback to full blob loading. Updates lazy data range reading to only read requested tensor ranges. Ensures storage usage is tracked immediately during tensor reconstruction for accurate lazy boundary detection.
Replaces all usages of Param::into_initialized with Param::from_mapped_value across module, quantize, reinit, and optimizer adaptor code. Updates the method name in the Param implementation to improve clarity and consistency in parameter mapping operations.
Renamed Param<Tensor> methods: 'save' to 'transform_for_save' and 'load' to 'transform_for_load' to better reflect their purpose of applying transformations during serialization and deserialization. Updated all usages in burn-core and burn-store accordingly for improved code clarity.
Replaces all usages and documentation of `collect_to` and `apply_from` with the more descriptive `save_into` and `load_from` methods for model serialization and deserialization. Updates all code, tests, examples, and documentation to use the new method names, improving clarity and consistency across the burn-store crate.
Replaces all usages and references of the ModuleSnapshoter trait with ModuleStore across the codebase, including trait implementations, imports, and documentation. This change improves naming consistency and clarity for module storage operations.
Renamed the BurnpackHeader::to_bytes method to into_bytes for clarity and consistency with Rust naming conventions. Updated all usages in tests and writer modules accordingly.
Refactored BurnpackWriter to support writing directly into caller-provided buffers via a new write_into() method, and added a size() method to calculate the required buffer size. Updated internal APIs and tests to use the new Bytes type for in-memory storage, improved memory efficiency, and enabled buffer reuse for serialization. Also added comprehensive tests for buffer-based writing and error handling.
Updated all references, documentation, tests, and implementation logic to use the .bpk file extension instead of .burnpack for BurnpackStore files.
Replaces references to `load()` and `save()` with `transform_for_load()` and `transform_for_save()` in the ParamMapper documentation to accurately describe where transformations are applied.
Enhanced the StorageBackend::read_into method to return errors for out-of-bounds and offset overflow conditions, ensuring consistent and safe behavior across backends. Added unit tests to verify error handling for out-of-bounds reads and offset overflows.
Replaces calls to `as_ref()` with `&*bytes` in assertions comparing byte slices in writer tests.
This commit adds overflow checking for metadata size, tensor shape dimensions, and data offsets in BurnpackReader and BurnpackWriter. It also validates tensor data length consistency during writing, ensuring that actual and expected lengths match. These changes improve robustness against corrupted or malformed files and prevent potential panics or undefined behavior due to integer overflows.
| 
           @laggui, I have addressed all @nathanielsimard feedback. It's more robust. Additionally, I have fixed potential data corruption issues (made more robust and strict). One of the biggest improvements is Bytes/[u8] usage. I added todos for to have a better control over byte allocation, as we transition to Backends allocator. So it should be easier now to transition to that model.  | 
    
Added implementations of the core::error::Error trait for ApplyError, BurnpackError, and TensorSnapshotError. Also implemented Display for TensorSnapshotError to improve error reporting and compatibility with standard error handling.
Refactored BurnpackReader::get_snapshots to return Result and propagate BurnpackError instead of panicking on corrupted tensor shape or offset data. Updated call sites and added tests to ensure errors are returned for invalid tensor metadata, improving robustness and error handling.
Introduced maximum limits for metadata size, tensor size, tensor count, and CBOR deserialization recursion depth to prevent resource exhaustion and DoS attacks. Updated BurnpackReader to validate these limits during file parsing and tensor access.
Introduces checks in BurnpackReader to ensure the underlying file or buffer is large enough to contain all claimed tensor data, preventing truncated file errors. Adds tests to verify correct error handling for truncated files and successful reading when file size is exactly correct.
Introduces a MAX_FILE_SIZE constant (100 GB) to prevent resource exhaustion from extremely large files. Both mmap and buffered file loading methods in BurnpackReader now validate the file size before proceeding.
| 
           Addressed Denial of Service (DOS) vulnerabilities related to resource exhaustion. This is for a production use with untrusted input.  | 
    
Set MAX_TENSOR_SIZE to 2 GB on 32-bit platforms and 10 GB on 64-bit platforms to prevent memory exhaustion. Also, conditionally define MAX_FILE_SIZE and its usage based on the 'std' feature to improve portability.
Introduces ParamId support in the Burnpack file format for stateful training continuation. Updates the format specification, core types, reader and writer logic, and adds comprehensive tests to ensure ParamId is preserved and backward compatible. Documentation is updated to reflect the new feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥🔥🔥
This PR improves Burn's DefaultFileRecorder by replacing the inefficient NamedMpkFileRecorder with BurnpackStore and new Burnpack format that addresses critical performance and compatibility issues.
Problems with current
NamedMpkFileRecorder:Pull Request Template
Checklist
cargo run-checkscommand has been executed.Related Issues/PRs
Depends on #3741 (PyTorch store changes)
Changes
A new native storage format (Burnpack Format) that serves as Burn's improved DefaultFileRecorder (
NamedMpkFileRecorder):Key Improvements:
Technical Changes:
Testing
Benchmarks
LOADING Benchmarks (Median Values)
Maximum Memory Allocation During Load
Load Time vs Memory Trade-off
SAVING Benchmarks (Median Values)
Maximum Memory Allocation During Save
Save Time vs Memory Trade-off
Benchmark details
load_report.txt
save_report.txt