Skip to content

Commit

Permalink
feat(decoder): add support for granite models (2)
Browse files Browse the repository at this point in the history
Apply granite specific multipliers.
  • Loading branch information
dacorvo committed Dec 23, 2024
1 parent 4293b52 commit fc34974
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
21 changes: 19 additions & 2 deletions optimum/neuron/models/granite/hlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
from .config import GraniteConfig


def scale_mul(t, scale):
"""Multiply a tensor by a float scale"""
dtype = t.dtype
# Convert float to a constant scalar tensor of the target dtype
scale_t = dtype.Constant(constant_value=scale)
# Expand the scalar tensor to the target shape
scale_br_t = dtype[t.sizes].Broadcast(scale_t, dimensions=[])
return dtype[t.sizes].Multiply(t, scale_br_t)


class GraniteForSamplingNoEmbeddingHlo:

def __init__(self, config: GraniteConfig, neuron_config: Optional[NeuronConfig] = None):
Expand Down Expand Up @@ -118,6 +128,9 @@ def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, *weights):
else:
block_to_seq = None

# Granite specific: embeddings are multiplied by embedding_multiplier
hidden = scale_mul(hidden, self.config.embedding_multiplier)

head_dim = self.config.attention_head_size
pos_embed = rotary.hlo_rotary_embedding(
hidden.dtype,
Expand Down Expand Up @@ -297,6 +310,8 @@ def layer(
attn_out_scales,
attn_out_bias,
)
# Granite specific: attention output is multiplied by residual multiplier
attn_output = scale_mul(attn_output, self.config.residual_multiplier)
hidden = hlo.add(attn_output, hidden)
gated_mlp = hlo.gated_mlp_bsh if is_bsh else hlo.gated_mlp
rms_norm_dim = 2 if is_bsh else 0
Expand Down Expand Up @@ -327,6 +342,8 @@ def layer(
tp_degree=self.config.tp_degree,
neuron_config=self.neuron_config,
)
# Granite specific: MLP output is multiplied by residual_multiplier
mlp_hidden = scale_mul(mlp_hidden, self.config.residual_multiplier)
res_hidden = hlo.add(mlp_hidden, hidden)
return res_hidden, out_attn_k_cache, out_attn_v_cache

Expand Down Expand Up @@ -657,8 +674,8 @@ def attention(
shard_over_batch=self.shard_over_batch,
)

# Q = Q / sqrt(d_head)
query = attention.scale(query, d_head)
# Granite specific: instead of dividing the QK product, multiply it by the attention_multiplier
query = scale_mul(query, self.config.attention_multiplier)

# In BSH cache layout, the output of QKV linear projection is still kept as SBH for all QKV.
bsh_cache_layout = False
Expand Down
2 changes: 2 additions & 0 deletions optimum/neuron/models/granite/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,5 +297,7 @@ def forward(self, input_ids, cache_ids=None, start_ids=None, last_token_id=None,
# either input_embeddings are generated (off device embedding), or input_ids will be padded from preprocess_and_embed (on device embedding)
inputs = input_embeddings if input_embeddings is not None else input_ids
logits = self._forward(inputs, *rst)
# Granite specific: divide logits by scaling factor
logits = logits / self.config.logits_scaling
logits = self._postprocess(logits, start_ids=start_ids, **kwargs)
return logits

0 comments on commit fc34974

Please sign in to comment.