Skip to content

Commit

Permalink
Allow locally downloaded models for QwenVL
Browse files Browse the repository at this point in the history
  • Loading branch information
bartowski1182 authored Dec 14, 2024
1 parent e52aba5 commit b01af27
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion examples/llava/qwen2_vl_surgery.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def main(args):
else:
raise ValueError()

local_model = False
model_name = args.model_name
print("model_name: ", model_name)
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
Expand All @@ -97,8 +98,10 @@ def main(args):
vcfg = cfg.vision_config

if os.path.isdir(model_name):
local_model = True
if model_name.endswith(os.sep):
model_name = model_name[:-1]
model_path = model_name
model_name = os.path.basename(model_name)
fname_out = f"{model_name.replace('/', '-').lower()}-vision.gguf"

Expand Down Expand Up @@ -139,7 +142,10 @@ def main(args):
it will be hardcoded in the `clip_image_build_graph` from `clip.cpp`.
"""

processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name)
if local_model:
processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_path)

Check failure on line 146 in examples/llava/qwen2_vl_surgery.py

View workflow job for this annotation

GitHub Actions / pyright type-check

"model_path" is possibly unbound (reportPossiblyUnboundVariable)
else:
processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name)
fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) # type: ignore[reportAttributeAccessIssue]
fout.add_array("clip.vision.image_std", processor.image_processor.image_std) # type: ignore[reportAttributeAccessIssue]

Expand Down

0 comments on commit b01af27

Please sign in to comment.