Skip to content

Commit

Permalink
Add logit softcapping to GQA (#876)
Browse files Browse the repository at this point in the history
### Description

This PR adds the `softcap` attribute to the `GroupQueryAttention` op.

### Motivation and Context

This PR helps resolve the `NaN` output issue with Gemma-2 raised in
[this issue](#692).
  • Loading branch information
kunal-vaishnavi authored Oct 31, 2024
1 parent fb60d82 commit e222963
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
}

# LayerNorm-specific variables
epsilon = config.rms_norm_eps if hasattr(config, "rms_norm_eps") else 1e-06
self.layernorm_attrs = {
"simple": True, # Use SimplifiedLayerNorm/SkipSimplifiedLayerNorm vs. LayerNorm/SkipLayerNorm
"first_layernorm": True, # 1st LayerNorm = LayerNorm, then SkipLayerNorm for all subsequent LayerNorms
Expand All @@ -156,6 +157,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
"output_0": "", # Output 0 for LayerNorm and SkipLayerNorm
"output_3": "", # Output 3 for SkipLayerNorm
"add_offset": 0, # Offset value for LayerNorm weight
"epsilon": epsilon, # Epsilon value to avoid `sqrt(0)` in LayerNorm
}

# MatMul-specific variables
Expand Down Expand Up @@ -212,6 +214,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
}

# Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.)
softcap = config.attn_logit_softcapping if hasattr(config, "attn_logit_softcapping") else 0.0 # default is 0.0 in GroupQueryAttention kernel

# Block-sparse attention-specific variables
sparse_block_size = config.blocksparse_block_size if hasattr(config, "blocksparse_block_size") else 0
kernel_block_size = config.blocksparse_triton_kernel_block_size if hasattr(config, "blocksparse_triton_kernel_block_size") else 0
Expand All @@ -224,6 +228,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
"v_path": "", # V path to attention
"op_type": "MultiHeadAttention", # Attention op to use
"scale": 1 / np.sqrt(self.head_size), # Scale value after calculating Q x K' in attention
"softcap": softcap, # Softcap value to prevent values from exploding in attention
"use_rotemb_in_attn": False, # Use rotary embeddings within attention (instead of a separate RotaryEmbedding op)
"use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V)
"block_sparse": { # Block-sparse attention-specific variables
Expand Down Expand Up @@ -969,7 +974,7 @@ def make_layernorm(self, layer_id, layernorm, skip, simple, location):

name = f"/model/layers.{layer_id}/{location}_layernorm/{'Skip' if skip else ''}LayerNorm"
op_type = f"{'Skip' if skip else ''}{'Simplified' if simple else ''}LayerNormalization"
kwargs = {"epsilon": 9.999999747378752e-06}
kwargs = {"epsilon": self.layernorm_attrs["epsilon"]}
if not skip:
kwargs.update({"axis": -1, "stash_type": 1})

Expand Down Expand Up @@ -1381,7 +1386,7 @@ def make_group_query_attention(self, name, **kwargs):
self.make_node(
"GroupQueryAttention", inputs=inputs, outputs=outputs, name=name, domain="com.microsoft",
num_heads=self.num_attn_heads, kv_num_heads=self.num_kv_heads, scale=self.attention_attrs["scale"], # local_window_size=self.window_size, # Disable sliding window attribute temporarily
do_rotary=self.attention_attrs["use_rotemb_in_attn"], rotary_interleaved=self.rotemb_attrs["interleaved"],
softcap=self.attention_attrs["softcap"], do_rotary=self.attention_attrs["use_rotemb_in_attn"], rotary_interleaved=self.rotemb_attrs["interleaved"],
)
self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.head_size * self.num_attn_heads])

Expand Down

0 comments on commit e222963

Please sign in to comment.