Skip to content

Commit

Permalink
+eval code for 27M ppl 1.65 BPC 0.72 enwik8 model
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Mar 25, 2022
1 parent 71538e4 commit 88e921b
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 59 deletions.
46 changes: 39 additions & 7 deletions RWKV-v2-RNN/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
########################################################################################################

import numpy as np
import math
import time
import types
import copy
import torch
from torch.nn import functional as F
from src.utils import TOKENIZER
from src.utils import TOKENIZER, Dataset
from src.model_run import RWKV_RNN
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)

### Step 1: set model ##################################################################################

Expand All @@ -26,9 +28,11 @@
MODEL_NAME = 'trained-31'
WORD_NAME = 'vocab' # the .json vocab (generated by train.py

# ### uncompress enwik8-model.zip to test my enwik8 model
# ########## Uncomment these to test my 27M params enwik8 model ##########
# MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13'
# WORD_NAME = 'enwik8-vocab'
# EVAL_DATA = 'enwik8' # uncomment this for EVAL MODE (no text generation)
# ########################################################################

# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> all unknown tokens in your context will be denoted by it <--
Expand All @@ -50,16 +54,44 @@

########################################################################################################

np.set_printoptions(precision=4, suppress=True, linewidth=200)

print(f'Loading {MODEL_NAME}...')
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)

########################################################################################################

if 'EVAL_DATA' in vars() or 'EVAL_DATA' in globals():
print('Evaluating on ' + EVAL_DATA + ' ...')

data = open(EVAL_DATA, "r", encoding='utf-8').read()

loss_table = np.zeros(ctx_len)

N_SAMPLE = 1000

for iii in range(N_SAMPLE):
pos = np.random.randint(0, len(data) - ctx_len-1)
context = data[pos:pos+ctx_len+1]
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]

model.clear()
for i in range(1, ctx_len+1):
x = ctx[:i]
out = model.run(x)
prob = F.softmax(torch.tensor(out), dim=-1)
loss_table[i-1] += -math.log(prob[ctx[i]])

print(f'Tested {iii+1} samples: avg_loss over ctx_len =',
np.mean(loss_table) / (iii+1))

exit(0)

########################################################################################################

context = tokenizer.refine_context(context)
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n')

print(f'Loading {MODEL_NAME}...')
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)

for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
t_begin = time.time_ns()

Expand Down
4 changes: 2 additions & 2 deletions RWKV-v2-RNN/src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True

log_file = open("mylog.txt", "a")

Expand Down Expand Up @@ -151,7 +151,7 @@ def run_epoch(split):
self.avg_loss = self.avg_loss * \
(1.0 - factor) + now_loss * factor
pbar.set_description(
f"epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")

self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs):
Expand Down
42 changes: 42 additions & 0 deletions RWKV-v2-RNN/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,48 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset


class Dataset(Dataset):
def __init__(self, data, ctx_len, epoch_length_fixed):
print('building token list...', end=' ')
unique = sorted(list(set(data)))
# print()
# for u in unique:
# print(u, end=' ')
# print('\n\n')

xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))

data_size, vocab_size = len(data), len(unique)
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
self.ctx_len = ctx_len
self.epoch_length_fixed = epoch_length_fixed
self.vocab_size = vocab_size
self.data = data

def __len__(self):
return self.epoch_length_fixed

def __getitem__(self, idx):
# cheat: pick a random spot in dataset
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
chunk = self.data[i:i+self.ctx_len+1]
dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long,
device=torch.device('cuda'))
y = torch.tensor(dix[1:], dtype=torch.long,
device=torch.device('cuda'))
return x, y


class TOKENIZER():
Expand Down
57 changes: 7 additions & 50 deletions RWKV-v2-RNN/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import json
from src.model import GPT, GPTConfig
from src.trainer import Trainer, TrainerConfig
from torch.utils.data import Dataset
from src.utils import Dataset
import torch
import numpy as np
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True

### Step 1: set training data ##########################################################################

Expand All @@ -36,21 +36,20 @@
# If you see "CUDA out of memory", reduce it. Use GPU-Z to find the highest value for your VRAM.
batch_size = 12

### Step 4: set learning rate, training 'epochs' #######################################################
### Step 4: set learning rate, training mini-epochs #######################################################

lr_init = 6e-4
lr_final = 1e-5
# the 'epoch' here is very short and of fixed length (ctx_len * epoch_length_fixed tokens)
# the mini-epoch is very short and of fixed length (ctx_len * epoch_length_fixed tokens)
n_epoch = 500
# 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc.
# 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, etc.
epoch_save_frequency = 30
epoch_save_path = 'trained-'

epoch_length_fixed = 10000

########################################################################################################


# import src.utils
# src.utils.set_seed(42) # remember to change seed if you load a model

Expand All @@ -71,50 +70,8 @@
########################################################################################################

print('loading data... ' + datafile)


class Dataset(Dataset):
def __init__(self, data, ctx_len):
print('building token list...', end=' ')
unique = sorted(list(set(data)))
# print()
# for u in unique:
# print(u, end=' ')
# print('\n\n')

xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))

data_size, vocab_size = len(data), len(unique)
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
self.ctx_len = ctx_len
self.vocab_size = vocab_size
self.data = data

def __len__(self):
return epoch_length_fixed

def __getitem__(self, idx):
# cheat: pick a random spot in dataset
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
chunk = self.data[i:i+self.ctx_len+1]
dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long,
device=torch.device('cuda'))
y = torch.tensor(dix[1:], dtype=torch.long,
device=torch.device('cuda'))
return x, y


train_dataset = Dataset(
open(datafile, "r", encoding=datafile_encoding).read(), ctx_len)
train_dataset = Dataset(open(
datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)

########################################################################################################
# Train model
Expand Down

0 comments on commit 88e921b

Please sign in to comment.