diff --git a/docs/RWKV/configs.html b/docs/RWKV/configs.html new file mode 100644 index 00000000..909ebb18 --- /dev/null +++ b/docs/RWKV/configs.html @@ -0,0 +1,210 @@ + + +
+ + + + + + + + + + + + + + + + + + + +1from labml.configs import BaseConfigs
This defines configurations for a transformer. The configurations are calculate using option functions. These are lazy loaded and therefore only the necessary modules are calculated.
+ +4class RWKVConfigs(BaseConfigs):
Number of attention heads
+ +14 n_heads: int = 8
Transformer embedding size
+ +16 d_model: int = 512
Number of layers
+ +18 n_layers: int = 6
Dropout probability
+ +20 dropout: float = 0.1
Number of tokens in the source vocabulary (for token embeddings)
+ +22 n_src_vocab: int
Number of tokens in the target vocabulary (to generate logits for prediction)
+ +24 n_tgt_vocab: int
1import inspect
+2import math
+3
+4import torch
+5import torch.nn as nn
+6from labml_nn.RWKV.configs import RWKVConfigs
+7
+8from labml_nn.RWKV import RWKV
+9from labml_nn.RWKV import TimeMixing
+10from labml import experiment
+11from labml.configs import option
+12from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
15class Configs(NLPAutoRegressionConfigs):
RWKV model
+ +24 model: RWKV
+25
+26 rwkv: RWKVConfigs
number of warmup iterations
+ +28 warmup_iters: int = 2000
total number of training iterations
+ +30 max_iters: int = 600000
weight decay
+ +32 weight_decay: float = 1e-1
Custom optimizer
+ +34 beta1: float = 0.9
+35 beta2: float = 0.95
+36 optimizer = 'rwkv_optimizer'
39@option(Configs.rwkv, 'RWKV')
+40def _rwkv_configs(c: Configs):
We use our configurable RWKV implementation
+ +47 conf = RWKVConfigs()
Set the vocabulary sizes for embeddings and generating logits
+ +49 conf.n_src_vocab = c.n_tokens
+50 conf.n_tgt_vocab = c.n_tokens
+51
+52 return conf
55def _init_weights(module, rwkv: RWKVConfigs):
initialize Vector Parameters in TimeMixing
+ +57 if isinstance(module, TimeMixing):
+58 layer_id = module.layer_id
+59 n_layer = module.n_layer
+60 n_embd = module.n_embd
+61 attn_sz = n_embd
+62
+63 with torch.no_grad():
+64 ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1
+65 ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0
+66 ddd = torch.ones(1, 1, n_embd)
+67 for i in range(n_embd):
+68 ddd[0, 0, i] = i / n_embd
+69
+70 decay_speed = torch.ones(attn_sz)
+71 for h in range(attn_sz):
+72 decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
+73 module.time_decay = nn.Parameter(decay_speed)
+74
+75 zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5
+76 module.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
+77 module.time_mix_key = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
+78 module.time_mix_value = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
+79 module.time_mix_receptance = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
Create RWKV model and initialize weights
+ +82@option(Configs.model)
+83def _model(c: Configs):
87 m = RWKV(c.rwkv).to(c.device)
Apply custom weight initialization
+ +90 m.apply(_init_weights, c.rwkv)
+91
+92 return m
95@option(NLPAutoRegressionConfigs.optimizer)
+96def _configure_optimizers(c: NLPAutoRegressionConfigs):
start with all of the candidate parameters
+ +98 param_dict = {pn: p for pn, p in c.model.named_parameters()}
filter out those that do not require grad
+ +100 param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
+ +103 decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
+104 nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
+105 optim_groups = [
+106 {'params': decay_params, 'weight_decay': c.weight_decay},
+107 {'params': nodecay_params, 'weight_decay': 0.0}
+108 ]
+109 num_decay_params = sum(p.numel() for p in decay_params)
+110 num_nodecay_params = sum(p.numel() for p in nodecay_params)
+111 print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
+112 print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
Create AdamW optimizer and use the fused version if it is available
+ +114 fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
+115 use_fused = fused_available and c.device_type == 'cuda'
+116 extra_args = dict(fused=True) if use_fused else dict()
+117 optimizer = torch.optim.AdamW(optim_groups, lr=c.learning_rate, betas=c.betas, **extra_args)
+118 print(f"using fused AdamW: {use_fused}")
+119
+120 return optimizer
123def main():
Create experiment
+ +125 experiment.create(name="RWKV")
Create configs
+ +127 conf = Configs()
+128 print(conf.model)
Override configurations
+ +130 experiment.configs(conf, {
Use character level tokenizer
+ +132 'tokenizer': 'character',
Prompt separator is blank
+ +134 'prompt_separator': '',
Starting prompt for sampling
+ +136 'prompt': 'It is ',
Use Tiny Shakespeare dataset
+ +138 'text': 'tiny_shakespeare',
Use a context size of
+ +141 'seq_len': 128,
Train for epochs
+ +143 'epochs': 32,
Batch size
+ +145 'batch_size': 128,
Switch between training and validation for times per epoch
+ +148 'inner_iterations': 10,
+149
+150 'rwkv.block_size': 1024,
model
+ +152 'rwkv.n_layer': 12,
+153 'rwkv.n_heads': 12,
+154 'rwkv.n_embd': 768
+155 })
+156
+157 print(conf.model)
Set models for saving and loading
+ +159 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
+ +162 with experiment.start():
Run training
+ +164 conf.run()
+ +
168if __name__ == '__main__':
+169 main()
This is a tutorial/implementation of RWKV from paper RWKV: Reinventing RNNs for the Transformer Era in PyTorch.
+Full definition of a RWKV Language Model, all of it in this single file. References: 1) the official RWKV PyTorch implementation released by Bo Peng 2) huggingface/transformers PyTorch implementation
+ +22import torch
+23import torch.nn as nn
+24from torch.nn import functional as F
+25
+26from labml_helpers.module import Module
+27
+28PREV_X_TIME = 0
+29NUM_STATE = 1
+30DEN_STATE = 2
+31MAX_STATE = 3
+32PREV_X_CHANNEL = 4
35class LayerNorm(Module):
40 def __init__(self, ndim, bias):
+41 super().__init__()
+42 self.weight = nn.Parameter(torch.ones(ndim))
+43 self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
45 def forward(self, input):
+46 return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
56 @staticmethod
+57 def forward(ctx, loss, y):
+58 ctx.save_for_backward(y)
+59 return loss
+60
+61 @staticmethod
+62 def backward(ctx, grad_output):
+63 y = ctx.saved_tensors[0]
to encourage the logits to be close to 0
+ +65 factor = 1e-4 / (y.shape[0] * y.shape[1])
+66 maxx, ids = torch.max(y, -1, keepdim=True)
+67 gy = torch.zeros_like(y)
+68 gy.scatter_(-1, ids, maxx * factor)
+69 return grad_output, gy
72class ChannelMixing(Module):
77 def __init__(self, config, layer_id):
+78 super().__init__()
+79 self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
token shifting
+ +81 self.layer_id = layer_id
+82
+83 n_embd = config.n_embd
+84 intermediate_size = (
+85 config.intermediate_size if config.intermediate_size is not None else 4 * n_embd
+86 )
Learnable Matrix
+ +89 self.key_proj = nn.Linear(n_embd, intermediate_size, bias=False)
+90 self.value_proj = nn.Linear(intermediate_size, n_embd, bias=False)
+91 self.receptance_proj = nn.Linear(n_embd, n_embd, bias=False)
Learnable Vector
+ +94 self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
+95 self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
97 def forward(self, x, state=None):
101 if state is not None:
+102 prev_x = state[self.layer_id, :, [PREV_X_CHANNEL], :]
+103 state[self.layer_id, :, [PREV_X_CHANNEL], :] = x
+104 else:
+105 prev_x = self.time_shift(x)
+ +
108 receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
+109 receptance = self.receptance_proj(receptance)
+ +
112 key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
+113 key = self.key_proj(key)
+ +
116 value = self.value_proj(torch.square(torch.relu(key)))
+ +
119 out = F.sigmoid(receptance) * value
+120 return out, state
123class TimeMixing(Module):
128 def __init__(self, config, layer_id):
+129 super().__init__()
+130 self.config = config
+131 self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+132 self.layer_id = layer_id
+133
+134 n_embd = config.n_embd
+135 attn_sz = n_embd
learnable matrix
+ +138 self.key_proj = nn.Linear(n_embd, attn_sz, bias=False)
+139 self.value_proj = nn.Linear(n_embd, attn_sz, bias=False)
+140 self.receptance_proj = nn.Linear(n_embd, attn_sz, bias=False)
+141 self.output_proj = nn.Linear(attn_sz, n_embd, bias=False)
learnable vector
+ +144 self.time_decay = nn.Parameter(torch.empty(attn_sz))
+145 self.time_first = nn.Parameter(torch.empty(attn_sz))
+146 self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
+147 self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd))
+148 self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
x = (Batch,Time,Channel)
+ +150 def forward(self, x, state=None):
154 if state is not None:
+155 prev_x = state[self.layer_id, :, [PREV_X_TIME], :]
+156 state[self.layer_id, :, [PREV_X_TIME], :] = x
+157 else:
+158 prev_x = self.time_shift(x)
+ +
161 receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
+162 receptance = self.receptance_proj(receptance)
+ +
165 key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
+166 key = self.key_proj(key)
+ +
169 value = x * self.time_mix_value + prev_x * (1 - self.time_mix_value)
+170 value = self.value_proj(value)
WKV calculation
+ +173 _, seq_length, _ = key.size()
+174 output = torch.zeros_like(key)
+175
+176 if state is None:
+177 num_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
+178 den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
+179 max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
+180 else:
+181 num_state = state[self.layer_id, :, NUM_STATE, :]
+182 den_state = state[self.layer_id, :, DEN_STATE, :]
+183 max_state = state[self.layer_id, :, MAX_STATE, :]
+184
+185 time_decay = -torch.exp(self.time_decay)
+186
+187 for current_index in range(seq_length):
+188 current_key = key[:, current_index].float()
+189 current_value = value[:, current_index]
+ +
192 max_for_output = torch.maximum(max_state, current_key + self.time_first)
+193 e1 = torch.exp(max_state - max_for_output)
+194 e2 = torch.exp(current_key + self.time_first - max_for_output)
+195 numerator = e1 * num_state + e2 * current_value
+196 denominator = e1 * den_state + e2
+197 output[:, current_index] = (numerator / denominator).to(output.dtype)
Update state for next iteration
+ +200 max_for_state = torch.maximum(max_state + time_decay, current_key)
+201 e1 = torch.exp(max_state + time_decay - max_for_state)
+202 e2 = torch.exp(current_key - max_for_state)
+203 num_state = e1 * num_state + e2 * current_value
+204 den_state = e1 * den_state + e2
+205 max_state = max_for_state
update states
+ +208 state[self.layer_id, :, NUM_STATE, :] = num_state
+209 state[self.layer_id, :, DEN_STATE, :] = den_state
+210 state[self.layer_id, :, MAX_STATE, :] = max_state
+211 wkv, state = self.wkv_function(key, value, use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,
+212 state=state)
+ +
215 rwkv = F.sigmoid(receptance) * wkv
+216 rwkv = self.output_proj(rwkv)
+217
+218 return rwkv, state
221class Block(Module):
226 def __init__(self, config, layer_id):
+227 super().__init__()
+228 self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
+229 self.attn = TimeMixing(config, layer_id)
+230 self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
+231 self.ffn = ChannelMixing(config, layer_id)
233 def forward(self, x, state=None):
state: batch_size, 5 , n_embd
+ +time mixing
+ +237 residual = x
+238 x, state = self.attn(self.ln_1(x), state=state)
+239 x = x + residual
channel mixing
+ +242 residual = x
+243 x, state = self.ffn(self.ln_2(x), state=state)
+244 x = x + residual
+245 return x, state
248class RWKV(Module):
252 def __init__(self, config, lr_init=0.0008):
+253 super().__init__()
+254 assert config.vocab_size is not None
+255 assert config.block_size is not None
+256 self.config = config
+257 self.lr_init = lr_init ## used to initialize embedding parameters
+258 self.n_layer = config.n_layer
+259 self.n_embd = config.n_embd
Initiate model layers
+ +262 self.rwkv = nn.ModuleDict(dict(
+263 wte=nn.Embedding(config.vocab_size, config.n_embd),
+264 ln_p=LayerNorm(config.n_embd, bias=config.bias),
+265 h=nn.ModuleList([Block(config, layer_id) for layer_id in range(config.n_layer)]),
+266 ln_f=LayerNorm(config.n_embd, bias=config.bias),
+267 ))
Output linear layer
+ +270 self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
272 def forward(self, idx, targets=None, state=None, return_state=False):
+273 b, t = idx.size()
+274 assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
Embedding Layer
+ +277 x = self.rwkv.wte(idx)
Layer Norm
+ +280 x = self.rwkv.ln_p(x)
RWKV Blocks
+ +283 for block_idx, block in enumerate(self.rwkv.h):
+284 x, state = block(x, state)
+285 x = self.rwkv.ln_f(x)
Logit Layer and loss Function (for training)
+ +288 if targets is not None:
if we are given some desired targets also calculate the loss
+ +290 logits = self.lm_head(x)
+291 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
+292 if self.training:
+293 loss = L2Wrap.apply(loss, logits)
+294 else:
inference-time mini-optimization: only forward the lm_head on the very last position
+ +296 logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
+297 loss = None
Return Logits and loss
+ +300 if return_state:
+301 return logits, loss, state
+302 else:
+303 return logits, loss