Skip to content

Commit

Permalink
feat(decoder): add support for granite models (3)
Browse files Browse the repository at this point in the history
Apply granite specific multipliers.
  • Loading branch information
dacorvo committed Dec 23, 2024
1 parent 42a462f commit c7829f6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
10 changes: 7 additions & 3 deletions optimum/neuron/models/granite/hlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def embedding(self, input_ids, cache_ids, start_ids, last_token_id, *weights):
hidden = hlo.slice_along(hidden, dim=-1, limit=self.config.hidden_size, start=0)
if self.neuron_config.attention_layout == LAYOUT_HSB:
hidden = hlo.transpose210(hidden)
return hidden
# Granite specific: embeddings are multiplied by embedding_multiplier
return hidden * self.config.embedding_multiplier

def token_tree_embedding(
self, input_ids, cache_ids, start_ids, last_token_id, previous_cache_ids, reorder_mapping, *weights
Expand Down Expand Up @@ -327,7 +328,8 @@ def layer(
tp_degree=self.config.tp_degree,
neuron_config=self.neuron_config,
)
res_hidden = hlo.add(mlp_hidden, hidden)
# Granite specific: MLP output is multiplied by residual_multiplier
res_hidden = hlo.add(mlp_hidden * self.config.residual_multiplier, hidden)
return res_hidden, out_attn_k_cache, out_attn_v_cache

def token_tree_layer(
Expand Down Expand Up @@ -658,7 +660,9 @@ def attention(
)

# Q = Q / sqrt(d_head)
query = attention.scale(query, d_head)
# Granite specific: instead of dividing the QK product, multiply it by the attention_multiplier
# query = attention.scale(query, d_head)
query = query * self.config.attention_multiplier

# In BSH cache layout, the output of QKV linear projection is still kept as SBH for all QKV.
bsh_cache_layout = False
Expand Down
2 changes: 2 additions & 0 deletions optimum/neuron/models/granite/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,5 +297,7 @@ def forward(self, input_ids, cache_ids=None, start_ids=None, last_token_id=None,
# either input_embeddings are generated (off device embedding), or input_ids will be padded from preprocess_and_embed (on device embedding)
inputs = input_embeddings if input_embeddings is not None else input_ids
logits = self._forward(inputs, *rst)
# Granite specific: divide logits by scaling factor
logits = logits / self.config.logits_scaling
logits = self._postprocess(logits, start_ids=start_ids, **kwargs)
return logits

0 comments on commit c7829f6

Please sign in to comment.