-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathprompt_mamba.py
76 lines (61 loc) · 2.37 KB
/
prompt_mamba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import sys
# Validate CLI params
if len(sys.argv) != 2:
print("Usage: python train.py state-spaces/mamba-130m")
exit()
# Take in the model you want to train
model_name = sys.argv[1]
# Choose a tokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
# Instantiate the MambaLMHeadModel from the state-spaces/mamba GitHub repo
# https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
model = MambaLMHeadModel.from_pretrained(model_name, device="cuda", dtype=torch.float16)
while True:
# Take the user input from the command line
user_message = input("\n> ")
# Create a prompt
n_shot_prompting = [
{
"question": "What is the capital of France?",
"answer": "Paris"
},
{
"question": "Who invented the segway?",
"answer": "Dean Kamen"
},
{
"question": "What is the fastest animal?",
"answer": "Cheetah"
}
]
prompt = f"You are a Trivia QA bot.\nAnswer the following question succinctly and accurately."
prompt = f"{prompt}\n\n" + "\n\n".join([f"Q: {p['question']}\nA: {p['answer']}" for p in n_shot_prompting])
prompt = f"{prompt}\n\nQ: {user_message}"
# Debug print to make sure our prompt looks good
print(prompt)
# Encode the text to token IDs
input_ids = torch.LongTensor([tokenizer.encode(prompt)]).cuda()
print(prompt)
# Encode the prompt into integers and convert to a tensor on the GPU
input_ids = torch.LongTensor([tokenizer.encode(prompt)]).cuda()
print(input_ids)
# Generate an output sequence of tokens given the input
# "out" will contain the raw token ids as integers
out = model.generate(
input_ids=input_ids,
max_length=256,
eos_token_id=tokenizer.eos_token_id
)
# you must use the tokenizer to decode them back into strings
decoded = tokenizer.batch_decode(out)[0]
print("="*80)
# out returns the whole sequence plus the original
cleaned = decoded.replace(prompt, "")
# the model will just keep generating, so only grab the first one
# cleaned = cleaned.split("\n\n")[0]
print(cleaned)