Skip to content

Commit 8e7174b

Browse files
committed
[P0] Adding gpt-oss support
1 parent 3c6cb78 commit 8e7174b

File tree

8 files changed

+463
-9
lines changed

8 files changed

+463
-9
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ license = { text = "Apache License 2.0" }
1818
requires-python = ">=3.9"
1919
dependencies = [
2020
"torch>=2.0.0",
21-
"transformers>=4.45.1",
21+
"transformers>=4.55.0.dev0",
2222
"tokenizers>=0.20.0",
2323
"datasets>=3.0.1",
2424
"protobuf>=3.20.0",

pyvene/__init__.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22

33
# Generic APIs
44
from .data_generators.causal_model import CausalModel
5-
from .models.intervenable_base import IntervenableModel, IntervenableNdifModel, build_intervenable_model
5+
from .models.intervenable_base import (
6+
IntervenableModel,
7+
IntervenableNdifModel,
8+
build_intervenable_model,
9+
)
610
from .models.configuration_intervenable_model import IntervenableConfig
711
from .models.configuration_intervenable_model import RepresentationConfig
812

@@ -37,7 +41,10 @@
3741
# Utils
3842
from .models.basic_utils import *
3943
from .models.intervention_utils import _do_intervention_by_swap
40-
from .models.intervenable_modelcard import type_to_module_mapping, type_to_dimension_mapping
44+
from .models.intervenable_modelcard import (
45+
type_to_module_mapping,
46+
type_to_dimension_mapping,
47+
)
4148
from .models.gpt2.modelings_intervenable_gpt2 import create_gpt2
4249
from .models.gpt2.modelings_intervenable_gpt2 import create_gpt2_lm
4350
from .models.blip.modelings_intervenable_blip import create_blip
@@ -51,5 +58,8 @@
5158
from .models.gru.modelings_gru import GRUConfig
5259
from .models.llama.modelings_intervenable_llama import create_llama
5360
from .models.mlp.modelings_intervenable_mlp import create_mlp_classifier
54-
from .models.backpack_gpt2.modelings_intervenable_backpack_gpt2 import create_backpack_gpt2
61+
from .models.backpack_gpt2.modelings_intervenable_backpack_gpt2 import (
62+
create_backpack_gpt2,
63+
)
5564
from .models.olmo.modelings_intervenable_olmo import create_olmo
65+
from .models.gpt_oss.modelings_intervenable_gpt_oss import create_gpt_oss

pyvene/models/README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
## How to add new models?
2+
3+
You can prompt a LM to generate files, or modifying existing ones in this folder by simply following these steps:
4+
5+
- Get the relevent implementation file from `https://github.com/huggingface/transformers/blob/main/src/transformers/models/` (e.g., the implementation for `gpt-oss` [here](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_oss/modeling_gpt_oss.py)).
6+
7+
- Copy the whole transformer model src file.
8+
9+
- Create a new folder for your new model.
10+
11+
- Move one of the existing model file to your new folder (e.g., `/gpt2/modelings_intervenable_gpt2.py` along with the default `__init__.py` file).
12+
13+
- Prompt a language model with the following template:
14+
15+
```text
16+
[YOUR_EXAMPLE_PYVENE_MODEL_FILE_COPY]
17+
18+
Generate a new mapping file based on the existing one above for the following new model:
19+
20+
[HF_TRANSFORMER_MODEL_SRC_FILE_COPY]
21+
22+
You also need to pay attention to these details:
23+
- [OTHER_REQ_GOES_HERE] (e.g., you need to take care of the MoE strcuture)
24+
```

pyvene/models/gpt_oss/__init__.py

Whitespace-only changes.
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""
2+
Each modeling file in this library is a mapping between
3+
abstract naming of intervention anchor points and actual
4+
model module defined in the huggingface library.
5+
6+
We also want to let the intervention library know how to
7+
config the dimensions of intervention based on model config
8+
defined in the huggingface library.
9+
"""
10+
11+
from ..constants import *
12+
13+
14+
"""gpt-oss base model"""
15+
gpt_oss_type_to_module_mapping = {
16+
"block_input": ("layers[%s]", CONST_INPUT_HOOK),
17+
"block_output": ("layers[%s]", CONST_OUTPUT_HOOK),
18+
"mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK),
19+
"mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK),
20+
"router_input": ("layers[%s].mlp.router", CONST_INPUT_HOOK),
21+
"router_output": ("layers[%s].mlp.router", CONST_OUTPUT_HOOK),
22+
"expert_input": ("layers[%s].mlp.experts", CONST_INPUT_HOOK),
23+
"expert_output": ("layers[%s].mlp.experts", CONST_OUTPUT_HOOK),
24+
"attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK),
25+
"attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK),
26+
"attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
27+
"head_attention_value_output": (
28+
"layers[%s].self_attn.o_proj",
29+
CONST_INPUT_HOOK,
30+
(split_head_and_permute, "num_attention_heads"),
31+
),
32+
"query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
33+
"key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
34+
"value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
35+
"head_query_output": (
36+
"layers[%s].self_attn.q_proj",
37+
CONST_OUTPUT_HOOK,
38+
(split_head_and_permute, "num_attention_heads"),
39+
),
40+
"head_key_output": (
41+
"layers[%s].self_attn.k_proj",
42+
CONST_OUTPUT_HOOK,
43+
(split_head_and_permute, "num_key_value_heads"),
44+
),
45+
"head_value_output": (
46+
"layers[%s].self_attn.v_proj",
47+
CONST_OUTPUT_HOOK,
48+
(split_head_and_permute, "num_key_value_heads"),
49+
),
50+
}
51+
52+
53+
gpt_oss_type_to_dimension_mapping = {
54+
"num_attention_heads": ("num_attention_heads",),
55+
"num_key_value_heads": ("num_key_value_heads",),
56+
"num_local_experts": ("num_local_experts",),
57+
"num_experts_per_tok": ("num_experts_per_tok",),
58+
"block_input": ("hidden_size",),
59+
"block_output": ("hidden_size",),
60+
"mlp_input": ("hidden_size",),
61+
"mlp_output": ("hidden_size",),
62+
"router_input": ("hidden_size",),
63+
"router_output": ("num_local_experts",),
64+
"expert_input": ("hidden_size",),
65+
"expert_output": ("hidden_size",),
66+
"attention_input": ("hidden_size",),
67+
"attention_output": ("hidden_size",),
68+
"attention_value_output": ("hidden_size",),
69+
"head_attention_value_output": ("hidden_size/num_attention_heads",),
70+
"query_output": ("hidden_size",),
71+
"key_output": ("hidden_size",),
72+
"value_output": ("hidden_size",),
73+
"head_query_output": ("hidden_size/num_attention_heads",),
74+
"head_key_output": ("hidden_size/num_key_value_heads",),
75+
"head_value_output": ("hidden_size/num_key_value_heads",),
76+
}
77+
78+
79+
"""gpt-oss model with LM head"""
80+
gpt_oss_lm_type_to_module_mapping = {}
81+
for k, v in gpt_oss_type_to_module_mapping.items():
82+
gpt_oss_lm_type_to_module_mapping[k] = (f"model.{v[0]}",) + v[1:]
83+
84+
gpt_oss_lm_type_to_dimension_mapping = gpt_oss_type_to_dimension_mapping
85+
86+
87+
def create_gpt_oss(name="openai/gpt-oss-20b", cache_dir=None, access_token=None):
88+
"""Creates a GPT-OSS model, config, and tokenizer from the given name and revision"""
89+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
90+
91+
config = AutoConfig.from_pretrained(name, cache_dir=cache_dir, token=access_token)
92+
tokenizer = AutoTokenizer.from_pretrained(
93+
name, cache_dir=cache_dir, token=access_token
94+
)
95+
gpt_oss = AutoModelForCausalLM.from_pretrained(
96+
name, cache_dir=cache_dir, token=access_token
97+
)
98+
print("loaded model")
99+
return config, tokenizer, gpt_oss

pyvene/models/intervenable_modelcard.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
from .blip.modelings_intervenable_blip_itm import *
1313
from .backpack_gpt2.modelings_intervenable_backpack_gpt2 import *
1414
from .llava.modelings_intervenable_llava import *
15-
from .qwen2.modelings_intervenable_qwen2 import *
15+
from .qwen2.modelings_intervenable_qwen2 import *
1616
from .olmo.modelings_intervenable_olmo import *
1717
from .esm.modelings_intervenable_esm import *
1818
from .mllama.modelings_intervenable_mllama import *
19+
from .gpt_oss.modelings_intervenable_gpt_oss import *
1920

2021
#########################################################################
2122
"""
@@ -65,7 +66,7 @@
6566
hf_models.gemma2.modeling_gemma2.Gemma2Model: gemma2_type_to_module_mapping,
6667
hf_models.gemma2.modeling_gemma2.Gemma2ForCausalLM: gemma2_lm_type_to_module_mapping,
6768
hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_module_mapping,
68-
hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_module_mapping,
69+
hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_module_mapping,
6970
hf_models.esm.modeling_esm.EsmModel: esm_type_to_module_mapping,
7071
hf_models.esm.modeling_esm.EsmForMaskedLM: esm_mlm_type_to_module_mapping,
7172
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_module_mapping,
@@ -80,6 +81,8 @@
8081
hf_models.qwen2.modeling_qwen2.Qwen2ForCausalLM: qwen2_lm_type_to_module_mapping,
8182
hf_models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification: qwen2_classifier_type_to_module_mapping,
8283
hf_models.mllama.modeling_mllama.MllamaForConditionalGeneration: mllama_type_to_module_mapping,
84+
hf_models.gpt_oss.modeling_gpt_oss.GptOssModel: gpt_oss_type_to_module_mapping,
85+
hf_models.gpt_oss.modeling_gpt_oss.GptOssForCausalLM: gpt_oss_lm_type_to_module_mapping,
8386
}
8487
if enable_blip:
8588
type_to_module_mapping[BlipWrapper] = blip_wrapper_type_to_module_mapping
@@ -105,7 +108,7 @@
105108
hf_models.gemma2.modeling_gemma2.Gemma2Model: gemma2_type_to_dimension_mapping,
106109
hf_models.gemma2.modeling_gemma2.Gemma2ForCausalLM: gemma2_lm_type_to_dimension_mapping,
107110
hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_dimension_mapping,
108-
hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_dimension_mapping,
111+
hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_dimension_mapping,
109112
hf_models.esm.modeling_esm.EsmModel: esm_type_to_dimension_mapping,
110113
hf_models.esm.modeling_esm.EsmForMaskedLM: esm_mlm_type_to_dimension_mapping,
111114
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_dimension_mapping,
@@ -120,9 +123,13 @@
120123
hf_models.qwen2.modeling_qwen2.Qwen2ForCausalLM: qwen2_lm_type_to_dimension_mapping,
121124
hf_models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification: qwen2_classifier_type_to_dimension_mapping,
122125
hf_models.mllama.modeling_mllama.MllamaForConditionalGeneration: mllama_type_to_dimension_mapping,
126+
hf_models.gpt_oss.modeling_gpt_oss.GptOssModel: gpt_oss_type_to_dimension_mapping,
127+
hf_models.gpt_oss.modeling_gpt_oss.GptOssForCausalLM: gpt_oss_lm_type_to_dimension_mapping,
123128
}
124129

125130
if enable_blip:
126131
type_to_dimension_mapping[BlipWrapper] = blip_wrapper_type_to_dimension_mapping
127-
type_to_dimension_mapping[BlipITMWrapper] = blip_itm_wrapper_type_to_dimension_mapping
132+
type_to_dimension_mapping[BlipITMWrapper] = (
133+
blip_itm_wrapper_type_to_dimension_mapping
134+
)
128135
#########################################################################

tutorials/basic_tutorials/Basic_Intervention.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"id": "89f31a38",
1414
"metadata": {},
1515
"source": [
16-
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/frankaging/pyvene/blob/main/tutorials/basic_tutorials/Basic_Intervention.ipynb)"
16+
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/stanfordnlp/pyvene/blob/main/tutorials/basic_tutorials/Basic_Intervention.ipynb)"
1717
]
1818
},
1919
{

0 commit comments

Comments
 (0)