Skip to content

Commit 7b986ed

Browse files
Qualcomm AI Engine Direct - GA Static Gemma-2b-instruct
Summary: - e2e script for Gemma-2b-it in static llama version - add model params file & model weight converter
1 parent 5ae40f1 commit 7b986ed

File tree

11 files changed

+267
-11
lines changed

11 files changed

+267
-11
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4762,6 +4762,65 @@ def test_qnn_backend_seq_mse(self):
47624762

47634763

47644764
class TestExampleLLMScript(TestQNN):
4765+
def test_static_gemma_2b(self):
4766+
if not self.required_envs():
4767+
self.skipTest("missing required envs")
4768+
4769+
prompt = "My favourite condiment is "
4770+
cmds = [
4771+
"python",
4772+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
4773+
"--artifact",
4774+
self.artifact_dir,
4775+
"--build_folder",
4776+
self.build_folder,
4777+
"--model",
4778+
self.model,
4779+
"--ip",
4780+
self.ip,
4781+
"--port",
4782+
str(self.port),
4783+
"--prompt",
4784+
f"{prompt}",
4785+
"--decoder_model",
4786+
"gemma-2b",
4787+
"--model_mode",
4788+
"kv",
4789+
"--max_seq_len",
4790+
"1024",
4791+
"--eval_perplexity",
4792+
"--tasks",
4793+
"wikitext",
4794+
"--limit",
4795+
"1",
4796+
]
4797+
if self.compile_only:
4798+
cmds.extend(["--compile_only"])
4799+
elif self.device:
4800+
cmds.extend(["--device", self.device])
4801+
if self.host:
4802+
cmds.extend(["--host", self.host])
4803+
elif self.enable_x86_64:
4804+
cmds.extend(["--enable_x86_64"])
4805+
if self.pre_gen_pte:
4806+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
4807+
4808+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
4809+
with Listener((self.ip, self.port)) as listener:
4810+
conn = listener.accept()
4811+
p.communicate()
4812+
msg = json.loads(conn.recv())
4813+
if "Error" in msg:
4814+
self.fail(msg["Error"])
4815+
else:
4816+
inference_speed_ref = {"SM8650": 32, "SM8750": 36}
4817+
self.assertLessEqual(msg["wiki_ppl"], 35)
4818+
self.assertLessEqual(msg["pte_size"], 2_700_000_000) # 2.7GB
4819+
if self.model in inference_speed_ref:
4820+
self.assertGreaterEqual(
4821+
msg["inference_speed"], inference_speed_ref[self.model]
4822+
)
4823+
47654824
def test_static_gemma3_1b(self):
47664825
if not self.required_envs():
47674826
self.skipTest("missing required envs")

examples/models/gemma/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This source code is licensed under the BSD-style license found in the
2+
# LICENSE file in the root directory of this source tree.
3+
4+
from executorch.examples.models.gemma.convert_weights import convert_weights
5+
from executorch.examples.models.llama.model import Llama2Model
6+
7+
8+
class GemmaModel(Llama2Model):
9+
def __init__(self, **kwargs):
10+
super().__init__(**kwargs)
11+
12+
13+
__all__ = [
14+
"GemmaModel",
15+
"convert_weights",
16+
]
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"dim": 2048,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 16384,
5+
"n_heads": 8,
6+
"head_dim": 256,
7+
"n_kv_heads": 1,
8+
"n_layers": 18,
9+
"act_fn": "gelu",
10+
"norm_type": "gemma3",
11+
"norm_eps": 1e-06,
12+
"rope_theta": 10000.0,
13+
"use_scaled_rope": false,
14+
"apply_embedding": true,
15+
"embedding_scale_factor": 45.254833995939045,
16+
"vocab_size": 256000,
17+
"use_hf_rope": true,
18+
"attention_qkv_bias": false
19+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import argparse
2+
3+
import json
4+
import os
5+
from typing import Dict
6+
7+
import torch
8+
from safetensors.torch import load_file
9+
10+
from torchtune.models.convert_weights import get_mapped_key
11+
12+
13+
# Weight mappings from Gemma's checkpoint to ExecuTorch's transformer parameters.
14+
_GEMMA_TO_EXECUTORCH = {
15+
"model.embed_tokens.weight": "tok_embeddings.weight",
16+
"model.norm.weight": "norm.weight",
17+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
18+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
19+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
20+
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
21+
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
22+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
23+
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
24+
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
25+
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
26+
}
27+
28+
29+
def gemma_to_executorch(
30+
state_dict: Dict[str, torch.Tensor]
31+
) -> Dict[str, torch.Tensor]:
32+
"""
33+
Convert the state dict so that it matches what ExecuTorch's transformer definition expects.
34+
"""
35+
converted_state_dict = {}
36+
for key, value in state_dict.items():
37+
new_key = get_mapped_key(key, _GEMMA_TO_EXECUTORCH)
38+
converted_state_dict[new_key] = value
39+
converted_state_dict["output.weight"] = converted_state_dict[
40+
"tok_embeddings.weight"
41+
]
42+
return converted_state_dict
43+
44+
45+
def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
46+
index_path = os.path.join(input_dir, "model.safetensors.index.json")
47+
if os.path.exists(index_path):
48+
# Sharded checkpoint.
49+
with open(index_path, "r") as f:
50+
index = json.load(f)
51+
weight_map = index["weight_map"]
52+
checkpoint_shards = sorted(set(weight_map.values()))
53+
54+
# Load all the shards into memory
55+
shard_to_weights = {}
56+
for shard in checkpoint_shards:
57+
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))
58+
59+
# Merge tensors into consolidated state dict.
60+
merged_state_dict = {}
61+
for weight_name, shard in weight_map.items():
62+
tensor = shard_to_weights[shard][weight_name]
63+
merged_state_dict[weight_name] = tensor
64+
return merged_state_dict
65+
else:
66+
# Single checkpoint.
67+
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
68+
return state_dict
69+
70+
71+
def load_checkpoint(input_dir: str) -> Dict:
72+
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
73+
if os.path.exists(pytorch_path):
74+
print("Loading checkpoint from PyTorch .bin file")
75+
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
76+
print("Loading checkpoint from safetensors directory")
77+
return load_checkpoint_from_safetensors(input_dir)
78+
79+
80+
def convert_weights(input_dir: str, output_file: str) -> None:
81+
print("Loading checkpoint...")
82+
sd = load_checkpoint(input_dir)
83+
print("Converting checkpoint...")
84+
sd = gemma_to_executorch(sd)
85+
print("Saving checkpoint...")
86+
torch.save(sd, output_file)
87+
print("Done.")
88+
89+
90+
def main():
91+
parser = argparse.ArgumentParser(
92+
description="Convert Gemma weights to ExecuTorch transformer format."
93+
)
94+
parser.add_argument(
95+
"input_dir",
96+
type=str,
97+
help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.",
98+
)
99+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
100+
101+
args = parser.parse_args()
102+
convert_weights(args.input_dir, args.output)
103+
104+
105+
if __name__ == "__main__":
106+
main()

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@ This file provides you the instructions to run LLM Decoder model with different
55
1. LLAMA2 Stories 110M
66
2. LLAMA3.2 1B
77
3. LLAMA3.2 3B
8-
4. Gemma3 1B
9-
5. Phi4-mini-instruct
10-
6. QWEN2.5 0.5B / 1.5B
11-
7. QWEN3 0.6B / 1.7B
12-
8. SmolLM2 135M
13-
9. SmolLM3 3B
8+
4. Gemma 2B
9+
5. Gemma3 1B
10+
6. Phi4-mini-instruct
11+
7. QWEN2.5 0.5B / 1.5B
12+
8. QWEN3 0.6B / 1.7B
13+
9. SmolLM2 135M
14+
10. SmolLM3 3B
1415

1516

1617
We offer the following modes to execute the model:
@@ -78,6 +79,13 @@ Default example using kv mode.
7879
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
7980
```
8081

82+
#### Gemma 2B
83+
Default example using hybrid mode
84+
```bash
85+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma-2b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
86+
```
87+
88+
8189
#### Gemma3 1B
8290
Default example using hybrid mode
8391
```bash

examples/qualcomm/oss_scripts/llama/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
2626

27+
from executorch.examples.models.gemma import convert_weights as convert_gemma_weights
2728
from executorch.examples.models.gemma3 import convert_weights as convert_gemma3_weights
2829
from executorch.examples.models.phi_4_mini import (
2930
convert_weights as convert_phi_4_mini_weights,
@@ -300,6 +301,36 @@ class Llama3_2_3B_Instruct(LLMModelConfig):
300301
)
301302

302303

304+
@register_llm_model("gemma-2b")
305+
@dataclass(init=False, frozen=True)
306+
class Gemma_2B(LLMModelConfig):
307+
repo_id: str = "google/gemma-2b-it"
308+
params_path: str = os.path.join(
309+
BASE_DIR, "../../../models/gemma/config/2b_config.json"
310+
)
311+
convert_weights = convert_gemma_weights
312+
transform_weight = False
313+
instruct_model = True
314+
315+
num_sharding = 4
316+
# quant config
317+
ptq = QuantDtype.use_16a4w_block
318+
group_size = 64
319+
masked_softmax = True
320+
seq_mse_candidates = 0
321+
r1 = False
322+
r2 = False
323+
r3 = False
324+
quantization_config_wv_sha_16a8w = get_ptq_per_channel_quant_config(
325+
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
326+
)
327+
custom_annotation = (
328+
annotate_kv_8bit,
329+
annotate_output_16a8w,
330+
partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_16a8w),
331+
)
332+
333+
303334
@register_llm_model("gemma3-1b")
304335
@dataclass(init=False, frozen=True)
305336
class Gemma3(LLMModelConfig):

examples/qualcomm/oss_scripts/llama/decoder_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
DECODER_MODEL_VERSION = {
1515
"stories260k": "llama2",
1616
"stories110m": "llama2",
17+
"gemma-2b": "gemma",
1718
"gemma3-1b": "gemma3",
1819
"phi_4_mini": "phi_4_mini",
1920
"llama3_2-1b_instruct": "llama3",

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,13 @@ def quantize(
325325
chat_template, args.prompt[0], args.system_prompt
326326
)
327327
)
328+
329+
# Gemma may produce unexpected output if the prompt contains an extra <bos> token.
330+
# This can happen after applying a prompt template, which might inject <bos> unintentionally.
331+
# To prevent decoding issues, we explicitly remove <bos> token
332+
if chat_template and args.decoder_model in {"gemma-2b", "gemma3-1b"}:
333+
prompt = prompt.replace("<bos>", "")
334+
328335
graph_module_inference(
329336
use_kv_cache=self.llama_meta["get_use_kv_cache"],
330337
get_example_inputs=self.get_example_inputs,
@@ -538,14 +545,13 @@ def compile(
538545
state_dict = torch.load(
539546
checkpoint, weights_only=True, map_location="cpu", mmap=True
540547
)
541-
if args.decoder_model == "gemma3-1b":
548+
if args.decoder_model in {"gemma-2b", "gemma3-1b"}:
542549
for k, v in state_dict.items():
543550
if "norm" not in k:
544551
continue
545552
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
546553
# See https://github.com/huggingface/transformers/pull/29402
547554
state_dict[k] = v.float() + torch.ones(v.shape, dtype=torch.float32)
548-
549555
else:
550556
state_dict = torch.load(
551557
args.checkpoint, weights_only=True, map_location="cpu", mmap=True
@@ -1284,7 +1290,11 @@ def export_llama(args) -> None:
12841290
)
12851291
tokenizer_artifacts = tokenizer.save_pretrained(args.artifact)
12861292
tokenizer_config = tokenizer_artifacts[0]
1287-
runtime_tokenizer_path = tokenizer_artifacts[-1]
1293+
if args.decoder_model == "gemma-2b":
1294+
# For Gemma, use tokenizer.model as it doesn't provide pre_tokenizer in tokenizer.json.
1295+
runtime_tokenizer_path = tokenizer_artifacts[-3]
1296+
else:
1297+
runtime_tokenizer_path = tokenizer_artifacts[-1]
12881298
tokenizer = get_tokenizer(runtime_tokenizer_path, tokenizer_config)
12891299

12901300
# TODO: Remove this once error is resolved.

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
/**
1010
* @file
1111
*
12-
* This tool can run Llama2 110M, Llama3.2 1B / 3B, Gemma3 1B,
12+
* This tool can run Llama2 110M, Llama3.2 1B / 3B, Gemma 2B, Gemma3 1B,
1313
* phi4-mini-instruct, Qwen2.5 0.5B / 1.5B, Qwen3 0.6B / 1.7B, SmolLM2 135M,
1414
* SmolLM3 3B with Qualcomm AI Engine Direct.
1515
*
@@ -117,6 +117,7 @@ std::string get_formatted_prompt(
117117
formatted_prompt.append(
118118
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
119119
break;
120+
case example::DecoderModelVersion::kGemma:
120121
case example::DecoderModelVersion::kGemma3:
121122
formatted_prompt.append("<start_of_turn>user\n");
122123
formatted_prompt.append(prompt);

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ Runner<T>::Runner(
122122
decoder_model_version_ = DecoderModelVersion::kLlama2;
123123
} else if (decoder_model_version == "llama3") {
124124
decoder_model_version_ = DecoderModelVersion::kLlama3;
125+
} else if (decoder_model_version == "gemma") {
126+
decoder_model_version_ = DecoderModelVersion::kGemma;
125127
} else if (decoder_model_version == "gemma3") {
126128
decoder_model_version_ = DecoderModelVersion::kGemma3;
127129
cache_mode_ = CacheMode::HybridCache;
@@ -199,7 +201,9 @@ Error Runner<T>::load() {
199201
decoder_model_version_ == DecoderModelVersion::kSmollm2_135m ||
200202
decoder_model_version_ == DecoderModelVersion::kSmollm3) {
201203
eos_ids->insert(tokenizer_->encode("<|im_end|>", 0, 0).get()[0]);
202-
} else if (decoder_model_version_ == DecoderModelVersion::kGemma3) {
204+
} else if (
205+
decoder_model_version_ == DecoderModelVersion::kGemma ||
206+
decoder_model_version_ == DecoderModelVersion::kGemma3) {
203207
eos_ids->insert(tokenizer_->encode("<end_of_turn>", 0, 0).get()[0]);
204208
}
205209

0 commit comments

Comments
 (0)