Skip to content

Commit

Permalink
reduce onnx export warning
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Jul 29, 2024
1 parent 92180e8 commit ac9f747
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 53 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
--kv_cache_length=1024
```

2. 验证onnx,返回项目根目录,运行cli_chat.py,测试一下onnx对话是否正常。
2. 验证onnx,返回项目根目录,运行cli_chat.py,测试一下onnx对话是否正常(注意:由于是cpu运行,所以速度较慢,请耐心等待)
```bash
python3 ./cli_chat.py \
--session_type=onnx \
Expand Down
4 changes: 2 additions & 2 deletions export/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def export_onnx(
model = Qwen2ForCausalLM.from_pretrained(
hf_model_dir,
torch_dtype=torch_dtype,
trust_remote_code=True
# trust_remote_code=True
).to(device)
quantize_cfg = {
"query_key_value": {
Expand Down Expand Up @@ -153,7 +153,7 @@ def export_onnx(
# None, # inputs_embeds: Optional[torch.FloatTensor] = None,
# None, # labels: Optional[torch.LongTensor] = None,
# True, # use_cache: Optional[bool] = None,
True, # output_attentions: Optional[bool] = None,
# True, # output_attentions: Optional[bool] = None,
# None, # output_hidden_states
# False # return_dict:
)
Expand Down
108 changes: 58 additions & 50 deletions export/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
# output_attentions: bool = False,
# use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
Expand Down Expand Up @@ -322,8 +322,8 @@ def forward(

attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None
# if not output_attentions:
# attn_weights = None

# return attn_output, attn_weights, past_key_value
return attn_output, attn_weights, out_cache
Expand Down Expand Up @@ -648,23 +648,23 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
# output_attentions: bool = False,
# use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
# logger.warning_once(
# "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
# 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
# )
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
# use_cache=use_cache,
)
# if output_attentions:
# # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
# # logger.warning_once(
# # "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
# # 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
# # )
# return super().forward(
# hidden_states=hidden_states,
# attention_mask=attention_mask,
# position_ids=position_ids,
# past_key_value=past_key_value,
# # output_attentions=output_attentions,
# # use_cache=use_cache,
# )

bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -695,11 +695,11 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
# if attention_mask is not None:
# if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
# raise ValueError(
# f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
# )

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
Expand All @@ -723,7 +723,6 @@ def forward(
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)

return attn_output, None, output_cache


Expand All @@ -745,7 +744,8 @@ def __init__(self, config: Qwen2Config, layer_idx: int):
"unexpected results may be encountered."
)
# self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.self_attn = Qwen2SdpaAttention(config, layer_idx)
# because npu only support Qwen2Attention
self.self_attn = Qwen2Attention(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand All @@ -756,7 +756,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
# output_attentions: Optional[bool] = False,
# use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
Expand Down Expand Up @@ -789,7 +789,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
# output_attentions=output_attentions,
# use_cache=use_cache,
)
hidden_states = residual + hidden_states
Expand All @@ -802,8 +802,8 @@ def forward(

outputs = (hidden_states,)

if output_attentions:
outputs += (self_attn_weights,)
# if output_attentions:
# outputs += (self_attn_weights,)

# if use_cache:
# outputs += (present_key_value,)
Expand Down Expand Up @@ -962,21 +962,33 @@ def set_input_embeddings(self, value):
@staticmethod
def get_masks(input_ids, past_key_values, padding_mask=None):
batch_size, seq_length = input_ids.shape
full_attention_mask = torch.ones(batch_size, seq_length, seq_length,
device=input_ids.device)
full_attention_mask = torch.ones(
batch_size,
seq_length,
seq_length,
device=input_ids.device,
# dtype=torch.int64
)
full_attention_mask.tril_()
past_length = past_key_values.shape[4]
# if past_length is not None:
full_attention_mask = torch.cat(
(torch.ones(batch_size, seq_length, past_length,
device=input_ids.device), full_attention_mask),
(
torch.ones(
batch_size,
seq_length,
past_length,
device=input_ids.device,
# dtype=torch.int64
),
full_attention_mask
),
dim=-1
)
if padding_mask is not None:
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(
1)
if not past_length and padding_mask is not None:
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
# if not past_length and padding_mask is not None:
# full_attention_mask -= padding_mask.unsqueeze(-1) - 1
full_attention_mask = (full_attention_mask < 0.5).bool()
full_attention_mask.unsqueeze_(1)
return full_attention_mask
Expand All @@ -987,14 +999,14 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[torch.FloatTensor] = None,
# inputs_embeds: Optional[torch.FloatTensor] = None,
# use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
# output_attentions: Optional[bool] = None,
# output_hidden_states: Optional[bool] = None,
# return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# output_hidden_states = (
# output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
# )
Expand Down Expand Up @@ -1113,15 +1125,15 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
# output_attentions=output_attentions,
# use_cache=use_cache,
)

hidden_states = layer_outputs[0]

# if use_cache:
# next_decoder_cache = layer_outputs[2 if output_attentions else 1]
presents.extend(layer_outputs[2])
presents.extend(layer_outputs[1])

# if output_attentions:
# all_self_attns += (layer_outputs[1],)
Expand Down Expand Up @@ -1191,20 +1203,16 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[torch.FloatTensor] = None,
# inputs_embeds: Optional[torch.FloatTensor] = None,
# labels: Optional[torch.LongTensor] = None,
# use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
# output_attentions: Optional[bool] = None,
# output_hidden_states: Optional[bool] = None,
# return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Expand Down Expand Up @@ -1233,7 +1241,7 @@ def forward(
# len(past_key_values[0]), past_key_values[0][0].shape
# )
# # [24, 2, 1, 16, 20, 64]
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# output_hidden_states = (
# output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
# )
Expand All @@ -1247,7 +1255,7 @@ def forward(
past_key_values=past_key_values,
# inputs_embeds=inputs_embeds,
# use_cache=use_cache,
output_attentions=output_attentions,
# output_attentions=output_attentions,
# output_hidden_states=output_hidden_states,
# return_dict=return_dict,
)
Expand Down

0 comments on commit ac9f747

Please sign in to comment.