Skip to content

Commit 64231d0

Browse files
authored
Merge pull request #225 from stanfordnlp/olmo2
Add OLMo2 and Qwen3 models
2 parents 4f812ef + 1d83080 commit 64231d0

File tree

5 files changed

+181
-0
lines changed

5 files changed

+181
-0
lines changed

pyvene/models/intervenable_modelcard.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from .llava.modelings_intervenable_llava import *
1515
from .qwen2.modelings_intervenable_qwen2 import *
1616
from .olmo.modelings_intervenable_olmo import *
17+
from .olmo2.modelings_intervenable_olmo2 import *
18+
from .qwen3.modelings_intervenable_qwen3 import *
1719
from .esm.modelings_intervenable_esm import *
1820
from .mllama.modelings_intervenable_mllama import *
1921
from .gpt_oss.modelings_intervenable_gpt_oss import *
@@ -67,6 +69,10 @@
6769
hf_models.gemma2.modeling_gemma2.Gemma2ForCausalLM: gemma2_lm_type_to_module_mapping,
6870
hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_module_mapping,
6971
hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_module_mapping,
72+
hf_models.olmo2.modeling_olmo2.Olmo2Model: olmo2_type_to_module_mapping,
73+
hf_models.olmo2.modeling_olmo2.Olmo2ForCausalLM: olmo2_lm_type_to_module_mapping,
74+
hf_models.qwen3.modeling_qwen3.Qwen3Model: qwen3_type_to_module_mapping,
75+
hf_models.qwen3.modeling_qwen3.Qwen3ForCausalLM: qwen3_lm_type_to_module_mapping,
7076
hf_models.esm.modeling_esm.EsmModel: esm_type_to_module_mapping,
7177
hf_models.esm.modeling_esm.EsmForMaskedLM: esm_mlm_type_to_module_mapping,
7278
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_module_mapping,
@@ -109,6 +115,10 @@
109115
hf_models.gemma2.modeling_gemma2.Gemma2ForCausalLM: gemma2_lm_type_to_dimension_mapping,
110116
hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_dimension_mapping,
111117
hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_dimension_mapping,
118+
hf_models.olmo2.modeling_olmo2.Olmo2Model: olmo2_type_to_dimension_mapping,
119+
hf_models.olmo2.modeling_olmo2.Olmo2ForCausalLM: olmo2_lm_type_to_dimension_mapping,
120+
hf_models.qwen3.modeling_qwen3.Qwen3Model: qwen3_type_to_dimension_mapping,
121+
hf_models.qwen3.modeling_qwen3.Qwen3ForCausalLM: qwen3_lm_type_to_dimension_mapping,
112122
hf_models.esm.modeling_esm.EsmModel: esm_type_to_dimension_mapping,
113123
hf_models.esm.modeling_esm.EsmForMaskedLM: esm_mlm_type_to_dimension_mapping,
114124
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_dimension_mapping,

pyvene/models/olmo2/__init__.py

Whitespace-only changes.
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""
2+
Each modeling file in this library is a mapping between
3+
abstract naming of intervention anchor points and actual
4+
model module defined in the huggingface library.
5+
6+
We also want to let the intervention library know how to
7+
config the dimensions of intervention based on model config
8+
defined in the huggingface library.
9+
"""
10+
11+
12+
import torch
13+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
14+
from ..constants import *
15+
16+
17+
olmo2_type_to_module_mapping = {
18+
"block_input": ("layers[%s]", CONST_INPUT_HOOK),
19+
"block_output": ("layers[%s]", CONST_OUTPUT_HOOK),
20+
"mlp_activation": ("layers[%s].mlp.act_fn", CONST_OUTPUT_HOOK),
21+
"mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK),
22+
"mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK),
23+
"attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
24+
"head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
25+
"attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK),
26+
"attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK),
27+
"query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
28+
"key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
29+
"value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
30+
"head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
31+
"head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
32+
"head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
33+
}
34+
35+
36+
olmo2_type_to_dimension_mapping = {
37+
"n_head": ("num_attention_heads",),
38+
"n_kv_head": ("num_key_value_heads",),
39+
"block_input": ("hidden_size",),
40+
"block_output": ("hidden_size",),
41+
"mlp_activation": ("intermediate_size",),
42+
"mlp_output": ("hidden_size",),
43+
"mlp_input": ("hidden_size",),
44+
"attention_value_output": ("hidden_size",),
45+
"head_attention_value_output": ("hidden_size/num_attention_heads",),
46+
"attention_output": ("hidden_size",),
47+
"attention_input": ("hidden_size",),
48+
"query_output": ("hidden_size",),
49+
"key_output": ("hidden_size",),
50+
"value_output": ("hidden_size",),
51+
"head_query_output": ("hidden_size/num_attention_heads",),
52+
"head_key_output": ("hidden_size/num_attention_heads",),
53+
"head_value_output": ("hidden_size/num_attention_heads",),
54+
}
55+
56+
57+
"""olmo2 model with LM head"""
58+
olmo2_lm_type_to_module_mapping = {}
59+
for k, v in olmo2_type_to_module_mapping.items():
60+
olmo2_lm_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]
61+
62+
63+
olmo2_lm_type_to_dimension_mapping = olmo2_type_to_dimension_mapping
64+
65+
66+
"""olmo2 model with classifier head"""
67+
olmo2_classifier_type_to_module_mapping = {}
68+
for k, v in olmo2_type_to_module_mapping.items():
69+
olmo2_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]
70+
71+
72+
olmo2_classifier_type_to_dimension_mapping = olmo2_type_to_dimension_mapping
73+
74+
75+
def create_olmo2(
76+
name="allenai/OLMo-2-1124-7B", cache_dir=None, dtype=torch.bfloat16, config=None,
77+
revision='main'
78+
):
79+
"""Creates a OLMo2 Causal LM model, config, and tokenizer from the given name and revision"""
80+
if config is None:
81+
config = AutoConfig.from_pretrained(name, cache_dir=cache_dir)
82+
olmo2 = AutoModelForCausalLM.from_pretrained(
83+
name,
84+
config=config,
85+
cache_dir=cache_dir,
86+
torch_dtype=dtype,
87+
revision=revision
88+
)
89+
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir)
90+
else:
91+
olmo2 = AutoModelForCausalLM(config, cache_dir=cache_dir, revision=revision)
92+
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir)
93+
print("loaded model")
94+
return config, tokenizer, olmo2

pyvene/models/qwen3/__init__.py

Whitespace-only changes.
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""
2+
Each modeling file in this library is a mapping between
3+
abstract naming of intervention anchor points and actual
4+
model module defined in the huggingface library.
5+
We also want to let the intervention library know how to
6+
config the dimensions of intervention based on model config
7+
defined in the huggingface library.
8+
"""
9+
import torch
10+
from ..constants import *
11+
12+
qwen3_type_to_module_mapping = {
13+
"block_input": ("layers[%s]", CONST_INPUT_HOOK),
14+
"block_output": ("layers[%s]", CONST_OUTPUT_HOOK),
15+
"mlp_activation": ("layers[%s].mlp.act_fn", CONST_OUTPUT_HOOK),
16+
"mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK),
17+
"mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK),
18+
"attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
19+
"head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
20+
"attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK),
21+
"attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK),
22+
"query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
23+
"key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
24+
"value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
25+
"head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
26+
"head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
27+
"head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
28+
}
29+
30+
qwen3_type_to_dimension_mapping = {
31+
"n_head": ("num_attention_heads",),
32+
"n_kv_head": ("num_key_value_heads",),
33+
"block_input": ("hidden_size",),
34+
"block_output": ("hidden_size",),
35+
"mlp_activation": ("intermediate_size",),
36+
"mlp_output": ("hidden_size",),
37+
"mlp_input": ("hidden_size",),
38+
"attention_value_output": ("hidden_size",),
39+
"head_attention_value_output": ("hidden_size/num_attention_heads",),
40+
"attention_output": ("hidden_size",),
41+
"attention_input": ("hidden_size",),
42+
"query_output": ("hidden_size",),
43+
"key_output": ("hidden_size",),
44+
"value_output": ("hidden_size",),
45+
"head_query_output": ("hidden_size/num_attention_heads",),
46+
"head_key_output": ("hidden_size/num_attention_heads",),
47+
"head_value_output": ("hidden_size/num_attention_heads",),
48+
}
49+
50+
"""qwen3 model with LM head"""
51+
qwen3_lm_type_to_module_mapping = {}
52+
for k, v in qwen3_type_to_module_mapping.items():
53+
qwen3_lm_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]
54+
qwen3_lm_type_to_dimension_mapping = qwen3_type_to_dimension_mapping
55+
56+
"""qwen3 model with classifier head"""
57+
qwen3_classifier_type_to_module_mapping = {}
58+
for k, v in qwen3_type_to_module_mapping.items():
59+
qwen3_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]
60+
qwen3_classifier_type_to_dimension_mapping = qwen3_type_to_dimension_mapping
61+
62+
def create_qwen3(
63+
name="Qwen/Qwen3-8B", cache_dir=None, dtype=torch.bfloat16
64+
):
65+
"""Creates a Causal LM model, config, and tokenizer from the given name and revision"""
66+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
67+
68+
config = AutoConfig.from_pretrained(name, cache_dir=cache_dir)
69+
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir)
70+
model = AutoModelForCausalLM.from_pretrained(
71+
name,
72+
config=config,
73+
cache_dir=cache_dir,
74+
torch_dtype=dtype,
75+
)
76+
print("loaded model")
77+
return config, tokenizer, model

0 commit comments

Comments
 (0)