From 5e37e0dc631aeccc4341e38a40f0c178b2178e22 Mon Sep 17 00:00:00 2001 From: Huiqiang Jiang Date: Tue, 13 Aug 2024 17:10:42 +0800 Subject: [PATCH] Fix(MInference): fix the vs pattern loss / sqrt(dk) (#70) Thanks! @PiotrNawrot Co-authored-by: Piotr Nawrot Co-authored-by: Yucheng Li Co-authored-by: Chengruidong Zhang Co-authored-by: Yuqing Yang --- minference/modules/minference_forward.py | 6 +++--- minference/patch.py | 17 +++++++++++++++-- minference/version.py | 2 +- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/minference/modules/minference_forward.py b/minference/modules/minference_forward.py index fdb46f5..d34bc84 100644 --- a/minference/modules/minference_forward.py +++ b/minference/modules/minference_forward.py @@ -373,7 +373,7 @@ def dialted(q,k,v, type): def vertical_and_slash_kernel(q, k, v, vertical_size, slash_size): vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50)) last_q = min(64, q_len) - qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k) + qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k) / math.sqrt(self.head_dim) qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK[...,-last_q:,-last_q:].to(q.device), qk[:, :, :, -last_q:], -torch.inf) qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32) vertical = qk.sum(-2, keepdim=True) @@ -391,7 +391,7 @@ def vertical_and_slash_kernel_extend(q, k, v, vertical_size, slash_size): vertical_size, slash_size = min(q_len, max(vertical_size + 100, 30)), min(q_len, max(slash_size, 50)) last_q = min(64, q_len) last_start = 100 - qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q-last_start:-last_start,:], k) + qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q-last_start:-last_start,:], k) / math.sqrt(self.head_dim) qk[:, :, :, -last_start:] = -torch.inf qk[:, :, :, -last_q-last_start:-last_start] = torch.where(LAST_Q_MASK[...,-last_q:,-last_q:].to(q.device), qk[:, :, :, -last_q-last_start:-last_start], -torch.inf) qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32) @@ -413,7 +413,7 @@ def vertical_and_slash_kernel_static(q, k, v, vertical_size, slash_size): else: vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50)) last_q = 64 - qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k) + qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k) / math.sqrt(self.head_dim) qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK, qk[:, :, :, -last_q:], -torch.inf) qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32) vertical = qk.sum(-2, keepdim=True) diff --git a/minference/patch.py b/minference/patch.py index 8f1e4cf..10fb24e 100644 --- a/minference/patch.py +++ b/minference/patch.py @@ -1279,9 +1279,22 @@ def model_forward( raise ValueError("Only supports llama, mistral and qwen2 models.") hf_rope = model.model.layers[0].self_attn.rotary_emb - base = base if base is not None else hf_rope.base + base = ( + base + if base is not None + else (hf_rope.base if "base" in hf_rope.__dict__ else hf_rope.config.rope_theta) + ) distance_scale = distance_scale if distance_scale is not None else 1.0 - rope = RotaryEmbeddingESM(hf_rope.dim, base, distance_scale, is_glm4=is_glm4) + rope = RotaryEmbeddingESM( + ( + hf_rope.dim + if "dim" in hf_rope.__dict__ + else hf_rope.config.hidden_size // hf_rope.config.num_attention_heads + ), + base, + distance_scale, + is_glm4=is_glm4, + ) model.model.position_bias = rope model.model.hf_position_bias = hf_rope DecoderLayer = model.model.layers[0].__class__ diff --git a/minference/version.py b/minference/version.py index b81efd9..ff46d65 100644 --- a/minference/version.py +++ b/minference/version.py @@ -8,7 +8,7 @@ _PATCH = "5" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. -_SUFFIX = "" +_SUFFIX = ".post1" VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)