diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cd4c22c --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*__pycache__* diff --git a/README.md b/README.md index 6c09ed7..9c489ff 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ The speculative sampling is proposed by Google and Deepmind independently. So I In the sample, I use [bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1/tree/main) as the target model, [bloom-560m](https://huggingface.co/bigscience/bloom-560m/tree/main) as the approximation model. ```bash -python sample.py \ +python main.py \ --input "The quick brown fox jumps over the lazy " \ --target_model_name bigscience/bloomz-7b1 \ --approx_model_name bigscience/bloom-560m diff --git a/main.py b/main.py new file mode 100644 index 0000000..a6404f1 --- /dev/null +++ b/main.py @@ -0,0 +1,93 @@ + +import torch +import argparse +import contexttimer + +from transformers import AutoTokenizer, AutoModelForCausalLM + +from sampling import autoregressive_sampling, speculative_sampling, speculative_sampling_v2 + +class Decoder: + def __init__(self, tokenizer) -> None: + self.tokenizer = tokenizer + + def decode(self, t : torch.Tensor) -> str: + # assert t.dim == 2, "t must be 2d tensor" + return self.tokenizer.decode(t[0], skip_special_tokens=True) + +DECODER : Decoder = None + +MODELZOO = { + "llama7b": "/share_nfs/tianzhi/code/llama-7b", + "bloom7b": "/share_nfs/fangjiarui/root/code/hf_models/bloomz-7b1", + "bloom-560m": "/share_nfs/fangjiarui/root/code/hf_models/bloom-560m", +} + +def parse_arguments(): + parser = argparse.ArgumentParser(description='args for sample.py') + + parser.add_argument('--input', type=str, default="Suggest at least five related search terms to \"Mạng neural nhân tạo\".") + parser.add_argument('--approx_model_name', type=str, default="/share_nfs/fangjiarui/root/code/hf_models/bloom-560m") + parser.add_argument('--target_model_name', type=str, default="/share_nfs/fangjiarui/root/code/hf_models/bloomz-7b1") + parser.add_argument('--verbose', '-v', action='store_true', default=False, help='enable verbose mode') + parser.add_argument('--seed', '-s', type=int, default=None, help='set a random seed') + args = parser.parse_args() + return args + + +def generate(input_text, approx_model_name, target_model_name, num_tokens=40, random_seed = None, verbose = False): + # NOTE() approx_model_name and target_model_name should use the same tokenizer! + + torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' + + tokenizer = AutoTokenizer.from_pretrained(approx_model_name) + + global DECODER + DECODER = Decoder(tokenizer) + + print("begin loading models") + small_model = AutoModelForCausalLM.from_pretrained(approx_model_name).to(torch_device) + large_model = AutoModelForCausalLM.from_pretrained(target_model_name).to(torch_device) + print("finish loading models") + + input_ids = tokenizer.encode(input_text, return_tensors='pt').to(torch_device) + + top_k = 10 + top_p = 0.9 + + torch.manual_seed(123) + output = autoregressive_sampling(input_ids, large_model, num_tokens, top_k = top_k, top_p=top_p) + generated_text = tokenizer.decode(output[0], skip_special_tokens=True) + print(f"large (target) model autoregressive_sampling: {generated_text}") + + TEST_TIME = 10 + with contexttimer.Timer() as t: + for _ in range(TEST_TIME): + output = autoregressive_sampling(input_ids, large_model, num_tokens, top_k = top_k, top_p=top_p) + print(f"large (target) model autoregressive_sampling 10 times, tokens/sec: {len(output[0] / t.elapsed / TEST_TIME)}") + + + torch.manual_seed(123) + output = autoregressive_sampling(input_ids, small_model, num_tokens, top_k = top_k, top_p=top_p) + generated_text = tokenizer.decode(output[0], skip_special_tokens=True) + print(f"small (approx) model autoregressive_sampling: {generated_text}") + + torch.manual_seed(123) + output = speculative_sampling_v2(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed) + generated_text = tokenizer.decode(output[0], skip_special_tokens=True) + print(f"deepmind's speculative_sampling: {generated_text}") + + torch.manual_seed(123) + output = speculative_sampling(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed, verbose = verbose) + generated_text = tokenizer.decode(output[0], skip_special_tokens=True) + print(f"google's speculative_sampling: {generated_text}") + + with contexttimer.Timer() as t: + for _ in range(TEST_TIME): + output = speculative_sampling(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed) + print(f"speculative_sampling 10 times, tokens/sec: {len(output[0] / t.elapsed / TEST_TIME)}") + +if __name__ == "__main__": + args = parse_arguments() + # args.approx_model_name = MODELZOO["llama7b"] + generate(args.input, args.approx_model_name, args.target_model_name, random_seed = args.seed, verbose=args.verbose) diff --git a/requirements.txt b/requirements.txt index be43cc7..fdd9f70 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ transformers==4.29.2 torch==2.0.1 +contexttimer \ No newline at end of file diff --git a/sampling/__init__.py b/sampling/__init__.py new file mode 100644 index 0000000..c652916 --- /dev/null +++ b/sampling/__init__.py @@ -0,0 +1,4 @@ +from sampling.speculative_sampling import speculative_sampling, speculative_sampling_v2 +from sampling.autoregressive_sampling import autoregressive_sampling + +__all__ = ["speculative_sampling", "speculative_sampling_v2", "autoregressive_sampling"] \ No newline at end of file diff --git a/sampling/autoregressive_sampling.py b/sampling/autoregressive_sampling.py new file mode 100644 index 0000000..54009f9 --- /dev/null +++ b/sampling/autoregressive_sampling.py @@ -0,0 +1,31 @@ +import torch + +from tqdm import tqdm +from sampling.utils import norm_logits, sample + +@torch.no_grad() +def autoregressive_sampling(x : torch.Tensor, model : torch.nn.Module, N : int, + temperature : float = 1, top_k : int = 0, top_p : float = 0): + n = len(x) + T = len(x) + N + + past_key_values = None + with tqdm(total=N, desc="autoregressive sampling") as pbar: + while n < T: + # outputs = model(x) + if past_key_values: + last_ids = x[:, -1] + if last_ids.dim() == 1: + last_ids = torch.unsqueeze(last_ids, 0) + outputs = model(last_ids, past_key_values = past_key_values, use_cache = True) + else: + outputs = model(x) + last_p = norm_logits(outputs.logits[::, -1, :], temperature, top_k, top_p) + past_key_values = outputs.past_key_values + idx_next = sample(last_p) + x = torch.cat((x, idx_next), dim=1) + n += 1 + pbar.update(1) + + return x + diff --git a/kvcache_model.py b/sampling/kvcache_model.py similarity index 79% rename from kvcache_model.py rename to sampling/kvcache_model.py index 934a2de..1b31480 100644 --- a/kvcache_model.py +++ b/sampling/kvcache_model.py @@ -1,7 +1,10 @@ import torch -from utils import norm_logits, sample from typing import Optional +from sampling.utils import norm_logits, sample +from transformers.models.llama.modeling_llama import LlamaForCausalLM +from transformers.models.bloom.modeling_bloom import BloomForCausalLM + def _debug_show_kvcache(past_key_values): if past_key_values is None: return @@ -96,10 +99,22 @@ def rollback(self, end_pos : int): k, v = kv # NOTE() the indexing is specific for bloom. This won't work for other models # For example llama k, v should be (batch, num_head, seq_len, hidden_dim) - k = k[:, :, :end_pos] - v = v[:, :end_pos, :] - kv_trimmed = (k, v) - past_key_values_trimmed.append(kv_trimmed) + + if isinstance(self._model, BloomForCausalLM): + # k (batch * head, hidden_dim, seq); v (batch * head, seq, hidden_dim) + k = k[:, :, :end_pos] + v = v[:, :end_pos, :] + kv_trimmed = (k, v) + past_key_values_trimmed.append(kv_trimmed) + elif isinstance(self._model, LlamaForCausalLM): + # k, v (batch, head, seq, hidden_dim) + k = k[:, :, :end_pos, :] + v = v[:, :, :end_pos, :] + kv_trimmed = (k, v) + past_key_values_trimmed.append(kv_trimmed) + else: + # check the model implementation to see the layout of K, V + raise TypeError(f"unknown model type {type(self._model)} for KV Cache trim operations") self._past_key_values = past_key_values_trimmed self._prob_history = self._prob_history[:, :end_pos, :] diff --git a/sample.py b/sampling/speculative_sampling.py similarity index 61% rename from sample.py rename to sampling/speculative_sampling.py index 0ebf939..27999db 100644 --- a/sample.py +++ b/sampling/speculative_sampling.py @@ -1,148 +1,9 @@ -from transformers import AutoTokenizer, AutoModelForCausalLM +import torch from tqdm import tqdm import torch -from torch.nn import functional as F -import argparse -from typing import Tuple -import torch.utils.benchmark as benchmark - -from kvcache_model import KVCacheModel -from utils import norm_logits, sample -class Decoder: - def __init__(self, tokenizer) -> None: - self.tokenizer = tokenizer - - def decode(self, t : torch.Tensor) -> str: - # assert t.dim == 2, "t must be 2d tensor" - return self.tokenizer.decode(t[0], skip_special_tokens=True) - -DECODER : Decoder = None - - -def parse_arguments(): - parser = argparse.ArgumentParser(description='args for sample.py') - - parser.add_argument('--input', type=str, default="Suggest at least five related search terms to \"Mạng neural nhân tạo\".") - parser.add_argument('--approx_model_name', type=str, default="/share_nfs/fangjiarui/root/code/hf_models/bloom-560m") - parser.add_argument('--target_model_name', type=str, default="/share_nfs/fangjiarui/root/code/hf_models/bloomz-7b1") - parser.add_argument('--verbose', '-v', action='store_true', default=False, help='enable verbose mode') - parser.add_argument('--seed', '-s', type=int, default=None, help='set a random seed') - args = parser.parse_args() - return args - - - -@torch.no_grad() -def autoregressive_sampling(x : torch.Tensor, model : torch.nn.Module, N : int, - temperature : float = 1, top_k : int = 0, top_p : float = 0): - n = len(x) - T = len(x) + N - - past_key_values = None - with tqdm(total=N, desc="autoregressive sampling") as pbar: - while n < T: - # outputs = model(x) - if past_key_values: - last_ids = x[:, -1] - if last_ids.dim() == 1: - last_ids = torch.unsqueeze(last_ids, 0) - outputs = model(last_ids, past_key_values = past_key_values, use_cache = True) - else: - outputs = model(x) - last_p = norm_logits(outputs.logits[::, -1, :], temperature, top_k, top_p) - past_key_values = outputs.past_key_values - idx_next = sample(last_p) - x = torch.cat((x, idx_next), dim=1) - n += 1 - pbar.update(1) - - return x - -def max_fn(x): - """ - norm(max (x, 0)) - """ - x_max = torch.where(x > 0, x, torch.zeros_like(x)) - x_max_sum = torch.sum(x_max, dim=1, keepdim=True) - return x_max / x_max_sum - -@torch.no_grad() -def speculative_sampling_v2(prefix : torch.Tensor, approx_model : torch.nn.Module, target_model : torch.nn.Module, - max_len : int , gamma : int = 4, - temperature : float = 1, top_k : int = 0, top_p : float = 0, random_seed : int = None) -> torch.Tensor: - """ - DeepMind version Speculative Sampling. - Accelerating Large Language Model Decoding with Speculative Sampling - https://arxiv.org/abs/2302.01318 - No KV Cache Optimization - - Args: - x (torch.Tensor): input sequence, (batch, prefix_seqlen), Note that the batch dim is always 1 now. - approx_model (torch.nn.Module): approx model, the small one - target_model (torch.nn.Module): target model, the large one - max_len (int): the max overall generated tokens number. - gamma (int): $\gamma$, the token number small model guesses. - temperature (float, optional): Defaults to 1. - top_k (int, optional): Defaults to 0. - top_p (float, optional): Defaults to 0. - - Returns: - torch.Tensor: generated tokens (batch, target_seqlen) - """ - seq_len = prefix.shape[1] - T = seq_len + max_len - - assert prefix.shape[0] == 1, "input batch size must be 1" - - with tqdm(total=T, desc="speculative sampling") as pbar: - while prefix.shape[1] < T: - # q = M_q[prefix + x_0, x_1, .., x_(gamma-2)] - x = prefix - prefix_len = prefix.shape[1] - for _ in range(gamma): - # p.logits shape (batch, seq, vocab) - q = approx_model(x).logits - next_tok = sample(norm_logits(q[:, -1, :], - temperature, top_k, top_p), random_seed = random_seed) - x = torch.cat((x, next_tok), dim=1) - - # normalize the logits - for i in range(q.shape[1]): - q[:,i,:] = norm_logits(q[:,i,:], - temperature, top_k, top_p) - # p = M_p[prefix + x_0, x_0, .., x_(gamma-1)] - p = target_model(x).logits - for i in range(p.shape[1]): - p[:,i,:] = norm_logits(p[:,i,:], - temperature, top_k, top_p) - - # n the end position of the valid prefix - # x = x_[:prefix_len-1] + x_0, ... x_(gamma-1) - - is_all_accept = True - n = prefix_len - 1 - for i in range(gamma): - r = torch.rand(1, device = p.device) - j = x[:, prefix_len + i] - - if r < torch.min(torch.tensor([1], device=q.device), p[:, prefix_len + i - 1, j] / q[:, prefix_len + i - 1, j]): - # accept, and update n - n += 1 - else: - # reject - t = sample(max_fn(p[:, n, :] - q[:, n, :]), random_seed = random_seed) - is_all_accept = False - break - - prefix = x[:, :n + 1] - - if is_all_accept: - t = sample(p[:, -1, :], random_seed = random_seed) - - prefix = torch.cat((prefix, t), dim=1) - pbar.update(n - pbar.n) - return prefix +from sampling.kvcache_model import KVCacheModel +from sampling.utils import norm_logits, sample, max_fn @torch.no_grad() def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module, target_model : torch.nn.Module, @@ -235,48 +96,82 @@ def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module, return prefix -def generate(input_text, approx_model_name, target_model_name, num_tokens=40, random_seed = None, verbose = False): - # NOTE() approx_model_name and target_model_name should use the same tokenizer! - - torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' - - tokenizer = AutoTokenizer.from_pretrained(approx_model_name) - - global DECODER - DECODER = Decoder(tokenizer) - - print("begin loading models") - small_model = AutoModelForCausalLM.from_pretrained(approx_model_name).to(torch_device) - large_model = AutoModelForCausalLM.from_pretrained(target_model_name).to(torch_device) - print("finish loading models") - - input_ids = tokenizer.encode(input_text, return_tensors='pt').to(torch_device) - top_k = 10 - top_p = 0.9 +@torch.no_grad() +def speculative_sampling_v2(prefix : torch.Tensor, approx_model : torch.nn.Module, target_model : torch.nn.Module, + max_len : int , gamma : int = 4, + temperature : float = 1, top_k : int = 0, top_p : float = 0, random_seed : int = None) -> torch.Tensor: + """ + DeepMind version Speculative Sampling. + Accelerating Large Language Model Decoding with Speculative Sampling + https://arxiv.org/abs/2302.01318 + No KV Cache Optimization + + Args: + x (torch.Tensor): input sequence, (batch, prefix_seqlen), Note that the batch dim is always 1 now. + approx_model (torch.nn.Module): approx model, the small one + target_model (torch.nn.Module): target model, the large one + max_len (int): the max overall generated tokens number. + gamma (int): $\gamma$, the token number small model guesses. + temperature (float, optional): Defaults to 1. + top_k (int, optional): Defaults to 0. + top_p (float, optional): Defaults to 0. - torch.manual_seed(123) - output = autoregressive_sampling(input_ids, large_model, num_tokens, top_k = top_k, top_p=top_p) - generated_text = tokenizer.decode(output[0], skip_special_tokens=True) - print(f"large (target) model autoregressive_sampling: {generated_text}") + Returns: + torch.Tensor: generated tokens (batch, target_seqlen) + """ + seq_len = prefix.shape[1] + T = seq_len + max_len + + assert prefix.shape[0] == 1, "input batch size must be 1" - torch.manual_seed(123) - output = autoregressive_sampling(input_ids, small_model, num_tokens, top_k = top_k, top_p=top_p) - generated_text = tokenizer.decode(output[0], skip_special_tokens=True) - print(f"small (approx) model autoregressive_sampling: {generated_text}") + with tqdm(total=T, desc="speculative sampling") as pbar: + while prefix.shape[1] < T: + # q = M_q[prefix + x_0, x_1, .., x_(gamma-2)] + x = prefix + prefix_len = prefix.shape[1] + for _ in range(gamma): + # p.logits shape (batch, seq, vocab) + q = approx_model(x).logits + next_tok = sample(norm_logits(q[:, -1, :], + temperature, top_k, top_p), random_seed = random_seed) + x = torch.cat((x, next_tok), dim=1) + + # normalize the logits + for i in range(q.shape[1]): + q[:,i,:] = norm_logits(q[:,i,:], + temperature, top_k, top_p) + # p = M_p[prefix + x_0, x_0, .., x_(gamma-1)] + p = target_model(x).logits + for i in range(p.shape[1]): + p[:,i,:] = norm_logits(p[:,i,:], + temperature, top_k, top_p) - torch.manual_seed(123) - output = speculative_sampling(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed, verbose = verbose) - generated_text = tokenizer.decode(output[0], skip_special_tokens=True) - print(f"google's speculative_sampling: {generated_text}") - - torch.manual_seed(123) - output = speculative_sampling_v2(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed) - generated_text = tokenizer.decode(output[0], skip_special_tokens=True) - print(f"deepmind's speculative_sampling: {generated_text}") + # n the end position of the valid prefix + # x = x_[:prefix_len-1] + x_0, ... x_(gamma-1) + + is_all_accept = True + n = prefix_len - 1 + for i in range(gamma): + r = torch.rand(1, device = p.device) + j = x[:, prefix_len + i] + + if r < torch.min(torch.tensor([1], device=q.device), p[:, prefix_len + i - 1, j] / q[:, prefix_len + i - 1, j]): + # accept, and update n + n += 1 + else: + # reject + t = sample(max_fn(p[:, n, :] - q[:, n, :]), random_seed = random_seed) + is_all_accept = False + break + + prefix = x[:, :n + 1] + + if is_all_accept: + t = sample(p[:, -1, :], random_seed = random_seed) + + prefix = torch.cat((prefix, t), dim=1) + pbar.update(n - pbar.n) -if __name__ == "__main__": - args = parse_arguments() - generate(args.input, args.approx_model_name, args.target_model_name, random_seed = args.seed, verbose=args.verbose) + return prefix - diff --git a/utils.py b/sampling/utils.py similarity index 90% rename from utils.py rename to sampling/utils.py index 87f0941..2bf23f9 100644 --- a/utils.py +++ b/sampling/utils.py @@ -54,3 +54,12 @@ def sample(probs : torch.Tensor, num_samples: int = 1, random_seed = None): if (idx_next.item() == 0): raise RuntimeError return idx_next + + +def max_fn(x): + """ + norm(max (x, 0)) + """ + x_max = torch.where(x > 0, x, torch.zeros_like(x)) + x_max_sum = torch.sum(x_max, dim=1, keepdim=True) + return x_max / x_max_sum