Skip to content

Commit

Permalink
use Qwen2SdpaAttention to replace Qwen2Attention for speed up
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Jul 29, 2024
1 parent ac9f747 commit d8b2da8
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
- 本项目参考了[ascend-llm](https://gitee.com/yinghuo302/ascend-llm)项目。
- 仅在昇腾310B1做了测试,理论上也兼容其他昇腾芯片。
- 仅测试了qwen1.5-0.5b-chat与qwen2-1.5b-instruct模型,理论上支持qwen1.5/qwen2系列所有chat/instruct模型。
- CANN环境安装可以参考[该教程](https://www.hiascend.com/forum/thread-0286155882998311250-1-1.html)

### 准备工作
1. 下载本项目
Expand All @@ -12,6 +13,7 @@


### 快速运行
- 暂无


### 分步骤运行
Expand Down
24 changes: 16 additions & 8 deletions export/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,8 @@ def __init__(self, config: Qwen2Config, layer_idx: int):
"unexpected results may be encountered."
)
# self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
# because npu only support Qwen2Attention
self.self_attn = Qwen2Attention(config, layer_idx)
self.self_attn = Qwen2SdpaAttention(config, layer_idx)
# self.self_attn = Qwen2Attention(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -986,7 +986,8 @@ def get_masks(input_ids, past_key_values, padding_mask=None):
dim=-1
)
if padding_mask is not None:
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(
1)
# if not past_length and padding_mask is not None:
# full_attention_mask -= padding_mask.unsqueeze(-1) - 1
full_attention_mask = (full_attention_mask < 0.5).bool()
Expand Down Expand Up @@ -1088,18 +1089,25 @@ def forward(
sliding_window=self.config.sliding_window,
)
"""
# copy from chatglm3-6b
# copy from chatglm3-6b for onnx export
full_attention_mask = self.get_masks(
input_ids,
past_key_values,
attention_mask,
)
dtype = past_key_values.dtype
device = input_ids.device
attention_mask = torch.zeros_like(full_attention_mask, dtype=dtype).to(device)
attention_mask.masked_fill_(full_attention_mask, torch.finfo(dtype).min)
# === if use Qwen2Attention ===
# dtype = past_key_values.dtype
# device = input_ids.device
# attention_mask = torch.zeros_like(full_attention_mask, dtype=dtype).to(device)
# attention_mask.masked_fill_(full_attention_mask, torch.finfo(dtype).min)

# == if use Qwen2SdpaAttention ===
# copy from chatglm3-6b
attention_mask = ~full_attention_mask

hidden_states = inputs_embeds


# decoder layers
# all_hidden_states = () if output_hidden_states else None
# all_self_attns = () if output_attentions else None
Expand Down
3 changes: 1 addition & 2 deletions export/test_pytorch_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size
position_ids,
now_kv_cache,
# use_cache=True,
output_attentions=True,
# output_attentions=True,
)
print("==== pytorch runtime ====")
print("output length: ", len(outputs))
Expand All @@ -135,7 +135,6 @@ def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size
print("logits shape: ", logits.shape)
print("logits mean: ", logits.float().mean().item())
print("logits max: ", logits.float().max().item())

new_kv_cache = outputs[1][:, :, :, :, :-1, :] # 1: 0.0009:
# new_kv_cache = outputs[1][:, :, :, :, -1:, :] # 2: 0.003526

Expand Down

0 comments on commit d8b2da8

Please sign in to comment.