Skip to content

Commit 9cefdb2

Browse files
committed
Prevent duplicate downloading of models
1 parent bfe8001 commit 9cefdb2

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

taggui/auto_captioning/auto_captioning_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ def replace_template_variables(text: str, image: Image) -> str:
3737

3838
class AutoCaptioningModel:
3939
dtype = torch.float16
40+
# When loading a model, if the `use_safetensors` argument is not set and
41+
# both a safetensors and a non-safetensors version of the model are
42+
# available, both versions get downloaded. This should be set to `None` for
43+
# models that do not have a safetensors version.
44+
use_safetensors = True
4045
model_load_context_manager = nullcontext()
4146
transformers_model_class = AutoModelForVision2Seq
4247
image_mode = 'RGB'
@@ -90,7 +95,8 @@ def get_processor(self):
9095
trust_remote_code=True)
9196

9297
def get_model_load_arguments(self) -> dict:
93-
arguments = {'device_map': self.device, 'trust_remote_code': True}
98+
arguments = {'device_map': self.device, 'trust_remote_code': True,
99+
'use_safetensors': self.use_safetensors}
94100
if self.load_in_4_bit:
95101
quantization_config = BitsAndBytesConfig(
96102
load_in_4bit=True,

taggui/auto_captioning/models/florence_2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66

77
class Florence2(AutoCaptioningModel):
8+
use_safetensors = None
89
transformers_model_class = AutoModelForCausalLM
910
task_prompts = [
1011
'<CAPTION>',
@@ -30,6 +31,7 @@ def get_default_prompt(self) -> str:
3031

3132

3233
class Florence2Promptgen(Florence2):
34+
use_safetensors = True
3335
task_prompts = [
3436
'<GENERATE_PROMPT>',
3537
'<CAPTION>',

0 commit comments

Comments
 (0)