From 7ae6095bbc9bdb93795ab8d2f6ea35617207a572 Mon Sep 17 00:00:00 2001 From: Tlntin Date: Mon, 29 Jul 2024 13:03:52 +0000 Subject: [PATCH] reduce some waring for onnx export, add test code --- export/export_onnx.py | 10 +- export/modeling_qwen2.py | 255 +++++++++++++++++++------------------ export/onnx2om.py | 6 +- export/test_onnx_run.py | 128 +++++++++++++++++++ export/test_pytorch_run.py | 124 ++++++++++++++++++ 5 files changed, 390 insertions(+), 133 deletions(-) create mode 100644 export/test_onnx_run.py create mode 100644 export/test_pytorch_run.py diff --git a/export/export_onnx.py b/export/export_onnx.py index 72ae07c..07f7336 100644 --- a/export/export_onnx.py +++ b/export/export_onnx.py @@ -150,12 +150,12 @@ def export_onnx( attention_mask, position_ids, past_key_values, - None, # inputs_embeds: Optional[torch.FloatTensor] = None, - None, # labels: Optional[torch.LongTensor] = None, - True, # use_cache: Optional[bool] = None, + # None, # inputs_embeds: Optional[torch.FloatTensor] = None, + # None, # labels: Optional[torch.LongTensor] = None, + # True, # use_cache: Optional[bool] = None, True, # output_attentions: Optional[bool] = None, - None, # output_hidden_states - False # return_dict: + # None, # output_hidden_states + # False # return_dict: ) model.eval() with torch.no_grad(): diff --git a/export/modeling_qwen2.py b/export/modeling_qwen2.py index b86ddd8..370157a 100644 --- a/export/modeling_qwen2.py +++ b/export/modeling_qwen2.py @@ -248,7 +248,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, - use_cache: bool = False, + # use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: @@ -292,30 +292,30 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) + # if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + # raise ValueError( + # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + # f" {attn_weights.size()}" + # ) - attn_weights = attn_weights + attention_mask + # if attention_mask is not None: + # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + # raise ValueError( + # f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + # ) + # attn_weights = attn_weights + attention_mask + attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + # if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + # raise ValueError( + # f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + # f" {attn_output.size()}" + # ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -649,21 +649,21 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, - use_cache: bool = False, + # use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) + # logger.warning_once( + # "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + # 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + # ) return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, - use_cache=use_cache, + # use_cache=use_cache, ) bsz, q_len, _ = hidden_states.size() @@ -683,7 +683,7 @@ def forward( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - + output_cache = (key_states, value_states) if past_key_value is not None: # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -724,7 +724,7 @@ def forward( attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None, output_cache QWEN2_ATTENTION_CLASSES = { @@ -757,7 +757,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, + # use_cache: Optional[bool] = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: if "padding_mask" in kwargs: @@ -790,7 +790,7 @@ def forward( position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, - use_cache=use_cache, + # use_cache=use_cache, ) hidden_states = residual + hidden_states @@ -805,8 +805,9 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) + # if use_cache: + # outputs += (present_key_value,) + outputs += (present_key_value,) return outputs @@ -987,38 +988,39 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, + # inputs_embeds: Optional[torch.FloatTensor] = None, + # use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + # output_hidden_states: Optional[bool] = None, + # return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache + # output_hidden_states = ( + # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + # ) + # use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + batch_size, seq_length = input_ids.shape + # if input_ids is not None and inputs_embeds is not None: + # raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + # elif input_ids is not None: + # batch_size, seq_length = input_ids.shape + # elif inputs_embeds is not None: + # batch_size, seq_length, _ = inputs_embeds.shape + # else: + # raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False + # if self.gradient_checkpointing and self.training: + # if use_cache: + # logger.warning_once( + # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + # ) + # use_cache = False # if past_key_values is not None: - past_key_values_length = past_key_values.shape[4] + # past_key_values_length = past_key_values.shape[4] # else: # past_key_values_length = 0 @@ -1028,17 +1030,19 @@ def forward( # past_key_values = DynamicCache.from_legacy_cache(past_key_values) # past_key_values_length = past_key_values.get_usable_length(seq_length) - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + # if position_ids is None: + # device = input_ids.device if input_ids is not None else inputs_embeds.device + # position_ids = torch.arange( + # past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + # ) + # position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + # else: + # position_ids = position_ids.view(-1, seq_length).long() + position_ids = position_ids.view(-1, seq_length).long() - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + # if inputs_embeds is None: + # inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input_ids) """ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size @@ -1085,59 +1089,59 @@ def forward( hidden_states = inputs_embeds # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None + # all_hidden_states = () if output_hidden_states else None + # all_self_attns = () if output_attentions else None # next_decoder_cache = None presents = [] for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - ) + # if output_hidden_states: + # all_hidden_states += (hidden_states,) + + # if self.gradient_checkpointing and self.training: + # layer_outputs = self._gradient_checkpointing_func( + # decoder_layer.__call__, + # hidden_states, + # attention_mask, + # position_ids, + # past_key_values, + # output_attentions, + # use_cache, + # ) + # else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + # use_cache=use_cache, + ) hidden_states = layer_outputs[0] - if use_cache: + # if use_cache: # next_decoder_cache = layer_outputs[2 if output_attentions else 1] - presents.extend(layer_outputs[2]) + presents.extend(layer_outputs[2]) - if output_attentions: - all_self_attns += (layer_outputs[1],) + # if output_attentions: + # all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) + # if output_hidden_states: + # all_hidden_states += (hidden_states,) # next_cache = None - if use_cache: + # if use_cache: # next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - one_shape = [len(presents) // 2, 2] + list(presents[0].shape) - presents = torch.concat(presents).reshape(one_shape) + one_shape = [len(presents) // 2, 2] + list(presents[0].shape) + presents = torch.concat(presents).reshape(one_shape) return ( hidden_states, presents, - all_hidden_states, - all_self_attns + # all_hidden_states, + # all_self_attns ) # if not return_dict: @@ -1188,12 +1192,12 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, + # inputs_embeds: Optional[torch.FloatTensor] = None, + # labels: Optional[torch.LongTensor] = None, + # use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + # output_hidden_states: Optional[bool] = None, + # return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1230,10 +1234,10 @@ def forward( # ) # # [24, 2, 1, 16, 20, 64] output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # output_hidden_states = ( + # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + # ) + # return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -1241,33 +1245,34 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, + # inputs_embeds=inputs_embeds, + # use_cache=use_cache, output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + # output_hidden_states=output_hidden_states, + # return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + # loss = None + # if labels is not None: + # # Shift so that tokens < n predict n + # shift_logits = logits[..., :-1, :].contiguous() + # shift_labels = labels[..., 1:].contiguous() + # # Flatten the tokens + # loss_fct = CrossEntropyLoss() + # shift_logits = shift_logits.view(-1, self.config.vocab_size) + # shift_labels = shift_labels.view(-1) + # # Enable model parallelism + # shift_labels = shift_labels.to(shift_logits.device) + # loss = loss_fct(shift_logits, shift_labels) # if not return_dict: output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + # return (loss,) + output if loss is not None else output + return output # return CausalLMOutputWithPast( # loss=loss, diff --git a/export/onnx2om.py b/export/onnx2om.py index 09d506f..84dd96c 100644 --- a/export/onnx2om.py +++ b/export/onnx2om.py @@ -20,19 +20,19 @@ '--hf_model_dir', type=str, help="model and tokenizer path, only support huggingface model", - default=os.path.join(project_dir, "download", "Qwen1_5_0_5B_Chat") + default=os.path.join(project_dir, "download", "Qwen2-1.5B-Instruct") ) parser.add_argument( "--onnx_model_path", help="output onnx path", type=str, - default=os.path.join(onnx_model_dir, "qwen1.5_0.5b_chat.onnx") + default=os.path.join(onnx_model_dir, "qwen2_1.5b_chat.onnx") ) parser.add_argument( "--om_model_path", help=".om model path", type=str, - default= os.path.join(model_dir, "qwen1.5_0.5b_chat") + default= os.path.join(model_dir, "qwen2_1.5b_chat") ) parser.add_argument( "--kv_cache_length", diff --git a/export/test_onnx_run.py b/export/test_onnx_run.py new file mode 100644 index 0000000..da18eea --- /dev/null +++ b/export/test_onnx_run.py @@ -0,0 +1,128 @@ +import os +import numpy as np +import onnxruntime +import argparse +from transformers.models.qwen2 import Qwen2Tokenizer, Qwen2Config + + +now_dir = os.path.dirname(os.path.abspath(__file__)) +project_dir = os.path.dirname(now_dir) + +parser = argparse.ArgumentParser() +parser.add_argument( + '--hf_model_dir', + type=str, + help="model and tokenizer path, only support huggingface model", + default=os.path.join(project_dir, "download", "Qwen2-1.5B-Instruct") +) +parser.add_argument( + "--onnx_model_path", + help="output onnx path", + type=str, + default=os.path.join(project_dir, "output", "onnx", "qwen2_1.5b_chat.onnx") +) +args = parser.parse_args() + + +def create_kv_cache(config: Qwen2Config, kv_cache_length=1024): + return np.zeros( + [ + config.num_hidden_layers, + 2, + 1, + config.num_key_value_heads, + kv_cache_length, + config.hidden_size // config.num_attention_heads + ], + dtype=np.float16 + ) + + +def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size: int = 1024): + """ + 获取指定长度的kv_cache, 顺便生成mask和position_id + Args: + kv_cache + seq_len (int): 待获取的kv-cache长度 + real_kv_size: 真实kv_size长度 + input_pos: 当前真实token所在位置 + past_kv_size + + Returns: + List[np.ndarray]: _description_ + """ + + """ + self.kv_cache shape ( + self.num_hidden_layers, + 2, + 1, + self.num_key_value_heads, + self.kv_cache_length, + self.per_head_dim + ) + """ + cache = kv_cache[:, :, :, :, :past_kv_size] + mask = np.ones((1, past_kv_size + seq_len), dtype=np.int64) + mask[:, real_kv_size: past_kv_size] = 0 + pos_id = np.arange( + input_pos, + input_pos + seq_len, + dtype=np.int64 + ).reshape(1, -1) + return cache, mask, pos_id + + +tokenizer = Qwen2Tokenizer.from_pretrained(args.hf_model_dir) +model_config = Qwen2Config.from_pretrained(args.hf_model_dir) +prompt = "你好" +system_prompt: str = "You are a helpful assistant." +history = [] +if len(history) == 0: + history = [{"role": "system", "content": system_prompt}] +history.append({"role": "user", "content": prompt}) +print("history: ", history) +text = tokenizer.apply_chat_template( + history, + tokenize=False, + add_generation_prompt=True +) +print("raw_text", text) +input_ids = tokenizer( + [text], return_tensors="np" +)["input_ids"].astype(np.int64) +print("input_ids", input_ids) + +options = onnxruntime.SessionOptions() +llm_session = onnxruntime.InferenceSession( + args.onnx_model_path, + sess_options=options, + providers=[ + "CPUExecutionProvider", + ], +) + +seq_len = input_ids.shape[-1] +kv_cache1 = create_kv_cache(model_config) +now_kv_cache, attn_mask, position_ids = get_inputs(kv_cache1, 1) +print("now_kv_cache shape: ", now_kv_cache.shape) +print("attention_mask shape: ", attn_mask.shape) +print("position_ids shape: ", position_ids.shape) +outputs = llm_session.run(None, { + "input_ids": input_ids[:, :1], + "attention_mask": attn_mask, + "position_ids": position_ids, + "past_key_values": now_kv_cache, +}) +print("==== onnx runtime ====") +print("output length: ", len(outputs)) +logits = outputs[0] +print("logits shape: ", logits.shape) +print("logits mean: ", logits.astype(np.float32).mean().item()) +print("logits max: ", logits.astype(np.float32).max().item()) +new_kv_cache = outputs[1] # [:, :, :, :, :-1, :] +print("new_kv_cache: shape", new_kv_cache.shape) +print("new_kv_cache: mean: ", new_kv_cache.astype(np.float32).mean().item()) +print("new_kv_cache: max: ", new_kv_cache.astype(np.float32).max().item()) + + diff --git a/export/test_pytorch_run.py b/export/test_pytorch_run.py new file mode 100644 index 0000000..d0649ad --- /dev/null +++ b/export/test_pytorch_run.py @@ -0,0 +1,124 @@ +import os +import torch +import argparse +from modeling_qwen2 import Qwen2ForCausalLM +from transformers.models.qwen2 import Qwen2Tokenizer, Qwen2Config + +device = "cpu" +now_dir = os.path.dirname(os.path.abspath(__file__)) +project_dir = os.path.dirname(now_dir) + +parser = argparse.ArgumentParser() +parser.add_argument( + '--hf_model_dir', + type=str, + help="model and tokenizer path, only support huggingface model", + default=os.path.join(project_dir, "download", "Qwen2-1.5B-Instruct") +) + +args = parser.parse_args() + + +def create_kv_cache(config: Qwen2Config, kv_cache_length=1024): + return torch.zeros( + [ + config.num_hidden_layers, + 2, + 1, + config.num_key_value_heads, + kv_cache_length, + config.hidden_size // config.num_attention_heads + ], + dtype=torch.float16 + ).to(device) + + +def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size: int = 1024): + """ + 获取指定长度的kv_cache, 顺便生成mask和position_id + Args: + kv_cache + seq_len (int): 待获取的kv-cache长度 + real_kv_size: 真实kv_size长度 + input_pos: 当前真实token所在位置 + past_kv_size + + Returns: + List[np.ndarray]: _description_ + """ + + """ + self.kv_cache shape ( + self.num_hidden_layers, + 2, + 1, + self.num_key_value_heads, + self.kv_cache_length, + self.per_head_dim + ) + """ + cache = kv_cache[:, :, :, :, :past_kv_size] + mask = torch.ones((1, past_kv_size + seq_len), dtype=torch.long).to(device) + mask[:, real_kv_size: past_kv_size] = 0 + pos_id = torch.arange( + input_pos, + input_pos + seq_len, + dtype=torch.long + ).reshape(1, -1).to(device) + return cache, mask, pos_id + + +tokenizer = Qwen2Tokenizer.from_pretrained(args.hf_model_dir) +model_config = Qwen2Config.from_pretrained(args.hf_model_dir) +model = Qwen2ForCausalLM.from_pretrained( + args.hf_model_dir, + torch_dtype=torch.float16 +).to(device) +prompt = "你好" +system_prompt: str = "You are a helpful assistant." +history = [] +if len(history) == 0: + history = [{"role": "system", "content": system_prompt}] +history.append({"role": "user", "content": prompt}) +print("history: ", history) +text = tokenizer.apply_chat_template( + history, + tokenize=False, + add_generation_prompt=True +) +print("raw_text", text) +input_ids = tokenizer( + [text], return_tensors="pt" +)["input_ids"].to(device) +print("input_ids", input_ids) +kv_cache1 = create_kv_cache(model_config) +now_kv_cache, attn_mask, position_ids = get_inputs(kv_cache1, 2, ) +print("now_kv_cache shape: ", now_kv_cache.shape) +print("attention_mask shape: ", attn_mask.shape) +print("position_ids shape: ", position_ids.shape) +outputs = model.forward( + input_ids[:, :2], + attn_mask, + position_ids, + now_kv_cache, + # use_cache=True, + output_attentions=True, +) +print("==== pytorch runtime ====") +print("output length: ", len(outputs)) +logits = outputs[0][:, :-1, :] # 1: -0.10800 +# logits = outputs[0][:, -1:, :] # 2: -0.008756 + +print("logits shape: ", logits.shape) +print("logits mean: ", logits.float().mean().item()) +print("logits max: ", logits.float().max().item()) + +new_kv_cache = outputs[1][:, :, :, :, :-1, :] # 1: 0.0009: +# new_kv_cache = outputs[1][:, :, :, :, -1:, :] # 2: 0.003526 + +print("new_kv_cache: shape:", new_kv_cache.shape) +# print("new_kv_cache: mean: ", new_kv_cache.astype(np.float32).mean().item()) +print("new_kv_cache: mean: ", new_kv_cache.float().mean().item()) +# print("new_kv_cache: max: ", new_kv_cache.astype(np.float32).max().item()) +print("new_kv_cache: max: ", new_kv_cache.float().max().item()) +