-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add support for safetensors in pytorch reader #2721
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #2721 +/- ##
=======================================
Coverage 83.60% 83.60%
=======================================
Files 819 819
Lines 106600 106605 +5
=======================================
+ Hits 89124 89129 +5
Misses 17476 17476 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the addition 🙏
Looks pretty good overall, just some minor comments.
.map(|(key, tensor)| (key, CandleTensor(tensor))) | ||
.collect(); | ||
//check if it's a safetensors file | ||
let is_safetensors = path.extension().is_some_and(|ext| ext == "safetensors"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if that's a very robust way to differentiate between both 😅 pretty sure I've seen a lot of .safetensor
files (without the plural form).
I think we should add a field to the LoadArgs
instead, and users can then specify when they're loading a safetensor file:
LoadArgs::new(...).with_safetensors(true)
@@ -17,7 +17,7 @@ fn main() { | |||
|
|||
// Load PyTorch weights into a model record. | |||
let record: model::ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default() | |||
.load("pytorch/mnist.pt".into(), &device) | |||
.load("pytorch/mnist.safetensors".into(), &device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To check if we should load the pickle file or safetensor, we can add a safetensors
feature flag to the example and check it here with something like:
let ext = if std::env::var("CARGO_FEATURE_SAFETENSORS").is_ok() {
"safetensors"
} else {
"pt"
};
and then load the correct file.
We should update the README with a small mention.
IMO maybe we can have something like |
I think this is a good point, and it also builds the scaffolding for potentially rewriting it to remove the Candle dependency. |
I agree that the format is not strongly related to pytorch, but I think most models available in safetensor format are pytorch models 😅 Unless you mean supporting the safetensor format as another recorder to load and save modules. In this case, not sure that this is a meaningful addition. |
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
#626
Changes
Simple addition to the already implemented reader.rs, supporting safetsensors format using candle with CPU device import.
Testing
in the examples/pytorch-import directory, there is a mnist.safetensors file that is successfully imported.