diff --git a/README.md b/README.md index 95872a1..fe26dda 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ --kv_cache_length=1024 ``` -2. 验证onnx,返回项目根目录,运行cli_chat.py,测试一下onnx对话是否正常。 +2. 验证onnx,返回项目根目录,运行cli_chat.py,测试一下onnx对话是否正常(注意:由于是cpu运行,所以速度较慢,请耐心等待)。 ```bash python3 ./cli_chat.py \ --session_type=onnx \ diff --git a/export/export_onnx.py b/export/export_onnx.py index 07f7336..d52e1e9 100644 --- a/export/export_onnx.py +++ b/export/export_onnx.py @@ -93,7 +93,7 @@ def export_onnx( model = Qwen2ForCausalLM.from_pretrained( hf_model_dir, torch_dtype=torch_dtype, - trust_remote_code=True + # trust_remote_code=True ).to(device) quantize_cfg = { "query_key_value": { @@ -153,7 +153,7 @@ def export_onnx( # None, # inputs_embeds: Optional[torch.FloatTensor] = None, # None, # labels: Optional[torch.LongTensor] = None, # True, # use_cache: Optional[bool] = None, - True, # output_attentions: Optional[bool] = None, + # True, # output_attentions: Optional[bool] = None, # None, # output_hidden_states # False # return_dict: ) diff --git a/export/modeling_qwen2.py b/export/modeling_qwen2.py index 370157a..81eab34 100644 --- a/export/modeling_qwen2.py +++ b/export/modeling_qwen2.py @@ -247,7 +247,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: bool = False, + # output_attentions: bool = False, # use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -322,8 +322,8 @@ def forward( attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None + # if not output_attentions: + # attn_weights = None # return attn_output, attn_weights, past_key_value return attn_output, attn_weights, out_cache @@ -648,23 +648,23 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: bool = False, + # output_attentions: bool = False, # use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - # logger.warning_once( - # "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - # 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - # ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - # use_cache=use_cache, - ) + # if output_attentions: + # # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + # # logger.warning_once( + # # "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + # # 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + # # ) + # return super().forward( + # hidden_states=hidden_states, + # attention_mask=attention_mask, + # position_ids=position_ids, + # past_key_value=past_key_value, + # # output_attentions=output_attentions, + # # use_cache=use_cache, + # ) bsz, q_len, _ = hidden_states.size() @@ -695,11 +695,11 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) + # if attention_mask is not None: + # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + # raise ValueError( + # f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + # ) # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. @@ -723,7 +723,6 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) - return attn_output, None, output_cache @@ -745,7 +744,8 @@ def __init__(self, config: Qwen2Config, layer_idx: int): "unexpected results may be encountered." ) # self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - self.self_attn = Qwen2SdpaAttention(config, layer_idx) + # because npu only support Qwen2Attention + self.self_attn = Qwen2Attention(config, layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -756,7 +756,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, + # output_attentions: Optional[bool] = False, # use_cache: Optional[bool] = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -789,7 +789,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, - output_attentions=output_attentions, + # output_attentions=output_attentions, # use_cache=use_cache, ) hidden_states = residual + hidden_states @@ -802,8 +802,8 @@ def forward( outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) + # if output_attentions: + # outputs += (self_attn_weights,) # if use_cache: # outputs += (present_key_value,) @@ -962,21 +962,33 @@ def set_input_embeddings(self, value): @staticmethod def get_masks(input_ids, past_key_values, padding_mask=None): batch_size, seq_length = input_ids.shape - full_attention_mask = torch.ones(batch_size, seq_length, seq_length, - device=input_ids.device) + full_attention_mask = torch.ones( + batch_size, + seq_length, + seq_length, + device=input_ids.device, + # dtype=torch.int64 + ) full_attention_mask.tril_() past_length = past_key_values.shape[4] # if past_length is not None: full_attention_mask = torch.cat( - (torch.ones(batch_size, seq_length, past_length, - device=input_ids.device), full_attention_mask), + ( + torch.ones( + batch_size, + seq_length, + past_length, + device=input_ids.device, + # dtype=torch.int64 + ), + full_attention_mask + ), dim=-1 ) if padding_mask is not None: - full_attention_mask = full_attention_mask * padding_mask.unsqueeze( - 1) - if not past_length and padding_mask is not None: - full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + # if not past_length and padding_mask is not None: + # full_attention_mask -= padding_mask.unsqueeze(-1) - 1 full_attention_mask = (full_attention_mask < 0.5).bool() full_attention_mask.unsqueeze_(1) return full_attention_mask @@ -987,14 +999,14 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[torch.FloatTensor] = None, # inputs_embeds: Optional[torch.FloatTensor] = None, # use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, + # output_attentions: Optional[bool] = None, # output_hidden_states: Optional[bool] = None, # return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # output_hidden_states = ( # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # ) @@ -1113,7 +1125,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, - output_attentions=output_attentions, + # output_attentions=output_attentions, # use_cache=use_cache, ) @@ -1121,7 +1133,7 @@ def forward( # if use_cache: # next_decoder_cache = layer_outputs[2 if output_attentions else 1] - presents.extend(layer_outputs[2]) + presents.extend(layer_outputs[1]) # if output_attentions: # all_self_attns += (layer_outputs[1],) @@ -1191,20 +1203,16 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[torch.FloatTensor] = None, # inputs_embeds: Optional[torch.FloatTensor] = None, # labels: Optional[torch.LongTensor] = None, # use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, + # output_attentions: Optional[bool] = None, # output_hidden_states: Optional[bool] = None, # return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: @@ -1233,7 +1241,7 @@ def forward( # len(past_key_values[0]), past_key_values[0][0].shape # ) # # [24, 2, 1, 16, 20, 64] - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # output_hidden_states = ( # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # ) @@ -1247,7 +1255,7 @@ def forward( past_key_values=past_key_values, # inputs_embeds=inputs_embeds, # use_cache=use_cache, - output_attentions=output_attentions, + # output_attentions=output_attentions, # output_hidden_states=output_hidden_states, # return_dict=return_dict, )