Skip to content

Commit 445e0d5

Browse files
author
liyang
committed
fix: Gemma2/Gemma3 inference issue caused by ln_1/ln_2 keys (switch to ln1/ln2)
1 parent af8d0bf commit 445e0d5

File tree

2 files changed

+55
-68
lines changed

2 files changed

+55
-68
lines changed

convert_hf_to_gguf.py

Lines changed: 53 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4897,7 +4897,7 @@ def _xlmroberta_set_vocab(self) -> None:
48974897
with open(tokenizer_config_path, "r", encoding="utf-8") as fp:
48984898
tokenizer_config_json = json.load(fp)
48994899

4900-
add_prefix = tokenizer.add_prefix_space
4900+
add_prefix = getattr(tokenizer, "add_prefix_space", False)
49014901
remove_whitespaces = tokenizer.clean_up_tokenization_spaces
49024902
precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"])
49034903

@@ -6426,79 +6426,68 @@ def __init__(self, *args, **kwargs):
64266426

64276427
# Load config for vision encoder
64286428
config_path = self.dir_model / "config.json"
6429-
if config_path.exists():
6430-
with open(config_path, encoding="utf-8") as f:
6431-
self.vision_config = json.load(f)
6432-
else:
6433-
# Default JinaCLIP v2 vision configuration
6434-
self.vision_config = {
6435-
"image_size": 448,
6436-
"patch_size": 14,
6437-
"hidden_size": 1024,
6438-
"num_hidden_layers": 24,
6439-
"num_attention_heads": 16,
6440-
"intermediate_size": 2731,
6441-
"layer_norm_eps": 1e-5,
6442-
"projection_dim": 1024
6443-
}
6429+
if not config_path.exists():
6430+
raise FileNotFoundError(
6431+
f"JinaCLIPVisionModel: missing config.json in {self.dir_model}. "
6432+
"Please ensure the original model config is present; default hyperparameter fallbacks are not used."
6433+
)
6434+
with open(config_path, encoding="utf-8") as f:
6435+
self.vision_config = json.load(f)
64446436

64456437
def set_vocab(self):
64466438
# Vision encoder doesn't need vocabulary
64476439
pass
64486440

64496441
def set_gguf_parameters(self):
6450-
# Identification (arch/name is set by writer); mark vision encoder presence
6451-
self.gguf_writer.add_clip_has_vision_encoder(True)
6442+
cfg = self.vision_config
64526443

6453-
# Vision parameters
6454-
config = self.vision_config
6455-
img_sz = int(config.get("image_size", 448))
6456-
patch_sz = int(config.get("patch_size", 14))
6457-
n_embd = int(config.get("hidden_size", 1024))
6458-
n_layer = int(config.get("num_hidden_layers", 24))
6459-
n_head = int(config.get("num_attention_heads", 16))
6460-
n_ff = int(config.get("intermediate_size", 2731))
6461-
proj_dim = int(config.get("projection_dim", 1024))
6462-
6463-
# Use gguf writer helpers (constants + typed setters)
6464-
self.gguf_writer.add_vision_image_size(img_sz)
6465-
self.gguf_writer.add_vision_patch_size(patch_sz)
6466-
self.gguf_writer.add_vision_embedding_length(n_embd)
6467-
self.gguf_writer.add_vision_block_count(n_layer)
6444+
try:
6445+
width = int(cfg["width"]) # channel dim
6446+
head_width = int(cfg["head_width"]) # per-head dim
6447+
layers = int(cfg["layers"]) # block count
6448+
image_size = int(cfg["image_size"]) # input image size
6449+
patch_size = int(cfg["patch_size"]) # patch size
6450+
except KeyError as e:
6451+
raise KeyError(f"JinaCLIPVisionModel: missing key in config.json: {e}")
6452+
6453+
if width % head_width != 0:
6454+
raise ValueError(
6455+
f"JinaCLIPVisionModel: width ({width}) not divisible by head_width ({head_width})"
6456+
)
6457+
n_head = width // head_width
6458+
6459+
if "mlp_ratio" in cfg:
6460+
n_ff = int(width * float(cfg["mlp_ratio"]))
6461+
elif bool(cfg.get("naive_swiglu", False)):
6462+
n_ff = int((width * 8) // 3)
6463+
else:
6464+
raise ValueError("JinaCLIPVisionModel: unable to infer FFN size; please provide 'mlp_ratio' or set 'naive_swiglu' in config.json")
6465+
6466+
self.gguf_writer.add_clip_has_vision_encoder(True)
6467+
proj_dim = int(cfg.get("projection_dim", width))
64686468
self.gguf_writer.add_vision_projection_dim(proj_dim)
6469-
self.gguf_writer.add_vision_feed_forward_length(n_ff)
6469+
6470+
self.gguf_writer.add_vision_image_size(image_size)
6471+
self.gguf_writer.add_vision_patch_size(patch_size)
6472+
self.gguf_writer.add_vision_embedding_length(width)
6473+
self.gguf_writer.add_vision_block_count(layers)
64706474
self.gguf_writer.add_vision_head_count(n_head)
6471-
# LayerNorm epsilon comes from config (fallback 1e-5)
6472-
eps_attn = float(config.get("layer_norm_eps", 1e-5))
6473-
self.gguf_writer.add_vision_attention_layernorm_eps(eps_attn)
6475+
self.gguf_writer.add_vision_feed_forward_length(n_ff)
6476+
6477+
self.gguf_writer.add_vision_attention_layernorm_eps(float(cfg.get("layer_norm_eps", 1e-5)))
64746478

6475-
# Preprocessing defaults
6476-
self.gguf_writer.add_vision_image_mean([0.48145466, 0.4578275, 0.40821073])
6477-
self.gguf_writer.add_vision_image_std ([0.26862954, 0.26130258, 0.27577711])
6479+
mean = self.preprocessor_config.get("image_mean", self.preprocessor_config.get("mean"))
6480+
std = self.preprocessor_config.get("image_std", self.preprocessor_config.get("std"))
6481+
if mean is None or std is None:
6482+
raise KeyError(
6483+
"JinaCLIPVisionModel: preprocessor_config missing image mean/std (expected keys: 'image_mean'/'image_std' or 'mean'/'std')"
6484+
)
6485+
self.gguf_writer.add_vision_image_mean(mean)
6486+
self.gguf_writer.add_vision_image_std(std)
64786487

6479-
# Projector type and activation
6480-
# JinaCLIP v2 projector type string follows upstream style (family+major)
64816488
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JINACLIP2)
64826489
self.gguf_writer.add_vision_use_silu(True)
64836490

6484-
# RoPE parameter used by vision encoder (prefer config override)
6485-
try:
6486-
rt = config.get("rope_theta", None)
6487-
rope_theta = float(rt) if rt is not None else 10000.0
6488-
except Exception:
6489-
rope_theta = 10000.0
6490-
# writer currently has no dedicated setter for this key; keep direct write
6491-
self.gguf_writer.add_float32("clip.vision.rope_theta", rope_theta)
6492-
6493-
# Compatibility (mmproj) — not covered by gguf writer helpers yet
6494-
self.gguf_writer.add_uint32("mmproj.embedding_length", n_embd)
6495-
self.gguf_writer.add_uint32("mmproj.block_count", n_layer)
6496-
6497-
logger.info(
6498-
"mmproj(jinaclip): image_size=%d patch_size=%d n_embd=%d n_layer=%d n_head=%d n_ff=%d proj_dim=%d",
6499-
img_sz, patch_sz, n_embd, n_layer, n_head, n_ff, proj_dim
6500-
)
6501-
65026491
# helpers to keep modify_tensors compact and consistent with other models
65036492
def _strip_vm_prefix(self, name: str) -> str:
65046493
return name[len('vision_model.'):] if name.startswith('vision_model.') else name
@@ -6508,10 +6497,10 @@ def _map_block_tensor(self, layer: int, rest: str, data_torch: Tensor, name: str
65086497
# layer norms
65096498
if rest.startswith('norm1.'):
65106499
suffix = parts[-1]
6511-
return [(f'v.blk.{layer}.ln_1.{suffix}', data_torch)]
6500+
return [(f'v.blk.{layer}.ln1.{suffix}', data_torch)]
65126501
if rest.startswith('norm2.'):
65136502
suffix = parts[-1]
6514-
return [(f'v.blk.{layer}.ln_2.{suffix}', data_torch)]
6503+
return [(f'v.blk.{layer}.ln2.{suffix}', data_torch)]
65156504
if rest.startswith('attn.inner_attn_ln.'):
65166505
suffix = parts[-1]
65176506
return [(f'v.blk.{layer}.attn_ln.{suffix}', data_torch)]
@@ -6625,7 +6614,6 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
66256614
if model_path is None:
66266615
raise FileNotFoundError(f"mmproj(jinaclip): no model weights found in {self.dir_model}")
66276616

6628-
logger.info("mmproj(jinaclip): loading weights from %s", model_path)
66296617
if model_path.suffix == ".bin":
66306618
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
66316619
else:
@@ -6638,7 +6626,6 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
66386626
yield name, tensor
66396627
count += 1
66406628

6641-
logger.info("mmproj(jinaclip): yielded %d raw tensors", count)
66426629

66436630
def _should_be_f32(self, gguf_name: str) -> bool:
66446631
"""Return True if tensor should be stored as F32 to avoid type mismatches in C++ runtime.
@@ -6647,8 +6634,8 @@ def _should_be_f32(self, gguf_name: str) -> bool:
66476634
binary-op dtype issues; patch embedding bias is also safer as F32.
66486635
"""
66496636
patterns = (
6650-
".ln_1.weight", ".ln_1.bias",
6651-
".ln_2.weight", ".ln_2.bias",
6637+
".ln1.weight", ".ln1.bias",
6638+
".ln2.weight", ".ln2.bias",
66526639
".attn_ln.weight", ".attn_ln.bias",
66536640
".ffn_norm.weight", ".ffn_norm.bias",
66546641
"v.patch_embd.proj.bias",

tools/mtmd/clip-impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@
7676
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
7777
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
7878
#define TN_FFN_NORM "%s.blk.%d.ffn_norm.%s"
79-
#define TN_LN_1 "%s.blk.%d.ln_1.%s" // layer norm
80-
#define TN_LN_2 "%s.blk.%d.ln_2.%s" // layer norm
79+
#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm
80+
#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm
8181
#define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale
8282
#define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale
8383
#define TN_LN_PRE "%s.pre_ln.%s"

0 commit comments

Comments
 (0)