Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for safetensors in pytorch reader #2721

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

wandbrandon
Copy link

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#626

Changes

Simple addition to the already implemented reader.rs, supporting safetsensors format using candle with CPU device import.

Testing

in the examples/pytorch-import directory, there is a mnist.safetensors file that is successfully imported.

Copy link

codecov bot commented Jan 20, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 83.60%. Comparing base (140ea75) to head (6a0330e).

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2721   +/-   ##
=======================================
  Coverage   83.60%   83.60%           
=======================================
  Files         819      819           
  Lines      106600   106605    +5     
=======================================
+ Hits        89124    89129    +5     
  Misses      17476    17476           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the addition 🙏

Looks pretty good overall, just some minor comments.

.map(|(key, tensor)| (key, CandleTensor(tensor)))
.collect();
//check if it's a safetensors file
let is_safetensors = path.extension().is_some_and(|ext| ext == "safetensors");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if that's a very robust way to differentiate between both 😅 pretty sure I've seen a lot of .safetensor files (without the plural form).

I think we should add a field to the LoadArgs instead, and users can then specify when they're loading a safetensor file:

LoadArgs::new(...).with_safetensors(true)

@@ -17,7 +17,7 @@ fn main() {

// Load PyTorch weights into a model record.
let record: model::ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("pytorch/mnist.pt".into(), &device)
.load("pytorch/mnist.safetensors".into(), &device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To check if we should load the pickle file or safetensor, we can add a safetensors feature flag to the example and check it here with something like:

let ext = if std::env::var("CARGO_FEATURE_SAFETENSORS").is_ok() {
  "safetensors"
} else {
  "pt"
};

and then load the correct file.

We should update the README with a small mention.

@Nikaidou-Shinku
Copy link
Contributor

IMO maybe we can have something like pub mod safetensors; under a new feature gate in crate burn-import so users could build exactly what they need, since safetensors does not seem to be a format strongly related to PyTorch.

@wandbrandon
Copy link
Author

IMO maybe we can have something like pub mod safetensors; under a new feature gate in crate burn-import so users could build exactly what they need, since safetensors does not seem to be a format strongly related to PyTorch.

I think this is a good point, and it also builds the scaffolding for potentially rewriting it to remove the Candle dependency.

@laggui
Copy link
Member

laggui commented Jan 21, 2025

IMO maybe we can have something like pub mod safetensors; under a new feature gate in crate burn-import so users could build exactly what they need, since safetensors does not seem to be a format strongly related to PyTorch.

I agree that the format is not strongly related to pytorch, but I think most models available in safetensor format are pytorch models 😅

Unless you mean supporting the safetensor format as another recorder to load and save modules. In this case, not sure that this is a meaningful addition.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants