diff --git a/utils/engine.py b/utils/engine.py index ebfb697..9812141 100644 --- a/utils/engine.py +++ b/utils/engine.py @@ -88,6 +88,7 @@ def __init__(self, config: InferenceConfig, context=None,callback=None): self.exit_flag = False self.max_batch = config.max_batch self.kv_cache_length = config.kv_cache_length + self.max_prefill_length = config.max_prefill_length # kv_cache的长度和max_output_length的长度一样 self.past_kv_size=self.kv_cache_length self.input_pos = 0 @@ -130,9 +131,9 @@ def get_inputs(self, seq_len: int) -> List[np.ndarray]: self.num_hidden_layers * 2 * self.num_key_value_heads, self.per_head_dim ) - """ + """ temp_seq_len = self.real_kv_size + seq_len - if temp_seq_len <= self.kv_cache_length // 2: + if self.max_prefill_length > 1 and temp_seq_len <= self.kv_cache_length // 2: temp_kv_size = self.kv_cache_length // 2 else: temp_kv_size = self.kv_cache_length