-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathshakespeare-cli.py
89 lines (69 loc) · 2.35 KB
/
shakespeare-cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import os
import random
import pickle
import torch
import time
import sys
from pathlib import Path
from model import encode, GPT
from config import GPTConfig
import warnings
from warnings import filterwarnings
def load_pickles():
with open(Path(__file__).parent / "components/itos.bin", "rb") as f:
itos = pickle.load(f) # index to string lookup
with open(Path(__file__).parent / "components/stoi.bin", "rb") as f:
stoi = pickle.load(f) # string to index lookup
return itos, stoi
def load_model(itos):
device = "cuda" if torch.cuda.is_available() else "cpu"
config = GPTConfig(device=device)
gpt = GPT(config, itos=itos).to(config.device)
# Load weights (if exists)
if os.path.isfile(Path(__file__).parent / "model_weights/gpt.pth"):
try:
gpt.load_state_dict(
torch.load(Path(__file__).parent / "model_weights/gpt.pth")
if device == "cuda"
else torch.load(
Path(__file__).parent / "model_weights/gpt.pth",
map_location=torch.device("cpu"),
)
)
except Exception as e:
print("Loading weights failed!")
print(e)
return gpt, config
def typing_effect(text, delay=0.01):
for char in text:
sys.stdout.write(char)
sys.stdout.flush()
time.sleep(delay)
sys.stdout.write("\n")
def main():
parser = argparse.ArgumentParser(description="Generate Shakespearean text.")
parser.add_argument("input_string", type=str, help="Input text for generation")
parser.add_argument(
"--tokens", type=int, default=500, help="Number of tokens to predict"
)
parser.add_argument(
"--temperature", type=float, default=1.0, help="Temperature for text generation"
)
args = parser.parse_args()
itos, stoi = load_pickles()
gpt, config = load_model(itos)
try:
x = encode(args.input_string, stoi=stoi)
except KeyError as e:
print("Try different words....")
return
x = torch.Tensor(x).reshape(1, -1)
x = x.type(torch.LongTensor)
out = gpt.write(
x.to(config.device), max_new_tokens=args.tokens, temperature=args.temperature
)
typing_effect(out[0])
if __name__ == "__main__":
warnings.filterwarnings("ignore", category=UserWarning)
main()