Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logit softcapping to GQA #876

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
}

# LayerNorm-specific variables
epsilon = config.rms_norm_eps if hasattr(config, "rms_norm_eps") else 1e-06
self.layernorm_attrs = {
"simple": True, # Use SimplifiedLayerNorm/SkipSimplifiedLayerNorm vs. LayerNorm/SkipLayerNorm
"first_layernorm": True, # 1st LayerNorm = LayerNorm, then SkipLayerNorm for all subsequent LayerNorms
Expand All @@ -156,6 +157,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
"output_0": "", # Output 0 for LayerNorm and SkipLayerNorm
"output_3": "", # Output 3 for SkipLayerNorm
"add_offset": 0, # Offset value for LayerNorm weight
"epsilon": epsilon, # Epsilon value to avoid `sqrt(0)` in LayerNorm
}

# MatMul-specific variables
Expand Down Expand Up @@ -212,6 +214,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
}

# Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.)
softcap = config.attn_logit_softcapping if hasattr(config, "attn_logit_softcapping") else 0.0 # default is 0.0 in GroupQueryAttention kernel

# Block-sparse attention-specific variables
sparse_block_size = config.blocksparse_block_size if hasattr(config, "blocksparse_block_size") else 0
kernel_block_size = config.blocksparse_triton_kernel_block_size if hasattr(config, "blocksparse_triton_kernel_block_size") else 0
Expand All @@ -224,6 +228,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
"v_path": "", # V path to attention
"op_type": "MultiHeadAttention", # Attention op to use
"scale": 1 / np.sqrt(self.head_size), # Scale value after calculating Q x K' in attention
"softcap": softcap, # Softcap value to prevent values from exploding in attention
"use_rotemb_in_attn": False, # Use rotary embeddings within attention (instead of a separate RotaryEmbedding op)
"use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V)
"block_sparse": { # Block-sparse attention-specific variables
Expand Down Expand Up @@ -969,7 +974,7 @@ def make_layernorm(self, layer_id, layernorm, skip, simple, location):

name = f"/model/layers.{layer_id}/{location}_layernorm/{'Skip' if skip else ''}LayerNorm"
op_type = f"{'Skip' if skip else ''}{'Simplified' if simple else ''}LayerNormalization"
kwargs = {"epsilon": 9.999999747378752e-06}
kwargs = {"epsilon": self.layernorm_attrs["epsilon"]}
if not skip:
kwargs.update({"axis": -1, "stash_type": 1})

Expand Down Expand Up @@ -1381,7 +1386,7 @@ def make_group_query_attention(self, name, **kwargs):
self.make_node(
"GroupQueryAttention", inputs=inputs, outputs=outputs, name=name, domain="com.microsoft",
num_heads=self.num_attn_heads, kv_num_heads=self.num_kv_heads, scale=self.attention_attrs["scale"], # local_window_size=self.window_size, # Disable sliding window attribute temporarily
do_rotary=self.attention_attrs["use_rotemb_in_attn"], rotary_interleaved=self.rotemb_attrs["interleaved"],
softcap=self.attention_attrs["softcap"], do_rotary=self.attention_attrs["use_rotemb_in_attn"], rotary_interleaved=self.rotemb_attrs["interleaved"],
kunal-vaishnavi marked this conversation as resolved.
Show resolved Hide resolved
)
self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.head_size * self.num_attn_heads])

Expand Down
Loading