Skip to content

Commit

Permalink
add time benchmarking and organize the directory better
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Sep 21, 2023
2 parents 60740d0 + ec5823b commit 02d664b
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 187 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*__pycache__*
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The speculative sampling is proposed by Google and Deepmind independently. So I
In the sample, I use [bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1/tree/main) as the target model, [bloom-560m](https://huggingface.co/bigscience/bloom-560m/tree/main) as the approximation model.

```bash
python sample.py \
python main.py \
--input "The quick brown fox jumps over the lazy " \
--target_model_name bigscience/bloomz-7b1 \
--approx_model_name bigscience/bloom-560m
Expand Down
93 changes: 93 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@

import torch
import argparse
import contexttimer

from transformers import AutoTokenizer, AutoModelForCausalLM

from sampling import autoregressive_sampling, speculative_sampling, speculative_sampling_v2

class Decoder:
def __init__(self, tokenizer) -> None:
self.tokenizer = tokenizer

def decode(self, t : torch.Tensor) -> str:
# assert t.dim == 2, "t must be 2d tensor"
return self.tokenizer.decode(t[0], skip_special_tokens=True)

DECODER : Decoder = None

MODELZOO = {
"llama7b": "/share_nfs/tianzhi/code/llama-7b",
"bloom7b": "/share_nfs/fangjiarui/root/code/hf_models/bloomz-7b1",
"bloom-560m": "/share_nfs/fangjiarui/root/code/hf_models/bloom-560m",
}

def parse_arguments():
parser = argparse.ArgumentParser(description='args for sample.py')

parser.add_argument('--input', type=str, default="Suggest at least five related search terms to \"Mạng neural nhân tạo\".")
parser.add_argument('--approx_model_name', type=str, default="/share_nfs/fangjiarui/root/code/hf_models/bloom-560m")
parser.add_argument('--target_model_name', type=str, default="/share_nfs/fangjiarui/root/code/hf_models/bloomz-7b1")
parser.add_argument('--verbose', '-v', action='store_true', default=False, help='enable verbose mode')
parser.add_argument('--seed', '-s', type=int, default=None, help='set a random seed')
args = parser.parse_args()
return args


def generate(input_text, approx_model_name, target_model_name, num_tokens=40, random_seed = None, verbose = False):
# NOTE() approx_model_name and target_model_name should use the same tokenizer!

torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = AutoTokenizer.from_pretrained(approx_model_name)

global DECODER
DECODER = Decoder(tokenizer)

print("begin loading models")
small_model = AutoModelForCausalLM.from_pretrained(approx_model_name).to(torch_device)
large_model = AutoModelForCausalLM.from_pretrained(target_model_name).to(torch_device)
print("finish loading models")

input_ids = tokenizer.encode(input_text, return_tensors='pt').to(torch_device)

top_k = 10
top_p = 0.9

torch.manual_seed(123)
output = autoregressive_sampling(input_ids, large_model, num_tokens, top_k = top_k, top_p=top_p)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"large (target) model autoregressive_sampling: {generated_text}")

TEST_TIME = 10
with contexttimer.Timer() as t:
for _ in range(TEST_TIME):
output = autoregressive_sampling(input_ids, large_model, num_tokens, top_k = top_k, top_p=top_p)
print(f"large (target) model autoregressive_sampling 10 times, tokens/sec: {len(output[0] / t.elapsed / TEST_TIME)}")


torch.manual_seed(123)
output = autoregressive_sampling(input_ids, small_model, num_tokens, top_k = top_k, top_p=top_p)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"small (approx) model autoregressive_sampling: {generated_text}")

torch.manual_seed(123)
output = speculative_sampling_v2(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"deepmind's speculative_sampling: {generated_text}")

torch.manual_seed(123)
output = speculative_sampling(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed, verbose = verbose)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"google's speculative_sampling: {generated_text}")

with contexttimer.Timer() as t:
for _ in range(TEST_TIME):
output = speculative_sampling(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed)
print(f"speculative_sampling 10 times, tokens/sec: {len(output[0] / t.elapsed / TEST_TIME)}")

if __name__ == "__main__":
args = parse_arguments()
# args.approx_model_name = MODELZOO["llama7b"]
generate(args.input, args.approx_model_name, args.target_model_name, random_seed = args.seed, verbose=args.verbose)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
transformers==4.29.2
torch==2.0.1
contexttimer
4 changes: 4 additions & 0 deletions sampling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from sampling.speculative_sampling import speculative_sampling, speculative_sampling_v2
from sampling.autoregressive_sampling import autoregressive_sampling

__all__ = ["speculative_sampling", "speculative_sampling_v2", "autoregressive_sampling"]
31 changes: 31 additions & 0 deletions sampling/autoregressive_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch

from tqdm import tqdm
from sampling.utils import norm_logits, sample

@torch.no_grad()
def autoregressive_sampling(x : torch.Tensor, model : torch.nn.Module, N : int,
temperature : float = 1, top_k : int = 0, top_p : float = 0):
n = len(x)
T = len(x) + N

past_key_values = None
with tqdm(total=N, desc="autoregressive sampling") as pbar:
while n < T:
# outputs = model(x)
if past_key_values:
last_ids = x[:, -1]
if last_ids.dim() == 1:
last_ids = torch.unsqueeze(last_ids, 0)
outputs = model(last_ids, past_key_values = past_key_values, use_cache = True)
else:
outputs = model(x)
last_p = norm_logits(outputs.logits[::, -1, :], temperature, top_k, top_p)
past_key_values = outputs.past_key_values
idx_next = sample(last_p)
x = torch.cat((x, idx_next), dim=1)
n += 1
pbar.update(1)

return x

25 changes: 20 additions & 5 deletions kvcache_model.py → sampling/kvcache_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import torch
from utils import norm_logits, sample
from typing import Optional

from sampling.utils import norm_logits, sample
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.bloom.modeling_bloom import BloomForCausalLM

def _debug_show_kvcache(past_key_values):
if past_key_values is None:
return
Expand Down Expand Up @@ -96,10 +99,22 @@ def rollback(self, end_pos : int):
k, v = kv
# NOTE() the indexing is specific for bloom. This won't work for other models
# For example llama k, v should be (batch, num_head, seq_len, hidden_dim)
k = k[:, :, :end_pos]
v = v[:, :end_pos, :]
kv_trimmed = (k, v)
past_key_values_trimmed.append(kv_trimmed)

if isinstance(self._model, BloomForCausalLM):
# k (batch * head, hidden_dim, seq); v (batch * head, seq, hidden_dim)
k = k[:, :, :end_pos]
v = v[:, :end_pos, :]
kv_trimmed = (k, v)
past_key_values_trimmed.append(kv_trimmed)
elif isinstance(self._model, LlamaForCausalLM):
# k, v (batch, head, seq, hidden_dim)
k = k[:, :, :end_pos, :]
v = v[:, :, :end_pos, :]
kv_trimmed = (k, v)
past_key_values_trimmed.append(kv_trimmed)
else:
# check the model implementation to see the layout of K, V
raise TypeError(f"unknown model type {type(self._model)} for KV Cache trim operations")

self._past_key_values = past_key_values_trimmed
self._prob_history = self._prob_history[:, :end_pos, :]
Expand Down
Loading

0 comments on commit 02d664b

Please sign in to comment.