Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for safetensors in pytorch reader #2721

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ members = [
"crates/*",
"crates/burn-import/pytorch-tests",
"crates/burn-import/onnx-tests",
"crates/burn-import/safetensors-tests",
"examples/*",
"examples/pytorch-import/model",
"xtask",
Expand Down
4 changes: 2 additions & 2 deletions burn-book/src/import/pytorch-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
## Introduction

Whether you've trained your model in PyTorch or you want to use a pre-trained model from PyTorch,
you can import them into Burn. Burn supports importing PyTorch model weights with `.pt` file
extension. Compared to ONNX models, `.pt` files only contain the weights of the model, so you will
you can import them into Burn. Burn supports importing PyTorch model weights with `.pt` and `.safetensors` file
extension. Compared to ONNX models, `.pt` and `.safetensors` files only contain the weights of the model, so you will
need to reconstruct the model architecture in Burn.

Here in this document we will show the full workflow of exporting a PyTorch model and importing it.
Expand Down
7 changes: 5 additions & 2 deletions crates/burn-import/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ version.workspace = true
default-run = "onnx2burn"

[features]
default = ["onnx", "pytorch"]
default = ["onnx", "pytorch", "safetensors"]
onnx = []
pytorch = ["burn/record-item-custom-serde", "thiserror", "zip"]
safetensors = ["burn/record-item-custom-serde", "thiserror", "zip"]

[dependencies]
burn = { path = "../burn", version = "0.17.0", default-features = false, features = ["std"]}
burn = { path = "../burn", version = "0.17.0", default-features = false, features = [
"std",
] }
burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", default-features = false }
onnx-ir = { path = "../onnx-ir", version = "0.17.0" }
candle-core = { workspace = true }
Expand Down
17 changes: 17 additions & 0 deletions crates/burn-import/safetensors-tests/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "safetensors-tests"
version.workspace = true
edition.workspace = true
license.workspace = true

[dev-dependencies]
burn = { path = "../../burn" }
burn-ndarray = { path = "../../burn-ndarray" }
burn-autodiff = { path = "../../burn-autodiff" }
serde = { workspace = true }
float-cmp = { workspace = true }
burn-import = { path = "../", features = ["safetensors"] }


[build-dependencies]
burn-import = { path = "../", features = ["safetensors"] }
1 change: 1 addition & 0 deletions crates/burn-import/safetensors-tests/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import save_file


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.norm1 = nn.BatchNorm2d(5)

def forward(self, x):
x = self.norm1(x)
return x


def main():

torch.set_printoptions(precision=8)
torch.manual_seed(1)

model = Model().to(torch.device("cpu"))

# Condition batch norm (each forward will affect the running stats)
x1 = torch.ones(1, 5, 2, 2) - 0.5
_ = model(x1)
model.eval() # Set to eval mode to freeze running stats
# Save the model to safetensors after the first forward
save_file(model.state_dict(), "batch_norm2d.safetensors")

x2 = torch.ones(1, 5, 2, 2) - 0.3
print("Input shape: {}", x2.shape)
output = model(x2)
print("Output: {}", output)
print("Output Shape: {}", output.shape)


if __name__ == "__main__":
main()
60 changes: 60 additions & 0 deletions crates/burn-import/safetensors-tests/tests/batch_norm/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use burn::{
module::Module,
nn::{BatchNorm, BatchNormConfig},
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Net<B: Backend> {
norm1: BatchNorm<B, 2>,
}

impl<B: Backend> Net<B> {
pub fn new(device: &B::Device) -> Self {
Self {
norm1: BatchNormConfig::new(4).init(device),
}
}

/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.norm1.forward(x)
}
}

#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;

use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::safetensors::SafeTensorsFileRecorder;

use super::*;

#[test]
fn batch_norm2d() {
let device = Default::default();
let record = SafeTensorsFileRecorder::<FullPrecisionSettings>::default()
.load("tests/batch_norm/batch_norm2d.safetensors".into(), &device)
.expect("Should decode state successfully");

let model = Net::<Backend>::new(&device).load_record(record);

let input = Tensor::<Backend, 4>::ones([1, 5, 2, 2], &device) - 0.3;

let output = model.forward(input);

let expected = Tensor::<Backend, 4>::from_data(
[[
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
]],
&device,
);

output.to_data().assert_approx_eq(&expected.to_data(), 5);
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python3

import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import save_file


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
buffer = torch.tensor([True, False, True])
self.register_buffer("buffer", buffer, persistent=True)

def forward(self, x):
x = self.buffer
return x


def main():

torch.set_printoptions(precision=8)
torch.manual_seed(1)

model = Model().to(torch.device("cpu"))

save_file(model.state_dict(), "boolean.safetensors")

input = torch.ones(3, 3)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)


if __name__ == "__main__":
main()
58 changes: 58 additions & 0 deletions crates/burn-import/safetensors-tests/tests/boolean/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use burn::{
module::{Module, Param},
tensor::{backend::Backend, Bool, Tensor},
};

#[derive(Module, Debug)]
pub struct Net<B: Backend> {
buffer: Param<Tensor<B, 1, Bool>>,
}

impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
Self {
buffer: record.buffer,
}
}

/// Forward pass of the model.
pub fn forward(&self, _x: Tensor<B, 2>) -> Tensor<B, 1, Bool> {
self.buffer.val()
}
}

#[cfg(test)]
mod tests {

use burn::{
record::{FullPrecisionSettings, Recorder},
tensor::TensorData,
};
use burn_import::safetensors::SafeTensorsFileRecorder;

use super::*;

type Backend = burn_ndarray::NdArray<f32>;

#[test]
#[ignore = "It appears loading boolean tensors are not supported yet"]
// Error skipping: Msg("unsupported storage type BoolStorage")
fn boolean() {
let device = Default::default();
let record = SafeTensorsFileRecorder::<FullPrecisionSettings>::default()
.load("tests/boolean/boolean.safetensors".into(), &device)
.expect("Should decode state successfully");

let model = Net::<Backend>::new_with(record);

let input = Tensor::<Backend, 2>::ones([3, 3], &device);

let output = model.forward(input);

let expected =
Tensor::<Backend, 1, Bool>::from_bool(TensorData::from([true, false, true]), &device);

assert_eq!(output.to_data(), expected.to_data());
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python3

import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import save_file


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
buffer = torch.ones(3, 3)
self.register_buffer("buffer", buffer, persistent=True)

def forward(self, x):
x = self.buffer + x
return x


def main():

torch.set_printoptions(precision=8)
torch.manual_seed(1)

model = Model().to(torch.device("cpu"))

save_file(model.state_dict(), "buffer.safetensors")

input = torch.ones(3, 3)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)


if __name__ == "__main__":
main()
51 changes: 51 additions & 0 deletions crates/burn-import/safetensors-tests/tests/buffer/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use burn::{
module::{Module, Param},
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Net<B: Backend> {
buffer: Param<Tensor<B, 2>>,
}

impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
Self {
buffer: record.buffer,
}
}

/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
self.buffer.val() + x
}
}

#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;

use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::safetensors::SafeTensorsFileRecorder;

use super::*;

#[test]
fn buffer() {
let device = Default::default();
let record = SafeTensorsFileRecorder::<FullPrecisionSettings>::default()
.load("tests/buffer/buffer.safetensors".into(), &device)
.expect("Should decode state successfully");

let model = Net::<Backend>::new_with(record);

let input = Tensor::<Backend, 2>::ones([3, 3], &device);

let output = model.forward(input);

let expected = Tensor::<Backend, 2>::ones([3, 3], &device) * 2.0;

output.to_data().assert_approx_eq(&expected.to_data(), 3);
}
}
Binary file not shown.
Loading
Loading