Skip to content

Commit ce91ef6

Browse files
committed
feat(decoder): add support for granite models
1 parent cb1a4d0 commit ce91ef6

File tree

6 files changed

+1276
-0
lines changed

6 files changed

+1276
-0
lines changed

optimum/exporters/neuron/model_configs/decoder_configs.py

+7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from optimum.exporters.tasks import TasksManager
1919

20+
from ....neuron.models.granite.model import GraniteForSampling
2021
from ....neuron.models.qwen2.model import Qwen2ForSampling
2122
from ..config import TextNeuronDecoderConfig
2223

@@ -63,3 +64,9 @@ class Qwen2NeuronConfig(TextNeuronDecoderConfig):
6364
NEURONX_CLASS = Qwen2ForSampling
6465
CONTINUOUS_BATCHING = True
6566
FUSE_QKV = False
67+
68+
69+
@register_in_tasks_manager("granite", "text-generation")
70+
class GraniteNeuronConfig(TextNeuronDecoderConfig):
71+
NEURONX_CLASS = GraniteForSampling
72+
CONTINUOUS_BATCHING = True
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from transformers import PretrainedConfig
17+
from transformers_neuronx.llama.config import LlamaConfig
18+
19+
20+
class GraniteConfig(LlamaConfig):
21+
"""The Granite model uses the same configuration as the TnX LLama model"""
22+
23+
def __init__(
24+
self, config: PretrainedConfig, n_positions: int, batch_size: int, amp: str, tp_degree: int, **kwargs
25+
):
26+
super().__init__(config, n_positions, batch_size, amp, tp_degree, **kwargs)
27+
self.model_type = "granite"
28+
# These are parameters specific to the granite modeling
29+
self.attention_multiplier = config.attention_multiplier
30+
self.embedding_multiplier = config.embedding_multiplier
31+
self.logits_scaling = config.logits_scaling
32+
self.residual_multiplier = config.residual_multiplier

0 commit comments

Comments
 (0)