Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
305 changes: 293 additions & 12 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4897,7 +4897,7 @@ def _xlmroberta_set_vocab(self) -> None:
with open(tokenizer_config_path, "r", encoding="utf-8") as fp:
tokenizer_config_json = json.load(fp)

add_prefix = tokenizer.add_prefix_space
add_prefix = getattr(tokenizer, "add_prefix_space", False)
remove_whitespaces = tokenizer.clean_up_tokenization_spaces
precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"])

Expand Down Expand Up @@ -5183,7 +5183,18 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,

if lora_names := hparams.get("lora_adaptations"):
self._lora_names = lora_names
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3

try:
text_cfg = hparams.get("text_config", {}) if isinstance(hparams.get("text_config", {}), dict) else {}
pe_type = (text_cfg.get("position_embedding_type") or hparams.get("position_embedding_type") or "").lower()
rope_base = text_cfg.get("rotary_emb_base", hparams.get("rotary_emb_base"))
name_path = (hparams.get("_name_or_path") or "").lower()
is_vx = ("jina" in name_path and ("v2" in name_path or "v3" in name_path))
is_v3 = (pe_type == "rotary" or rope_base is not None) and is_vx
if (is_v3) or self._lora_names:
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3
except Exception:
pass

super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
self._xlmroberta_tokenizer_init()
Expand Down Expand Up @@ -6405,6 +6416,271 @@ def set_vocab(self):
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')


@ModelBase.register("JinaCLIPVisionModel", "JinaCLIPModel")
class JinaCLIPVisionModel(MmprojModel):
"""JinaCLIP v2 Vision Encoder Model - handles vision component only"""
model_arch = gguf.MODEL_ARCH.MMPROJ

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Load config for vision encoder
config_path = self.dir_model / "config.json"
if not config_path.exists():
raise FileNotFoundError(
f"JinaCLIPVisionModel: missing config.json in {self.dir_model}. "
"Please ensure the original model config is present; default hyperparameter fallbacks are not used."
)
with open(config_path, encoding="utf-8") as f:
self.vision_config = json.load(f)

def set_vocab(self):
# Vision encoder doesn't need vocabulary
pass

def set_gguf_parameters(self):
cfg = self.vision_config

try:
width = int(cfg["width"]) # channel dim
head_width = int(cfg["head_width"]) # per-head dim
layers = int(cfg["layers"]) # block count
image_size = int(cfg["image_size"]) # input image size
patch_size = int(cfg["patch_size"]) # patch size
except KeyError as e:
raise KeyError(f"JinaCLIPVisionModel: missing key in config.json: {e}")

if width % head_width != 0:
raise ValueError(
f"JinaCLIPVisionModel: width ({width}) not divisible by head_width ({head_width})"
)
n_head = width // head_width

if "mlp_ratio" in cfg:
n_ff = int(width * float(cfg["mlp_ratio"]))
elif bool(cfg.get("naive_swiglu", False)):
n_ff = int((width * 8) // 3)
else:
raise ValueError("JinaCLIPVisionModel: unable to infer FFN size; please provide 'mlp_ratio' or set 'naive_swiglu' in config.json")

self.gguf_writer.add_clip_has_vision_encoder(True)
proj_dim = int(cfg.get("projection_dim", width))
self.gguf_writer.add_vision_projection_dim(proj_dim)

self.gguf_writer.add_vision_image_size(image_size)
self.gguf_writer.add_vision_patch_size(patch_size)
self.gguf_writer.add_vision_embedding_length(width)
self.gguf_writer.add_vision_block_count(layers)
self.gguf_writer.add_vision_head_count(n_head)
self.gguf_writer.add_vision_feed_forward_length(n_ff)

self.gguf_writer.add_vision_attention_layernorm_eps(float(cfg.get("layer_norm_eps", 1e-5)))

mean = self.preprocessor_config.get("image_mean", self.preprocessor_config.get("mean"))
std = self.preprocessor_config.get("image_std", self.preprocessor_config.get("std"))
if mean is None or std is None:
raise KeyError(
"JinaCLIPVisionModel: preprocessor_config missing image mean/std (expected keys: 'image_mean'/'image_std' or 'mean'/'std')"
)
self.gguf_writer.add_vision_image_mean(mean)
self.gguf_writer.add_vision_image_std(std)

self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JINACLIP2)
self.gguf_writer.add_vision_use_silu(True)

# helpers to keep modify_tensors compact and consistent with other models
def _strip_vm_prefix(self, name: str) -> str:
return name[len('vision_model.'):] if name.startswith('vision_model.') else name

def _map_block_tensor(self, layer: int, rest: str, data_torch: Tensor, name: str) -> list[tuple[str, Tensor]] | None:
parts = rest.split('.')
# layer norms
if rest.startswith('norm1.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ln1.{suffix}', data_torch)]
if rest.startswith('norm2.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ln2.{suffix}', data_torch)]
if rest.startswith('attn.inner_attn_ln.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.attn_ln.{suffix}', data_torch)]

# fused qkv
if rest == 'attn.qkv.weight':
w = data_torch
wdim = w.shape[0]
if wdim % 3 != 0:
logger.warning('mmproj(jinaclip): unexpected qkv weight shape %s for %s', tuple(w.shape), name)
d = wdim // 3
q, k, v = w[0:d, :], w[d:2 * d, :], w[2 * d:, :]
return [
(f'v.blk.{layer}.attn_q.weight', q),
(f'v.blk.{layer}.attn_k.weight', k),
(f'v.blk.{layer}.attn_v.weight', v),
]
if rest == 'attn.qkv.bias':
b = data_torch
bdim = b.shape[0]
if bdim % 3 != 0:
logger.warning('mmproj(jinaclip): unexpected qkv bias shape %s for %s', tuple(b.shape), name)
d = bdim // 3
qb, kb, vb = b[0:d], b[d:2 * d], b[2 * d:]
return [
(f'v.blk.{layer}.attn_q.bias', qb),
(f'v.blk.{layer}.attn_k.bias', kb),
(f'v.blk.{layer}.attn_v.bias', vb),
]
# separate q/v bias (some checkpoints)
if rest == 'attn.q_bias':
return [(f'v.blk.{layer}.attn_q.bias', data_torch)]
if rest == 'attn.v_bias':
return [(f'v.blk.{layer}.attn_v.bias', data_torch)]

# separate projections
if rest.startswith('attn.q_proj.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.attn_q.{suffix}', data_torch)]
if rest.startswith('attn.k_proj.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.attn_k.{suffix}', data_torch)]
if rest.startswith('attn.v_proj.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.attn_v.{suffix}', data_torch)]
if rest.startswith('attn.proj.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.attn_out.{suffix}', data_torch)]

# MLP
if rest.startswith('mlp.w1.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ffn_gate.{suffix}', data_torch)]
if rest.startswith('mlp.w2.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ffn_up.{suffix}', data_torch)]
if rest.startswith('mlp.w3.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ffn_down.{suffix}', data_torch)]
if rest.startswith('mlp.ffn_ln.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ffn_norm.{suffix}', data_torch)]
if rest.startswith('mlp.fc1.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ffn_up.{suffix}', data_torch)]
if rest.startswith('mlp.fc2.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ffn_down.{suffix}', data_torch)]
return None

def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
"""Prefer base table-driven mapping; keep Jina-specific targets if already mapped; fallback to legacy mapper."""
# Already a GGUF target name (e.g., "v.*" or "mm.*"): return as-is
if name.startswith('v.') or name.startswith('mm.'):
return name
# Try the base mapping first
try:
return super().map_tensor_name(name, try_suffixes=try_suffixes)
except Exception:
# Fallback to legacy Jina-specific mapper for any remaining edge keys
if hasattr(self, "_map_jinaclip_tensor_name"):
mapped = self._map_jinaclip_tensor_name(name) # type: ignore[attr-defined]
if mapped:
return mapped
return name

def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
"""Yield tensors for the vision encoder.
Prefer the base implementation (supports sharded/indexed weights). If that fails
or no parts are detected, fall back to a direct single-file load.
"""
yielded_any = False
try:
for name, tensor in super().get_tensors():
yielded_any = True
yield name, tensor
except Exception as e:
logger.warning("mmproj(jinaclip): base get_tensors failed, falling back: %s", e)
if yielded_any:
return

import torch
candidates = [
self.dir_model / "pytorch_model.bin",
self.dir_model / "vision_model_weights.bin",
]
model_path = next((p for p in candidates if p.exists()), None)
if model_path is None:
raise FileNotFoundError(f"mmproj(jinaclip): no model weights found in {self.dir_model}")
try:
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
except TypeError:
state_dict = torch.load(model_path, map_location="cpu")

for name, tensor in state_dict.items():
yield name, tensor

def _should_be_f32(self, gguf_name: str) -> bool:
"""Return True if tensor should be stored as F32 to avoid type mismatches in C++ runtime.

Keep the list minimal: LayerNorm weights/bias are the common source of
binary-op dtype issues; patch embedding bias is also safer as F32.
"""
patterns = (
".ln1.weight", ".ln1.bias",
".ln2.weight", ".ln2.bias",
".attn_ln.weight", ".attn_ln.bias",
".ffn_norm.weight", ".ffn_norm.bias",
"v.patch_embd.proj.bias",
)
return any(p in gguf_name for p in patterns)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
"""Normalize JinaCLIP vision tensor names to base-friendly patterns, with Jina-specific exceptions.

- Emit Jina-specific targets directly for: patch/proj, pos_embed, inner-attn LN, SwiGLU FFN names.
- If fused QKV is encountered, split into Q/K/V.
- For standard pieces (norm1/norm2, q/k/v/out), map to v.blk.{i}.* targets.
"""
del bid # unused

src = name
if src.startswith('v.') or src.startswith('mm.'):
return [(src, data_torch)]

# Drop 'vision_model.' prefix if present
src_no_vm = self._strip_vm_prefix(src)

# Top-level direct mappings — use gguf constants directly for canonical names
if src_no_vm == 'cls_token':
base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_CLS]
return [(base, data_torch)]
if src_no_vm.startswith('patch_embed.proj.'):
suffix = src_no_vm.split('.')[-1]
base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH]
return [(f'{base}.{suffix}', data_torch)]
if src_no_vm == 'pos_embed':
pos_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_POS] + '.weight'
return [(pos_name, data_torch)]
if src_no_vm.startswith('norm.'):
suffix = src_no_vm.split('.')[-1]
base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_POST_NORM]
return [(f'{base}.{suffix}', data_torch)]

if src_no_vm.startswith('blocks.'):
parts = src_no_vm.split('.')
if len(parts) >= 3 and parts[1].isdigit():
layer = int(parts[1])
rest = '.'.join(parts[2:])
mapped = self._map_block_tensor(layer, rest, data_torch, name)
if mapped is not None:
return mapped

try:
return [(self.map_tensor_name(name), data_torch)]
except Exception:
logger.debug("mmproj(jinaclip): skip unmapped tensor %s", name)
return []


@ModelBase.register("OpenELMForCausalLM")
class OpenELMModel(TextModel):
model_arch = gguf.MODEL_ARCH.OPENELM
Expand Down Expand Up @@ -9789,16 +10065,21 @@ def main() -> None:
else:
model_class = MistralModel

model_instance = model_class(dir_model, output_type, fname_out,
is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
eager=args.no_lazy,
metadata_override=args.metadata, model_name=args.model_name,
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
)
common_kwargs = dict(
is_big_endian=args.bigendian,
use_temp_file=args.use_temp_file,
eager=args.no_lazy,
metadata_override=args.metadata,
model_name=args.model_name,
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size),
dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=hf_repo_id,
disable_mistral_community_chat_template=disable_mistral_community_chat_template,
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules,
)
model_instance = model_class(dir_model, output_type, fname_out, **common_kwargs)

if args.vocab_only:
logger.info("Exporting model vocab...")
Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3057,6 +3057,7 @@ class VisionProjectorType:
QWEN25VL = "qwen2.5vl_merger"
ULTRAVOX = "ultravox"
INTERNVL = "internvl"
JINACLIP2 = "jinaclip2"
QWEN2A = "qwen2a" # audio
QWEN25O = "qwen2.5o" # omni
VOXTRAL = "voxtral"
Expand Down
Loading
Loading