From 2317050239d852792f23b8b72324584545194ad7 Mon Sep 17 00:00:00 2001 From: Ultr4_dev Date: Tue, 13 Aug 2024 00:00:43 +0200 Subject: [PATCH 1/2] Update torch.load() call with weights_only=True --- whisper/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisper/__init__.py b/whisper/__init__.py index d7fbba36f..7525f9baa 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -143,7 +143,7 @@ def load_model( with ( io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") ) as fp: - checkpoint = torch.load(fp, map_location=device) + checkpoint = torch.load(fp, map_location=device, weights_only=True) del checkpoint_file dims = ModelDimensions(**checkpoint["dims"]) From 895e4fb88ebc5f3c0cc0eb50d5edf40bbaf27340 Mon Sep 17 00:00:00 2001 From: Ultr4_dev Date: Tue, 13 Aug 2024 00:02:11 +0200 Subject: [PATCH 2/2] Add weights_only parameter to load_model function and extent docstring --- whisper/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/whisper/__init__.py b/whisper/__init__.py index 7525f9baa..3e081e973 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -101,6 +101,7 @@ def load_model( device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False, + weights_only: bool = False, ) -> Whisper: """ Load a Whisper ASR model @@ -116,6 +117,8 @@ def load_model( path to download the model files; by default, it uses "~/.cache/whisper" in_memory: bool whether to preload the model weights into host memory + weights_only: bool + whether to load only the model weights Returns ------- @@ -143,7 +146,7 @@ def load_model( with ( io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") ) as fp: - checkpoint = torch.load(fp, map_location=device, weights_only=True) + checkpoint = torch.load(fp, map_location=device, weights_only=weights_only) del checkpoint_file dims = ModelDimensions(**checkpoint["dims"])