Skip to content

Commit

Permalink
wip: remove granite specifics
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Dec 23, 2024
1 parent e017917 commit da9a01a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
10 changes: 3 additions & 7 deletions optimum/neuron/models/granite/hlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def embedding(self, input_ids, cache_ids, start_ids, last_token_id, *weights):
hidden = hlo.slice_along(hidden, dim=-1, limit=self.config.hidden_size, start=0)
if self.neuron_config.attention_layout == LAYOUT_HSB:
hidden = hlo.transpose210(hidden)
# Granite specific: embeddings are multiplied by embedding_multiplier
return hidden * self.config.embedding_multiplier
return hidden

def token_tree_embedding(
self, input_ids, cache_ids, start_ids, last_token_id, previous_cache_ids, reorder_mapping, *weights
Expand Down Expand Up @@ -328,8 +327,7 @@ def layer(
tp_degree=self.config.tp_degree,
neuron_config=self.neuron_config,
)
# Granite specific: MLP output is multiplied by residual_multiplier
res_hidden = hlo.add(mlp_hidden * self.config.residual_multiplier, hidden)
res_hidden = hlo.add(mlp_hidden, hidden)
return res_hidden, out_attn_k_cache, out_attn_v_cache

def token_tree_layer(
Expand Down Expand Up @@ -660,9 +658,7 @@ def attention(
)

# Q = Q / sqrt(d_head)
# Granite specific: instead of dividing the QK product, multiply it by the attention_multiplier
# query = attention.scale(query, d_head)
query = query * self.config.attention_multiplier
query = attention.scale(query, d_head)

# In BSH cache layout, the output of QKV linear projection is still kept as SBH for all QKV.
bsh_cache_layout = False
Expand Down
2 changes: 0 additions & 2 deletions optimum/neuron/models/granite/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,5 @@ def forward(self, input_ids, cache_ids=None, start_ids=None, last_token_id=None,
# either input_embeddings are generated (off device embedding), or input_ids will be padded from preprocess_and_embed (on device embedding)
inputs = input_embeddings if input_embeddings is not None else input_ids
logits = self._forward(inputs, *rst)
# Granite specific: divide logits by scaling factor
logits = logits / self.config.logits_scaling
logits = self._postprocess(logits, start_ids=start_ids, **kwargs)
return logits

0 comments on commit da9a01a

Please sign in to comment.