From d8b2da85f6673225abe37f95f745ecfcdf3a3482 Mon Sep 17 00:00:00 2001 From: Tlntin Date: Mon, 29 Jul 2024 15:26:50 +0000 Subject: [PATCH] use Qwen2SdpaAttention to replace Qwen2Attention for speed up --- README.md | 2 ++ export/modeling_qwen2.py | 24 ++++++++++++++++-------- export/test_pytorch_run.py | 3 +-- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index fe26dda..c69442a 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ - 本项目参考了[ascend-llm](https://gitee.com/yinghuo302/ascend-llm)项目。 - 仅在昇腾310B1做了测试,理论上也兼容其他昇腾芯片。 - 仅测试了qwen1.5-0.5b-chat与qwen2-1.5b-instruct模型,理论上支持qwen1.5/qwen2系列所有chat/instruct模型。 +- CANN环境安装可以参考[该教程](https://www.hiascend.com/forum/thread-0286155882998311250-1-1.html)。 ### 准备工作 1. 下载本项目 @@ -12,6 +13,7 @@ ### 快速运行 +- 暂无 ### 分步骤运行 diff --git a/export/modeling_qwen2.py b/export/modeling_qwen2.py index 81eab34..735bea1 100644 --- a/export/modeling_qwen2.py +++ b/export/modeling_qwen2.py @@ -744,8 +744,8 @@ def __init__(self, config: Qwen2Config, layer_idx: int): "unexpected results may be encountered." ) # self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - # because npu only support Qwen2Attention - self.self_attn = Qwen2Attention(config, layer_idx) + self.self_attn = Qwen2SdpaAttention(config, layer_idx) + # self.self_attn = Qwen2Attention(config, layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -986,7 +986,8 @@ def get_masks(input_ids, past_key_values, padding_mask=None): dim=-1 ) if padding_mask is not None: - full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + full_attention_mask = full_attention_mask * padding_mask.unsqueeze( + 1) # if not past_length and padding_mask is not None: # full_attention_mask -= padding_mask.unsqueeze(-1) - 1 full_attention_mask = (full_attention_mask < 0.5).bool() @@ -1088,18 +1089,25 @@ def forward( sliding_window=self.config.sliding_window, ) """ - # copy from chatglm3-6b + # copy from chatglm3-6b for onnx export full_attention_mask = self.get_masks( input_ids, past_key_values, attention_mask, ) - dtype = past_key_values.dtype - device = input_ids.device - attention_mask = torch.zeros_like(full_attention_mask, dtype=dtype).to(device) - attention_mask.masked_fill_(full_attention_mask, torch.finfo(dtype).min) + # === if use Qwen2Attention === + # dtype = past_key_values.dtype + # device = input_ids.device + # attention_mask = torch.zeros_like(full_attention_mask, dtype=dtype).to(device) + # attention_mask.masked_fill_(full_attention_mask, torch.finfo(dtype).min) + + # == if use Qwen2SdpaAttention === + # copy from chatglm3-6b + attention_mask = ~full_attention_mask + hidden_states = inputs_embeds + # decoder layers # all_hidden_states = () if output_hidden_states else None # all_self_attns = () if output_attentions else None diff --git a/export/test_pytorch_run.py b/export/test_pytorch_run.py index 6eb772d..669a479 100644 --- a/export/test_pytorch_run.py +++ b/export/test_pytorch_run.py @@ -125,7 +125,7 @@ def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size position_ids, now_kv_cache, # use_cache=True, - output_attentions=True, + # output_attentions=True, ) print("==== pytorch runtime ====") print("output length: ", len(outputs)) @@ -135,7 +135,6 @@ def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size print("logits shape: ", logits.shape) print("logits mean: ", logits.float().mean().item()) print("logits max: ", logits.float().max().item()) - new_kv_cache = outputs[1][:, :, :, :, :-1, :] # 1: 0.0009: # new_kv_cache = outputs[1][:, :, :, :, -1:, :] # 2: 0.003526