|
| 1 | +from typing import Optional, Union |
| 2 | + |
| 3 | +import torch |
| 4 | +from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList |
| 5 | +from transformers.generation.utils import GenerateNonBeamOutput, GenerationMixin |
| 6 | +from transformers.modeling_outputs import BaseModelOutputWithPast |
| 7 | +from transformers.models.llama import LlamaForCausalLM |
| 8 | + |
| 9 | +from petals.models.llama.config import DistributedLlamaConfig |
| 10 | +from petals.models.llama.model import DistributedLlamaForCausalLM |
| 11 | + |
| 12 | + |
| 13 | +class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin): |
| 14 | + def __init__(self, config: DistributedLlamaConfig, small_model: LlamaForCausalLM): |
| 15 | + DistributedLlamaForCausalLM.__init__(self, config) |
| 16 | + self.small_model = small_model |
| 17 | + |
| 18 | + def _sample( |
| 19 | + self, |
| 20 | + input_ids: torch.LongTensor, |
| 21 | + logits_processor: LogitsProcessorList, |
| 22 | + stopping_criteria: StoppingCriteriaList, |
| 23 | + generation_config: GenerationConfig, |
| 24 | + synced_gpus: bool, |
| 25 | + streamer: Optional["BaseStreamer"], |
| 26 | + logits_warper: Optional[LogitsProcessorList], |
| 27 | + speculative_inference_iteration_size: int = 10, |
| 28 | + **model_kwargs, |
| 29 | + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: |
| 30 | + assert not generation_config.do_sample, "sample is not working for speculative generation now" |
| 31 | + assert not synced_gpus, "synced_gpus is not working for speculative generation now" |
| 32 | + assert ( |
| 33 | + not generation_config.return_dict_in_generate |
| 34 | + ), "return_dict_in_generate is not working for speculative generation now" |
| 35 | + |
| 36 | + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) |
| 37 | + |
| 38 | + # keep track of which sequences are already finished |
| 39 | + batch_size = input_ids.shape[0] |
| 40 | + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) |
| 41 | + finished = False |
| 42 | + firsts = True |
| 43 | + |
| 44 | + while not finished: |
| 45 | + speculative_inference_iteration_size = min( |
| 46 | + speculative_inference_iteration_size, self.active_session._max_length - input_ids.shape[1] |
| 47 | + ) |
| 48 | + with torch.no_grad(): |
| 49 | + speculative_outputs = self.small_model.generate( |
| 50 | + input_ids, |
| 51 | + max_new_tokens=speculative_inference_iteration_size, |
| 52 | + do_sample=False, |
| 53 | + ) |
| 54 | + speculative_tokens = speculative_outputs[:, -speculative_inference_iteration_size:] |
| 55 | + |
| 56 | + full_sequence = torch.cat([input_ids, speculative_tokens], dim=-1) |
| 57 | + assert input_ids.shape[1] + speculative_inference_iteration_size == full_sequence.shape[1] |
| 58 | + |
| 59 | + input_for_validation = full_sequence |
| 60 | + if not firsts: |
| 61 | + self.active_session.position = input_ids.shape[1] - 1 |
| 62 | + input_for_validation = input_for_validation[:, -speculative_inference_iteration_size - 1 :] |
| 63 | + else: |
| 64 | + firsts = False |
| 65 | + input_for_validation = input_for_validation[:, :-1] |
| 66 | + with torch.no_grad(): |
| 67 | + precise_model_outputs = self(input_for_validation) |
| 68 | + full_token_logits = precise_model_outputs.logits[:, -speculative_inference_iteration_size:, :].clone() |
| 69 | + |
| 70 | + all_valid_tokens = [] |
| 71 | + first_token = None |
| 72 | + for i in range(speculative_inference_iteration_size): |
| 73 | + token_logits = full_token_logits[:, i, :] |
| 74 | + token_scores = logits_processor( |
| 75 | + input_for_validation[:, : -speculative_inference_iteration_size + 1 + i], token_logits |
| 76 | + ) |
| 77 | + valid_token = torch.argmax(token_scores, dim=-1) |
| 78 | + |
| 79 | + if first_token is None: |
| 80 | + first_token = valid_token |
| 81 | + |
| 82 | + if valid_token.item() == speculative_tokens[:, i].item(): |
| 83 | + all_valid_tokens.append(valid_token.unsqueeze(-1)) |
| 84 | + else: |
| 85 | + break |
| 86 | + |
| 87 | + if not all_valid_tokens and first_token is not None: |
| 88 | + all_valid_tokens.append(first_token.unsqueeze(-1)) |
| 89 | + all_valid_tokens = torch.cat(all_valid_tokens, dim=-1) |
| 90 | + |
| 91 | + # finished sentences should have their next token be a padding token |
| 92 | + if has_eos_stopping_criteria: |
| 93 | + all_valid_tokens = all_valid_tokens * unfinished_sequences + generation_config.pad_token_id * ( |
| 94 | + 1 - unfinished_sequences |
| 95 | + ) |
| 96 | + |
| 97 | + # update generated ids, model inputs, and length for next step |
| 98 | + input_ids = torch.cat([input_ids, all_valid_tokens], dim=-1) |
| 99 | + |
| 100 | + if streamer is not None: |
| 101 | + streamer.put(all_valid_tokens.cpu()) |
| 102 | + |
| 103 | + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None) |
| 104 | + finished = unfinished_sequences.max() == 0 |
| 105 | + |
| 106 | + del precise_model_outputs |
| 107 | + |
| 108 | + if streamer is not None: |
| 109 | + streamer.end() |
| 110 | + |
| 111 | + return input_ids |
0 commit comments