Skip to content

Commit

Permalink
Add support for TGI truncate parameter (#647)
Browse files Browse the repository at this point in the history
feat(tgi): support truncate parameter
  • Loading branch information
dacorvo authored Jul 3, 2024
1 parent fca97de commit 20a49e9
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def clear(self):
self._batch_id = None
self._request_id = None
self._inputs = ""
self._truncate = 0
self._generation_config = None
self._tokens = []
self._mask = torch.tensor([])
Expand Down Expand Up @@ -158,6 +159,8 @@ def assign(self, batch_id: int, request: Request, generation_config: GenerationC
self._batch_id = batch_id
self._request_id = request.id
self._inputs = request.inputs
if request.truncate:
self._truncate = request.truncate
self._generation_config = copy.deepcopy(generation_config)
# Update generation config with request parameters
self._generation_config.do_sample = request.parameters.do_sample
Expand Down Expand Up @@ -300,6 +303,10 @@ def attention_mask(self) -> torch.LongTensor:
def max_token(self) -> int:
return self._generation_config.max_length

@property
def truncate(self) -> int:
return self._truncate


class NeuronGenerator(Generator):
"""A Generator for Neuron models."""
Expand All @@ -311,9 +318,10 @@ def __init__(
):
self.model = model
self.rebuild_cache_on_prefill = not self.model.continuous_batching
# Specify padding options for decoder-only architecture
# Specify padding and truncation options for decoder-only architecture
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids
self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)]
Expand Down Expand Up @@ -390,13 +398,21 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# - the inputs for new requests,
# - only when rebuilding the cache, the inputs and the generated text that has already
# been cached (i.e. excluding the last generated token) for unfinished requests.
inputs = [slot.cached_text for slot in prefill_slots]
# Tokenize with padding
padded_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True)
# If needed truncate sequences to fit into the static dimensions
seq_length = min(padded_inputs.input_ids.shape[-1], self.model.max_length)
input_ids = padded_inputs.input_ids[:, :seq_length]
attention_mask = padded_inputs.attention_mask[:, :seq_length]
inputs = []
max_length = 0
for slot in prefill_slots:
inputs.append(slot.cached_text)
# Apply truncation, making sure we fit into static dimensions
if slot.truncate == 0:
max_length = self.model.max_length
elif slot.truncate > max_length and slot.truncate < self.model.max_length:
max_length = slot.truncate
# Tokenize with padding and truncation
padded_inputs = self.tokenizer(
inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length
)
input_ids = padded_inputs.input_ids
attention_mask = padded_inputs.attention_mask
# Pause previously active slots during generation
next_tokens = []
for slot in active_slots:
Expand All @@ -405,12 +421,12 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# The slot will be reset, so we need to store its next token
next_tokens.append(slot.next_token)
# Each slot must be reset with the padded inputs and masks
if self.rebuild_cache_on_prefill:
reset_slots = self.slots
else:
reset_slots = prefill_slots
for i, slot in enumerate(reset_slots):
for i, slot in enumerate(prefill_slots):
if slot.state != slot.state.EMPTY:
if slot.truncate > 0 and slot.truncate < input_ids.shape[-1]:
# Apply per-request truncation
input_ids[i, : -slot.truncate] = self.tokenizer.pad_token_id
attention_mask[i, : -slot.truncate] = 0
slot_input_ids = input_ids[i : i + 1, :]
# Padded input ids are also required to set logits processors and stopping criterias
selector = TokenSelector.create(
Expand Down
7 changes: 5 additions & 2 deletions text-generation-inference/tests/server/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
def create_request(
id: int,
inputs: str,
max_new_tokens=20,
truncate: int = 0,
max_new_tokens: int = 20,
do_sample: bool = False,
top_k: int = 50,
top_p: float = 0.9,
Expand All @@ -27,7 +28,9 @@ def create_request(
repetition_penalty=repetition_penalty,
)
stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens)
return Request(id=id, inputs=inputs, parameters=parameters, stopping_parameters=stopping_parameters)
return Request(
id=id, inputs=inputs, truncate=truncate, parameters=parameters, stopping_parameters=stopping_parameters
)


def check_prefill(input_text, expected_token_id, expected_token_text, do_sample, batch_size, model_path):
Expand Down
32 changes: 32 additions & 0 deletions text-generation-inference/tests/server/test_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,35 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
tokens = g.tokens
assert tokens.ids[0] == expectations[0]
assert tokens.texts[0] == expectations[1]


def test_prefill_truncate(neuron_model_config):
config_name = neuron_model_config["name"]
neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path)
batch_size = generator.model.batch_size
# We apply truncation to all requests but the first one
truncate = [
None,
] + [i * 3 for i in range(1, batch_size)]
input_text = (
"Two gin-scented tears trickled down the sides of his nose."
" But it was all right, everything was all right, the struggle was finished."
" He had won the victory over himself. He loved Big Brother."
)
requests = []
for i in range(batch_size):
requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))
max_length = generator.model.max_length
batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length)
generations, _ = generator.prefill(batch)
# Even if the input text is identical for all requests, the first generated token might
# be different because of the truncation
expectations = {
"gpt2": [" He", " He", "\n", " He"],
"llama": ["\n", "\n", " He", "\n"],
"mistral": [" He", "\n", " He", " He"],
}[config_name]
for i, g in enumerate(generations):
tokens = g.tokens
assert tokens.texts[0] == expectations[i]

0 comments on commit 20a49e9

Please sign in to comment.