Skip to content

Commit

Permalink
Fix(MInference): fix the vs pattern loss / sqrt(dk) (#70)
Browse files Browse the repository at this point in the history
Thanks! @PiotrNawrot

Co-authored-by: Piotr Nawrot <[email protected]>
Co-authored-by: Yucheng Li <[email protected]>
Co-authored-by: Chengruidong Zhang <[email protected]>
Co-authored-by: Yuqing Yang <[email protected]>
  • Loading branch information
5 people authored Aug 13, 2024
1 parent f0cae77 commit 5e37e0d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
6 changes: 3 additions & 3 deletions minference/modules/minference_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def dialted(q,k,v, type):
def vertical_and_slash_kernel(q, k, v, vertical_size, slash_size):
vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
last_q = min(64, q_len)
qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k)
qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k) / math.sqrt(self.head_dim)
qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK[...,-last_q:,-last_q:].to(q.device), qk[:, :, :, -last_q:], -torch.inf)
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
vertical = qk.sum(-2, keepdim=True)
Expand All @@ -391,7 +391,7 @@ def vertical_and_slash_kernel_extend(q, k, v, vertical_size, slash_size):
vertical_size, slash_size = min(q_len, max(vertical_size + 100, 30)), min(q_len, max(slash_size, 50))
last_q = min(64, q_len)
last_start = 100
qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q-last_start:-last_start,:], k)
qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q-last_start:-last_start,:], k) / math.sqrt(self.head_dim)
qk[:, :, :, -last_start:] = -torch.inf
qk[:, :, :, -last_q-last_start:-last_start] = torch.where(LAST_Q_MASK[...,-last_q:,-last_q:].to(q.device), qk[:, :, :, -last_q-last_start:-last_start], -torch.inf)
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
Expand All @@ -413,7 +413,7 @@ def vertical_and_slash_kernel_static(q, k, v, vertical_size, slash_size):
else:
vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
last_q = 64
qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k)
qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k) / math.sqrt(self.head_dim)
qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK, qk[:, :, :, -last_q:], -torch.inf)
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
vertical = qk.sum(-2, keepdim=True)
Expand Down
17 changes: 15 additions & 2 deletions minference/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,9 +1279,22 @@ def model_forward(
raise ValueError("Only supports llama, mistral and qwen2 models.")

hf_rope = model.model.layers[0].self_attn.rotary_emb
base = base if base is not None else hf_rope.base
base = (
base
if base is not None
else (hf_rope.base if "base" in hf_rope.__dict__ else hf_rope.config.rope_theta)
)
distance_scale = distance_scale if distance_scale is not None else 1.0
rope = RotaryEmbeddingESM(hf_rope.dim, base, distance_scale, is_glm4=is_glm4)
rope = RotaryEmbeddingESM(
(
hf_rope.dim
if "dim" in hf_rope.__dict__
else hf_rope.config.hidden_size // hf_rope.config.num_attention_heads
),
base,
distance_scale,
is_glm4=is_glm4,
)
model.model.position_bias = rope
model.model.hf_position_bias = hf_rope
DecoderLayer = model.model.layers[0].__class__
Expand Down
2 changes: 1 addition & 1 deletion minference/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
_PATCH = "5"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = ""
_SUFFIX = ".post1"

VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)

0 comments on commit 5e37e0d

Please sign in to comment.