From 8c90e4dd5297e14758371b0d675ad3fdb517ed14 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 23 Apr 2024 01:45:10 +0000 Subject: [PATCH 01/10] reuse kv cache prefix --- mii/batching/ragged_batching.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index e402d01f..b5c0e5f5 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -273,24 +273,31 @@ def _schedule_prompts(self, requests: List[Request]) -> None: max_blocks = free_blocks - self.scheduled_req_blocks - if len(r.input_tokens) > 1: + cache_hit_length, block_ids = self.inference_engine.lookup_cache(r.uid, r.input_tokens) + input_tokens = r.input_tokens[cache_hit_length:] + + if len(input_tokens) > 1: # When the KV cache is out of capacity, we release KV cache blocks for a request. # However, we can immediately schedule the request again if we split the request. # So we make sure that we have capacity for the entire prompt (+tokens already generated). - req_tokens, _ = self.inference_engine.query(r.uid, len(r.input_tokens), max_blocks) - if req_tokens < len(r.input_tokens): + req_tokens, _ = self.inference_engine.query(r.uid, len(input_tokens), max_blocks) + if req_tokens < len(input_tokens): break - req_tokens = min(len(r.input_tokens), max_batch_size) + req_tokens = min(len(input_tokens), max_batch_size) req_tokens, req_blocks = self.inference_engine.query(r.uid, req_tokens, max_blocks) if req_tokens <= 0: continue # Decompose the prompt to fit to the max ragged batch size - decomposed = req_tokens < len(r.input_tokens) - remaining_tokens = r.input_tokens[req_tokens:] - r.input_tokens = r.input_tokens[:req_tokens] + if cache_hit_length > 0: + self.inference_engine.setup_cached_sequence(r.uid, r.input_tokens, block_ids) + r.seq_length = r.seq_length + cache_hit_length + + decomposed = req_tokens < len(input_tokens) + remaining_tokens = input_tokens[req_tokens:] + r.input_tokens = input_tokens[:req_tokens] r.last_in_prompt = not decomposed # Schedule the request From 18da1441bb5bd905da49a9b612897351ebe97b41 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 26 Apr 2024 03:03:23 +0000 Subject: [PATCH 02/10] fix argument --- mii/batching/ragged_batching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index b5c0e5f5..45eb0f74 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -273,7 +273,7 @@ def _schedule_prompts(self, requests: List[Request]) -> None: max_blocks = free_blocks - self.scheduled_req_blocks - cache_hit_length, block_ids = self.inference_engine.lookup_cache(r.uid, r.input_tokens) + cache_hit_length, block_ids = self.inference_engine.lookup_cache(r.input_tokens) input_tokens = r.input_tokens[cache_hit_length:] if len(input_tokens) > 1: From 08f579d2ca16def7d3c9f55e1da066c896ebff9c Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sun, 28 Apr 2024 22:02:15 +0000 Subject: [PATCH 03/10] fix alloc arg --- mii/batching/ragged_batching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index 45eb0f74..af2ccde0 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -292,7 +292,7 @@ def _schedule_prompts(self, requests: List[Request]) -> None: # Decompose the prompt to fit to the max ragged batch size if cache_hit_length > 0: - self.inference_engine.setup_cached_sequence(r.uid, r.input_tokens, block_ids) + self.inference_engine.setup_cached_sequence(r.uid, r.input_tokens.numel(), block_ids) r.seq_length = r.seq_length + cache_hit_length decomposed = req_tokens < len(input_tokens) From cfb9b3a64bdc312d46da8126ad635811e9942a80 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 8 May 2024 01:14:06 +0000 Subject: [PATCH 04/10] add option to prefix cache --- mii/batching/ragged_batching.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index af2ccde0..40b30893 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -82,6 +82,8 @@ def __init__(self, inference_engine, tokenizer, model_config): self.socket.setsockopt_string(zmq.SUBSCRIBE, "") self.socket.setsockopt(zmq.RCVTIMEO, ZMQ_RECV_TIMEOUT) + self.enable_prefix_cache = self.inference_engine._config.enable_prefix_cache + @cached_property def local_rank(self) -> int: return get_accelerator().current_device() @@ -131,6 +133,11 @@ def generate(self) -> None: if not r.stop_generation: r.set_next_as_input() self.request_queue.put(r) + if r.stop_generation: + if len(r.generated_tokens) > 0: + all_tokens = torch.cat([t.unsqueeze(0) for t in r.generated_tokens], dim=0) + all_tokens = torch.cat([r.prompt_tokens, all_tokens], dim=0) + self.inference_engine.update_prefix_cache(r.uid, all_tokens) # 7. Update scheduled requests self.scheduled_requests.prune(running_requests.completed.uids) @@ -273,8 +280,13 @@ def _schedule_prompts(self, requests: List[Request]) -> None: max_blocks = free_blocks - self.scheduled_req_blocks - cache_hit_length, block_ids = self.inference_engine.lookup_cache(r.input_tokens) - input_tokens = r.input_tokens[cache_hit_length:] + input_tokens = r.input_tokens + if r.seq_length == 0: + cache_hit_length, block_ids = self.inference_engine.lookup_cache(r.input_tokens) + input_tokens = input_tokens[cache_hit_length:] + else: + cache_hit_length = 0 + block_ids = [] if len(input_tokens) > 1: # When the KV cache is out of capacity, we release KV cache blocks for a request. @@ -292,7 +304,7 @@ def _schedule_prompts(self, requests: List[Request]) -> None: # Decompose the prompt to fit to the max ragged batch size if cache_hit_length > 0: - self.inference_engine.setup_cached_sequence(r.uid, r.input_tokens.numel(), block_ids) + self.inference_engine.setup_cached_sequence(r.uid, cache_hit_length, block_ids) r.seq_length = r.seq_length + cache_hit_length decomposed = req_tokens < len(input_tokens) From 7444160aaf7c29bcebe82bcf9738e1dbb823d970 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 17 May 2024 22:59:55 +0000 Subject: [PATCH 05/10] flush the last request --- mii/batching/ragged_batching.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index 40b30893..1f60d025 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -135,7 +135,8 @@ def generate(self) -> None: self.request_queue.put(r) if r.stop_generation: if len(r.generated_tokens) > 0: - all_tokens = torch.cat([t.unsqueeze(0) for t in r.generated_tokens], dim=0) + all_tokens = torch.cat([t.unsqueeze(0) for t in r.generated_tokens], + dim=0) all_tokens = torch.cat([r.prompt_tokens, all_tokens], dim=0) self.inference_engine.update_prefix_cache(r.uid, all_tokens) @@ -304,7 +305,9 @@ def _schedule_prompts(self, requests: List[Request]) -> None: # Decompose the prompt to fit to the max ragged batch size if cache_hit_length > 0: - self.inference_engine.setup_cached_sequence(r.uid, cache_hit_length, block_ids) + self.inference_engine.setup_cached_sequence(r.uid, + cache_hit_length, + block_ids) r.seq_length = r.seq_length + cache_hit_length decomposed = req_tokens < len(input_tokens) @@ -589,11 +592,15 @@ def __call__(self, if self.is_rank_0: # Rank 0 runs generate() until all responses are returned - while uids_running: + while uids_running \ + or not self.request_queue.empty() \ + or self.scheduled_requests.requests_to_flush.uids: self.generate() while not self.result_queues[self.tid].empty(): uid, response = self._get_response() outputs.append(response) + # We can't directly call flush because the flush request is broadcasted + # to other ranks after taken from the queue self._queue_flush_request(uid) uids_complete_order.append(uid) uids_running.remove(uid) From 42a0386fd66db09d2809be9bed8ff9956fb9aed7 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sat, 25 May 2024 02:19:49 +0000 Subject: [PATCH 06/10] add debug func --- mii/batching/ragged_batching.py | 90 +++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index 1f60d025..80cd88e3 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -624,6 +624,96 @@ def __call__(self, return outputs + def call_dbg(self, + prompts: Union[str, + List[str]], + **generate_kwargs) -> List[Response]: + """ + Generates text for the given prompts + + :param prompts: The string or list of strings used as prompts for generation. + :param \**generate_kwargs: Generation keywords. A full list can be found + in :class:`GenerateParamsConfig `. + + :return: A list of :class:`Response` objects containing the generated + text for all prompts. + """ # noqa: W605 + if self._destroyed: + raise RuntimeError( + "The inference engine of this pipeline has been destroyed.") + + delay = 3 + + if isinstance(prompts, str): + prompts = [prompts] + outputs: List[Response] = [] + uids_running: List[int] = [] + uids_complete_order: List[int] = [] + + requests = [] + for uid, input in enumerate(prompts): + request_kwargs = generate_kwargs.copy() + requests.append((uid, input, request_kwargs)) + + req = requests.pop(0) + self._put_request(req[0], req[1], req[2]) + uids_running.append(req[0]) + + self.schedule_requests() + iteration = 0 + + if self.is_rank_0: + # Rank 0 runs generate() until all responses are returned + while uids_running \ + or not self.request_queue.empty() \ + or self.scheduled_requests.requests_to_flush.uids: + + print(f"iteration {iteration} running_requests {uids_running} scheduled_requests {self.scheduled_requests.requests_to_run.uids} flush_requests {self.scheduled_requests.requests_to_flush.uids}") + + if iteration % delay == 0 and len(requests) > 0: + req = requests.pop(0) + self._put_request(req[0], req[1], req[2]) + uids_running.append(req[0]) + print(f"Request uid={req[0]} added to queue") + + self.generate() + while not self.result_queues[self.tid].empty(): + uid, response = self._get_response() + outputs.append(response) + # We can't directly call flush because the flush request is broadcasted + # to other ranks after taken from the queue + self._queue_flush_request(uid) + uids_complete_order.append(uid) + uids_running.remove(uid) + + + if iteration > 1000: + # import pdb; pdb.set_trace() + break + + iteration += 1 + # Ensure final flush requests broadcast and + # kick ranks 1 -> n out of the while loop + self._bcast_requests(force=True) + else: + # Ranks 1 -> n just run generate() until there are no more requests + while self.scheduled_requests: + self.generate() + iteration += 1 + + outputs = [ + r for idx, + r in sorted(zip(uids_complete_order, + outputs), + key=lambda pair: pair[0]) + ] + + if self._all_rank_output: + outputs = self._bcast_responses(outputs) + + return outputs + + def _put_request(self, uid: int, input: str, kwargs: Dict[str, Any]) -> None: self.result_queues[self.tid] = queue.Queue() input_tokens = self.tokenizer.encode(input) From b94bc7ad7469bfcbec38cf9ee880ea3ee19d1b40 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 27 May 2024 20:30:47 +0000 Subject: [PATCH 07/10] remove debug function --- mii/batching/ragged_batching.py | 90 --------------------------------- 1 file changed, 90 deletions(-) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index 80cd88e3..1f60d025 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -624,96 +624,6 @@ def __call__(self, return outputs - def call_dbg(self, - prompts: Union[str, - List[str]], - **generate_kwargs) -> List[Response]: - """ - Generates text for the given prompts - - :param prompts: The string or list of strings used as prompts for generation. - :param \**generate_kwargs: Generation keywords. A full list can be found - in :class:`GenerateParamsConfig `. - - :return: A list of :class:`Response` objects containing the generated - text for all prompts. - """ # noqa: W605 - if self._destroyed: - raise RuntimeError( - "The inference engine of this pipeline has been destroyed.") - - delay = 3 - - if isinstance(prompts, str): - prompts = [prompts] - outputs: List[Response] = [] - uids_running: List[int] = [] - uids_complete_order: List[int] = [] - - requests = [] - for uid, input in enumerate(prompts): - request_kwargs = generate_kwargs.copy() - requests.append((uid, input, request_kwargs)) - - req = requests.pop(0) - self._put_request(req[0], req[1], req[2]) - uids_running.append(req[0]) - - self.schedule_requests() - iteration = 0 - - if self.is_rank_0: - # Rank 0 runs generate() until all responses are returned - while uids_running \ - or not self.request_queue.empty() \ - or self.scheduled_requests.requests_to_flush.uids: - - print(f"iteration {iteration} running_requests {uids_running} scheduled_requests {self.scheduled_requests.requests_to_run.uids} flush_requests {self.scheduled_requests.requests_to_flush.uids}") - - if iteration % delay == 0 and len(requests) > 0: - req = requests.pop(0) - self._put_request(req[0], req[1], req[2]) - uids_running.append(req[0]) - print(f"Request uid={req[0]} added to queue") - - self.generate() - while not self.result_queues[self.tid].empty(): - uid, response = self._get_response() - outputs.append(response) - # We can't directly call flush because the flush request is broadcasted - # to other ranks after taken from the queue - self._queue_flush_request(uid) - uids_complete_order.append(uid) - uids_running.remove(uid) - - - if iteration > 1000: - # import pdb; pdb.set_trace() - break - - iteration += 1 - # Ensure final flush requests broadcast and - # kick ranks 1 -> n out of the while loop - self._bcast_requests(force=True) - else: - # Ranks 1 -> n just run generate() until there are no more requests - while self.scheduled_requests: - self.generate() - iteration += 1 - - outputs = [ - r for idx, - r in sorted(zip(uids_complete_order, - outputs), - key=lambda pair: pair[0]) - ] - - if self._all_rank_output: - outputs = self._bcast_responses(outputs) - - return outputs - - def _put_request(self, uid: int, input: str, kwargs: Dict[str, Any]) -> None: self.result_queues[self.tid] = queue.Queue() input_tokens = self.tokenizer.encode(input) From 414f6964c085dcb663b61a7140b731fd1bd57094 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 30 May 2024 02:41:02 +0000 Subject: [PATCH 08/10] save prefix cache at every iteration --- mii/batching/data_classes.py | 6 ++++++ mii/batching/ragged_batching.py | 8 ++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mii/batching/data_classes.py b/mii/batching/data_classes.py index 81cabd32..a005fa25 100644 --- a/mii/batching/data_classes.py +++ b/mii/batching/data_classes.py @@ -134,6 +134,12 @@ def is_done(self, is_done: bool) -> None: def generated_tokens(self) -> List[torch.Tensor]: return self._generated_tokens + @property + def all_tokens(self) -> List[torch.Tensor]: + return torch.cat([self.prompt_tokens] + + [t.unsqueeze(0) for t in self.generated_tokens], + dim=0) + @property def finish_reason(self) -> GenerationFinishReason: return self._finish_reason diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index f61c4d29..6f902efe 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -133,12 +133,8 @@ def generate(self) -> None: if not r.stop_generation: r.set_next_as_input() self.request_queue.put(r) - if r.stop_generation: - if len(r.generated_tokens) > 0: - all_tokens = torch.cat([t.unsqueeze(0) for t in r.generated_tokens], - dim=0) - all_tokens = torch.cat([r.prompt_tokens, all_tokens], dim=0) - self.inference_engine.update_prefix_cache(r.uid, all_tokens) + if len(r.generated_tokens) > 0: + self.inference_engine.update_prefix_cache(r.uid, r.all_tokens) # 7. Update scheduled requests self.scheduled_requests.prune(running_requests.completed.uids) From 263bf2155c0967ae3ff18c1a9c6869bd6d15c8a0 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 31 May 2024 07:36:39 +0000 Subject: [PATCH 09/10] fix token count for prefix cache --- mii/batching/ragged_batching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index 6f902efe..3db9eb1e 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -126,6 +126,8 @@ def generate(self) -> None: # 6. Accumulate generated tokens, check completion, and generate output for r in running_requests.last_in_prompt: + if len(r.generated_tokens) > 0: + self.inference_engine.update_prefix_cache(r.uid, r.all_tokens) r.accumulate_generated_token() self._num_generated_tokens += 1 if r.stop_generation or r.stream: @@ -133,8 +135,6 @@ def generate(self) -> None: if not r.stop_generation: r.set_next_as_input() self.request_queue.put(r) - if len(r.generated_tokens) > 0: - self.inference_engine.update_prefix_cache(r.uid, r.all_tokens) # 7. Update scheduled requests self.scheduled_requests.prune(running_requests.completed.uids) From 584407c5057f381f4d9028264072df1dd892b786 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 3 Jun 2024 00:10:12 +0000 Subject: [PATCH 10/10] fix prefix cache target --- mii/batching/ragged_batching.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index 3db9eb1e..2724da7b 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -124,10 +124,11 @@ def generate(self) -> None: # 5. Schedule requests while we wait for the forward pass to finish self._reset_scheduler_bookkeeping() + for r in running_requests: + self.inference_engine.update_prefix_cache(r.uid, r.all_tokens) + # 6. Accumulate generated tokens, check completion, and generate output for r in running_requests.last_in_prompt: - if len(r.generated_tokens) > 0: - self.inference_engine.update_prefix_cache(r.uid, r.all_tokens) r.accumulate_generated_token() self._num_generated_tokens += 1 if r.stop_generation or r.stream: