From 20a49e9f0e3c0385f6c3d66a7fca213a93c91417 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Wed, 3 Jul 2024 10:33:01 +0200 Subject: [PATCH] Add support for TGI truncate parameter (#647) feat(tgi): support truncate parameter --- .../text_generation_server/generator.py | 42 +++++++++++++------ .../tests/server/helpers.py | 7 +++- .../tests/server/test_prefill.py | 32 ++++++++++++++ 3 files changed, 66 insertions(+), 15 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index eaa106088..9be650c98 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -107,6 +107,7 @@ def clear(self): self._batch_id = None self._request_id = None self._inputs = "" + self._truncate = 0 self._generation_config = None self._tokens = [] self._mask = torch.tensor([]) @@ -158,6 +159,8 @@ def assign(self, batch_id: int, request: Request, generation_config: GenerationC self._batch_id = batch_id self._request_id = request.id self._inputs = request.inputs + if request.truncate: + self._truncate = request.truncate self._generation_config = copy.deepcopy(generation_config) # Update generation config with request parameters self._generation_config.do_sample = request.parameters.do_sample @@ -300,6 +303,10 @@ def attention_mask(self) -> torch.LongTensor: def max_token(self) -> int: return self._generation_config.max_length + @property + def truncate(self) -> int: + return self._truncate + class NeuronGenerator(Generator): """A Generator for Neuron models.""" @@ -311,9 +318,10 @@ def __init__( ): self.model = model self.rebuild_cache_on_prefill = not self.model.continuous_batching - # Specify padding options for decoder-only architecture + # Specify padding and truncation options for decoder-only architecture tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" + tokenizer.truncation_side = "left" self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)] @@ -390,13 +398,21 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # - the inputs for new requests, # - only when rebuilding the cache, the inputs and the generated text that has already # been cached (i.e. excluding the last generated token) for unfinished requests. - inputs = [slot.cached_text for slot in prefill_slots] - # Tokenize with padding - padded_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True) - # If needed truncate sequences to fit into the static dimensions - seq_length = min(padded_inputs.input_ids.shape[-1], self.model.max_length) - input_ids = padded_inputs.input_ids[:, :seq_length] - attention_mask = padded_inputs.attention_mask[:, :seq_length] + inputs = [] + max_length = 0 + for slot in prefill_slots: + inputs.append(slot.cached_text) + # Apply truncation, making sure we fit into static dimensions + if slot.truncate == 0: + max_length = self.model.max_length + elif slot.truncate > max_length and slot.truncate < self.model.max_length: + max_length = slot.truncate + # Tokenize with padding and truncation + padded_inputs = self.tokenizer( + inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length + ) + input_ids = padded_inputs.input_ids + attention_mask = padded_inputs.attention_mask # Pause previously active slots during generation next_tokens = [] for slot in active_slots: @@ -405,12 +421,12 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # The slot will be reset, so we need to store its next token next_tokens.append(slot.next_token) # Each slot must be reset with the padded inputs and masks - if self.rebuild_cache_on_prefill: - reset_slots = self.slots - else: - reset_slots = prefill_slots - for i, slot in enumerate(reset_slots): + for i, slot in enumerate(prefill_slots): if slot.state != slot.state.EMPTY: + if slot.truncate > 0 and slot.truncate < input_ids.shape[-1]: + # Apply per-request truncation + input_ids[i, : -slot.truncate] = self.tokenizer.pad_token_id + attention_mask[i, : -slot.truncate] = 0 slot_input_ids = input_ids[i : i + 1, :] # Padded input ids are also required to set logits processors and stopping criterias selector = TokenSelector.create( diff --git a/text-generation-inference/tests/server/helpers.py b/text-generation-inference/tests/server/helpers.py index dbf2dcad3..81547cb6a 100644 --- a/text-generation-inference/tests/server/helpers.py +++ b/text-generation-inference/tests/server/helpers.py @@ -10,7 +10,8 @@ def create_request( id: int, inputs: str, - max_new_tokens=20, + truncate: int = 0, + max_new_tokens: int = 20, do_sample: bool = False, top_k: int = 50, top_p: float = 0.9, @@ -27,7 +28,9 @@ def create_request( repetition_penalty=repetition_penalty, ) stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens) - return Request(id=id, inputs=inputs, parameters=parameters, stopping_parameters=stopping_parameters) + return Request( + id=id, inputs=inputs, truncate=truncate, parameters=parameters, stopping_parameters=stopping_parameters + ) def check_prefill(input_text, expected_token_id, expected_token_text, do_sample, batch_size, model_path): diff --git a/text-generation-inference/tests/server/test_prefill.py b/text-generation-inference/tests/server/test_prefill.py index 651fa87f0..6412f926f 100644 --- a/text-generation-inference/tests/server/test_prefill.py +++ b/text-generation-inference/tests/server/test_prefill.py @@ -41,3 +41,35 @@ def _test_prefill(config_name, generator, batch_size, do_sample): tokens = g.tokens assert tokens.ids[0] == expectations[0] assert tokens.texts[0] == expectations[1] + + +def test_prefill_truncate(neuron_model_config): + config_name = neuron_model_config["name"] + neuron_model_path = neuron_model_config["neuron_model_path"] + generator = NeuronGenerator.from_pretrained(neuron_model_path) + batch_size = generator.model.batch_size + # We apply truncation to all requests but the first one + truncate = [ + None, + ] + [i * 3 for i in range(1, batch_size)] + input_text = ( + "Two gin-scented tears trickled down the sides of his nose." + " But it was all right, everything was all right, the struggle was finished." + " He had won the victory over himself. He loved Big Brother." + ) + requests = [] + for i in range(batch_size): + requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i])) + max_length = generator.model.max_length + batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length) + generations, _ = generator.prefill(batch) + # Even if the input text is identical for all requests, the first generated token might + # be different because of the truncation + expectations = { + "gpt2": [" He", " He", "\n", " He"], + "llama": ["\n", "\n", " He", "\n"], + "mistral": [" He", "\n", " He", " He"], + }[config_name] + for i, g in enumerate(generations): + tokens = g.tokens + assert tokens.texts[0] == expectations[i]