Skip to content

Commit

Permalink
add server
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Sep 21, 2023
2 parents 02d664b + 2b62bec commit 1da8d0a
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 13 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The speculative sampling is proposed by Google and Deepmind independently. So I
- 2023.08.16: First release, implement the paper's algorithm.

## Usage
### Inference
In the sample, I use [bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1/tree/main) as the target model, [bloom-560m](https://huggingface.co/bigscience/bloom-560m/tree/main) as the approximation model.

```bash
Expand All @@ -24,6 +25,16 @@ You can also use `--v` args to see a token is generated by which model.

![example image](./imgs/sps.jpg "console output")

### Serving
Start an inference server.
```bash
python serving.py
```

Test the serving with curl:
```bash
curl -X POST -H "Content-Type: application/json" -d '{"prompt": "Who is the president of the USA"}' http://127.0.0.1:5000/predict
```
## References
```
@inproceedings{leviathan2023fast,
Expand Down
20 changes: 13 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@ def decode(self, t : torch.Tensor) -> str:

DECODER : Decoder = None

# my local models
MODELZOO = {
"llama7b": "/share_nfs/tianzhi/code/llama-7b",
"bloom7b": "/share_nfs/fangjiarui/root/code/hf_models/bloomz-7b1",
"bloom-560m": "/share_nfs/fangjiarui/root/code/hf_models/bloom-560m",
"baichuan-13b": "/share_nfs/duanqiyuan/models/source_models/hf/Baichuan-13B-Base",
"baichuan-7b": "/share_nfs/duanqiyuan/models/source_models/hf/baichuan-7B"
}

def parse_arguments():
parser = argparse.ArgumentParser(description='args for sample.py')

parser.add_argument('--input', type=str, default="Suggest at least five related search terms to \"Mạng neural nhân tạo\".")
parser.add_argument('--approx_model_name', type=str, default="/share_nfs/fangjiarui/root/code/hf_models/bloom-560m")
parser.add_argument('--target_model_name', type=str, default="/share_nfs/fangjiarui/root/code/hf_models/bloomz-7b1")
parser.add_argument('--approx_model_name', type=str, default=MODELZOO["bloom-560m"])
parser.add_argument('--target_model_name', type=str, default=MODELZOO["bloom7b"])
parser.add_argument('--verbose', '-v', action='store_true', default=False, help='enable verbose mode')
parser.add_argument('--seed', '-s', type=int, default=None, help='set a random seed')
args = parser.parse_args()
Expand All @@ -40,14 +43,14 @@ def generate(input_text, approx_model_name, target_model_name, num_tokens=40, ra

torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = AutoTokenizer.from_pretrained(approx_model_name)
tokenizer = AutoTokenizer.from_pretrained(approx_model_name, trust_remote_code=True)

global DECODER
DECODER = Decoder(tokenizer)

print("begin loading models")
small_model = AutoModelForCausalLM.from_pretrained(approx_model_name).to(torch_device)
large_model = AutoModelForCausalLM.from_pretrained(target_model_name).to(torch_device)
print(f"begin loading models: \n {approx_model_name} \n {target_model_name}")
small_model = AutoModelForCausalLM.from_pretrained(approx_model_name, trust_remote_code=True).to(torch_device)
large_model = AutoModelForCausalLM.from_pretrained(target_model_name, trust_remote_code=True).to(torch_device)
print("finish loading models")

input_ids = tokenizer.encode(input_text, return_tensors='pt').to(torch_device)
Expand Down Expand Up @@ -89,5 +92,8 @@ def generate(input_text, approx_model_name, target_model_name, num_tokens=40, ra

if __name__ == "__main__":
args = parse_arguments()
# args.approx_model_name = MODELZOO["llama7b"]

args.approx_model_name = MODELZOO["baichuan-7b"]
args.target_model_name = MODELZOO["baichuan-13b"]

generate(args.input, args.approx_model_name, args.target_model_name, random_seed = args.seed, verbose=args.verbose)
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
transformers==4.29.2
torch==2.0.1
contexttimer
contexttimer
flask
transformers_stream_generator
7 changes: 2 additions & 5 deletions sampling/kvcache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional

from sampling.utils import norm_logits, sample
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.bloom.modeling_bloom import BloomForCausalLM

def _debug_show_kvcache(past_key_values):
Expand Down Expand Up @@ -100,21 +99,19 @@ def rollback(self, end_pos : int):
# NOTE() the indexing is specific for bloom. This won't work for other models
# For example llama k, v should be (batch, num_head, seq_len, hidden_dim)

# Bloom is special one
if isinstance(self._model, BloomForCausalLM):
# k (batch * head, hidden_dim, seq); v (batch * head, seq, hidden_dim)
k = k[:, :, :end_pos]
v = v[:, :end_pos, :]
kv_trimmed = (k, v)
past_key_values_trimmed.append(kv_trimmed)
elif isinstance(self._model, LlamaForCausalLM):
else:
# k, v (batch, head, seq, hidden_dim)
k = k[:, :, :end_pos, :]
v = v[:, :, :end_pos, :]
kv_trimmed = (k, v)
past_key_values_trimmed.append(kv_trimmed)
else:
# check the model implementation to see the layout of K, V
raise TypeError(f"unknown model type {type(self._model)} for KV Cache trim operations")

self._past_key_values = past_key_values_trimmed
self._prob_history = self._prob_history[:, :end_pos, :]
Expand Down
64 changes: 64 additions & 0 deletions serving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from flask import Flask, request, jsonify
import numpy as np
from transformers import AutoTokenizer
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging

from sampling import autoregressive_sampling, speculative_sampling, speculative_sampling_v2

app = Flask(__name__)
pipeline = None

GLOBAL_SERVER = None

class Server:
def __init__(self, approx_model_name, target_model_name) -> None:
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'

logging.info("begin load models")
self._small_model = AutoModelForCausalLM.from_pretrained(approx_model_name, trust_remote_code=True).to(self._device)
self._large_model = AutoModelForCausalLM.from_pretrained(target_model_name, trust_remote_code=True).to(self._device)
self._tokenizer = AutoTokenizer.from_pretrained(approx_model_name)
logging.info("fininsh load models")

self.num_tokens = 40
self.top_k = 10
self.top_p = 0.9

def process_request(self, request : str) -> torch.Tensor:
input_str = request['prompt']
logging.info(f"recieve request {input_str}")
input_ids = self._tokenizer.encode(input_str, return_tensors='pt').to(self._device)
output = speculative_sampling(input_ids,
self._small_model,
self._large_model, self.num_tokens,
top_k = self.top_k,
top_p = self.top_p)
generated_text = self._tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text

# Set up a route to listen for inference requests
@app.route('/predict', methods=['POST'])
def predict():
# Check the content type of the request
if request.headers['Content-Type'] != 'application/json':
return jsonify({'error': 'Invalid content type'})

# Get the request data
request_data = request.json

# Perform inference
result = GLOBAL_SERVER.process_request(request_data)

# Return the inference results
return jsonify(result)

if __name__ == '__main__':
# Load the model
# load_model("/share_nfs/fangjiarui/root/code/hf_models/bloom-560m")

GLOBAL_SERVER = Server(approx_model_name="/share_nfs/fangjiarui/root/code/hf_models/bloom-560m",
target_model_name="/share_nfs/fangjiarui/root/code/hf_models/bloomz-7b1")
# Start the Flask service
app.run(host='0.0.0.0', port=5000)

0 comments on commit 1da8d0a

Please sign in to comment.