diff --git a/generate.py b/generate.py index 8446d115..388e188f 100644 --- a/generate.py +++ b/generate.py @@ -177,6 +177,7 @@ def generate( next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) if is_speculative: prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs) + next_token = next_token.clone() seq[T] = next_token input_pos = torch.tensor([T], device=device, dtype=torch.int)