diff --git a/docs/RWKV/configs.html b/docs/RWKV/configs.html new file mode 100644 index 00000000..909ebb18 --- /dev/null +++ b/docs/RWKV/configs.html @@ -0,0 +1,210 @@ + + + + + + + + + + + + + + + + + + + + + + + configs.py + + + + + + + + + + +
+
+
+
+

+ home + RWKV +

+

+ + Github + + Twitter +

+

+ + View code on Github +

+
+
+
+
+ + +
+
+
1from labml.configs import BaseConfigs
+
+
+
+
+ +

Transformer Configurations

+

This defines configurations for a transformer. The configurations are calculate using option functions. These are lazy loaded and therefore only the necessary modules are calculated.

+ +
+
+
4class RWKVConfigs(BaseConfigs):
+
+
+
+
+ +

Number of attention heads

+ +
+
+
14    n_heads: int = 8
+
+
+
+
+ +

Transformer embedding size

+ +
+
+
16    d_model: int = 512
+
+
+
+
+ +

Number of layers

+ +
+
+
18    n_layers: int = 6
+
+
+
+
+ +

Dropout probability

+ +
+
+
20    dropout: float = 0.1
+
+
+
+
+ +

Number of tokens in the source vocabulary (for token embeddings)

+ +
+
+
22    n_src_vocab: int
+
+
+
+
+ +

Number of tokens in the target vocabulary (to generate logits for prediction)

+ +
+
+
24    n_tgt_vocab: int
+
+
+ +
+ + + + \ No newline at end of file diff --git a/docs/RWKV/experiment.html b/docs/RWKV/experiment.html new file mode 100644 index 00000000..b431f2ff --- /dev/null +++ b/docs/RWKV/experiment.html @@ -0,0 +1,624 @@ + + + + + + + + + + + + + + + + + + + + + + + experiment.py + + + + + + + + + + +
+
+
+
+

+ home + RWKV +

+

+ + Github + + Twitter +

+

+ + View code on Github +

+
+
+
+
+ + +
+
+
1import inspect
+2import math
+3
+4import torch
+5import torch.nn as nn
+6from labml_nn.RWKV.configs import RWKVConfigs
+7
+8from labml_nn.RWKV import RWKV
+9from labml_nn.RWKV import TimeMixing
+10from labml import experiment
+11from labml.configs import option
+12from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+
+
+
+
+ +

Configurations

+

This inherits from NLPAutoRegressionConfigs +

+ +
+
+
15class Configs(NLPAutoRegressionConfigs):
+
+
+
+
+ +

RWKV model

+ +
+
+
24    model: RWKV
+25
+26    rwkv: RWKVConfigs
+
+
+
+
+ +

number of warmup iterations

+ +
+
+
28    warmup_iters: int = 2000
+
+
+
+
+ +

total number of training iterations

+ +
+
+
30    max_iters: int = 600000
+
+
+
+
+ +

weight decay

+ +
+
+
32    weight_decay: float = 1e-1
+
+
+
+
+ +

Custom optimizer

+ +
+
+
34    beta1: float = 0.9
+35    beta2: float = 0.95
+36    optimizer = 'rwkv_optimizer'
+
+
+
+
+ +

RWKV configurations

+ +
+
+
39@option(Configs.rwkv, 'RWKV')
+40def _rwkv_configs(c: Configs):
+
+
+
+
+ +

We use our configurable RWKV implementation

+ +
+
+
47    conf = RWKVConfigs()
+
+
+
+
+ +

Set the vocabulary sizes for embeddings and generating logits

+ +
+
+
49    conf.n_src_vocab = c.n_tokens
+50    conf.n_tgt_vocab = c.n_tokens
+51
+52    return conf
+
+
+
+
+ + +
+
+
55def _init_weights(module, rwkv: RWKVConfigs):
+
+
+
+
+ +

initialize Vector Parameters in TimeMixing

+ +
+
+
57    if isinstance(module, TimeMixing):
+58        layer_id = module.layer_id
+59        n_layer = module.n_layer
+60        n_embd = module.n_embd
+61        attn_sz = n_embd
+62
+63        with torch.no_grad():
+64            ratio_0_to_1 = layer_id / (n_layer - 1)  # 0 to 1
+65            ratio_1_to_almost0 = 1.0 - (layer_id / n_layer)  # 1 to ~0
+66            ddd = torch.ones(1, 1, n_embd)
+67            for i in range(n_embd):
+68                ddd[0, 0, i] = i / n_embd
+69
+70            decay_speed = torch.ones(attn_sz)
+71            for h in range(attn_sz):
+72                decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
+73            module.time_decay = nn.Parameter(decay_speed)
+74
+75            zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5
+76            module.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
+77            module.time_mix_key = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
+78            module.time_mix_value = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
+79            module.time_mix_receptance = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
+
+
+
+
+ +

Create RWKV model and initialize weights

+ +
+
+
82@option(Configs.model)
+83def _model(c: Configs):
+
+
+
+
+ + +
+
+
87    m = RWKV(c.rwkv).to(c.device)
+
+
+
+
+ +

Apply custom weight initialization

+ +
+
+
90    m.apply(_init_weights, c.rwkv)
+91
+92    return m
+
+
+
+
+ + +
+
+
95@option(NLPAutoRegressionConfigs.optimizer)
+96def _configure_optimizers(c: NLPAutoRegressionConfigs):
+
+
+
+
+ +

start with all of the candidate parameters

+ +
+
+
98    param_dict = {pn: p for pn, p in c.model.named_parameters()}
+
+
+
+
+ +

filter out those that do not require grad

+ +
+
+
100    param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
+
+
+
+
+ +

create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.

+ +
+
+
103    decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
+104    nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
+105    optim_groups = [
+106        {'params': decay_params, 'weight_decay': c.weight_decay},
+107        {'params': nodecay_params, 'weight_decay': 0.0}
+108    ]
+109    num_decay_params = sum(p.numel() for p in decay_params)
+110    num_nodecay_params = sum(p.numel() for p in nodecay_params)
+111    print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
+112    print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
+
+
+
+
+ +

Create AdamW optimizer and use the fused version if it is available

+ +
+
+
114    fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
+115    use_fused = fused_available and c.device_type == 'cuda'
+116    extra_args = dict(fused=True) if use_fused else dict()
+117    optimizer = torch.optim.AdamW(optim_groups, lr=c.learning_rate, betas=c.betas, **extra_args)
+118    print(f"using fused AdamW: {use_fused}")
+119
+120    return optimizer
+
+
+
+
+ + +
+
+
123def main():
+
+
+
+
+ +

Create experiment

+ +
+
+
125    experiment.create(name="RWKV")
+
+
+
+
+ +

Create configs

+ +
+
+
127    conf = Configs()
+128    print(conf.model)
+
+
+
+
+ +

Override configurations

+ +
+
+
130    experiment.configs(conf, {
+
+
+
+
+ +

Use character level tokenizer

+ +
+
+
132        'tokenizer': 'character',
+
+
+
+
+ +

Prompt separator is blank

+ +
+
+
134        'prompt_separator': '',
+
+
+
+
+ +

Starting prompt for sampling

+ +
+
+
136        'prompt': 'It is ',
+
+
+
+
+ +

Use Tiny Shakespeare dataset

+ +
+
+
138        'text': 'tiny_shakespeare',
+
+
+
+
+ +

Use a context size of

+ +
+
+
141        'seq_len': 128,
+
+
+
+
+ +

Train for epochs

+ +
+
+
143        'epochs': 32,
+
+
+
+
+ +

Batch size

+ +
+
+
145        'batch_size': 128,
+
+
+
+
+ +

Switch between training and validation for times per epoch

+ +
+
+
148        'inner_iterations': 10,
+149
+150        'rwkv.block_size': 1024,
+
+
+
+
+ +

model

+ +
+
+
152        'rwkv.n_layer': 12,
+153        'rwkv.n_heads': 12,
+154        'rwkv.n_embd': 768
+155    })
+156
+157    print(conf.model)
+
+
+
+
+ +

Set models for saving and loading

+ +
+
+
159    experiment.add_pytorch_models({'model': conf.model})
+
+
+
+
+ +

Start the experiment

+ +
+
+
162    with experiment.start():
+
+
+
+
+ +

Run training

+ +
+
+
164        conf.run()
+
+
+
+
+ +

+ +
+
+
168if __name__ == '__main__':
+169    main()
+
+
+ +
+ + + + \ No newline at end of file diff --git a/docs/RWKV/index.html b/docs/RWKV/index.html new file mode 100644 index 00000000..bf808f4f --- /dev/null +++ b/docs/RWKV/index.html @@ -0,0 +1,836 @@ + + + + + + + + + + + + + + + + + + + + + + + Receptance Weighted Key Value (RWKV) + + + + + + + + + + +
+
+
+
+

+ home + RWKV +

+

+ + Github + + Twitter +

+

+ + View code on Github +

+
+
+
+
+ +

Receptance Weighted Key Value (RWKV)

+

This is a tutorial/implementation of RWKV from paper RWKV: Reinventing RNNs for the Transformer Era in PyTorch.

+

Full definition of a RWKV Language Model, all of it in this single file. References: 1) the official RWKV PyTorch implementation released by Bo Peng 2) huggingface/transformers PyTorch implementation

+ +
+
+
22import torch
+23import torch.nn as nn
+24from torch.nn import functional as F
+25
+26from labml_helpers.module import Module
+27
+28PREV_X_TIME = 0
+29NUM_STATE = 1
+30DEN_STATE = 2
+31MAX_STATE = 3
+32PREV_X_CHANNEL = 4
+
+
+
+
+ +

Layer normalization with bias

+ +
+
+
35class LayerNorm(Module):
+
+
+
+
+ + +
+
+
40    def __init__(self, ndim, bias):
+41        super().__init__()
+42        self.weight = nn.Parameter(torch.ones(ndim))
+43        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
+
+
+
+
+ + +
+
+
45    def forward(self, input):
+46        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
+
+
+
+
+ +

L2 loss wrapper

+

ref

+ +
+
+
49class L2Wrap(torch.autograd.Function):
+
+
+
+
+ + +
+
+
56    @staticmethod
+57    def forward(ctx, loss, y):
+58        ctx.save_for_backward(y)
+59        return loss
+60
+61    @staticmethod
+62    def backward(ctx, grad_output):
+63        y = ctx.saved_tensors[0]
+
+
+
+
+ +

to encourage the logits to be close to 0

+ +
+
+
65        factor = 1e-4 / (y.shape[0] * y.shape[1])
+66        maxx, ids = torch.max(y, -1, keepdim=True)
+67        gy = torch.zeros_like(y)
+68        gy.scatter_(-1, ids, maxx * factor)
+69        return grad_output, gy
+
+
+
+
+ +

Channel Mixing

+ +
+
+
72class ChannelMixing(Module):
+
+
+
+
+ + +
+
+
77    def __init__(self, config, layer_id):
+78        super().__init__()
+79        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+
+
+
+
+ +

token shifting

+ +
+
+
81        self.layer_id = layer_id
+82
+83        n_embd = config.n_embd
+84        intermediate_size = (
+85            config.intermediate_size if config.intermediate_size is not None else 4 * n_embd
+86        )
+
+
+
+
+ +

Learnable Matrix

+ +
+
+
89        self.key_proj = nn.Linear(n_embd, intermediate_size, bias=False)
+90        self.value_proj = nn.Linear(intermediate_size, n_embd, bias=False)
+91        self.receptance_proj = nn.Linear(n_embd, n_embd, bias=False)
+
+
+
+
+ +

Learnable Vector

+ +
+
+
94        self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
+95        self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
+
+
+
+
+ +

x = (Batch,Time,Channel)

+ +
+
+
97    def forward(self, x, state=None):
+
+
+
+
+ + +
+
+
101        if state is not None:
+102            prev_x = state[self.layer_id, :, [PREV_X_CHANNEL], :]
+103            state[self.layer_id, :, [PREV_X_CHANNEL], :] = x
+104        else:
+105            prev_x = self.time_shift(x)
+
+
+
+
+ +

+ +
+
+
108        receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
+109        receptance = self.receptance_proj(receptance)
+
+
+
+
+ +

+ +
+
+
112        key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
+113        key = self.key_proj(key)
+
+
+
+
+ +

+ +
+
+
116        value = self.value_proj(torch.square(torch.relu(key)))
+
+
+
+
+ +

+ +
+
+
119        out = F.sigmoid(receptance) * value
+120        return out, state
+
+
+
+
+ +

Time Mixing

+ +
+
+
123class TimeMixing(Module):
+
+
+
+
+ + +
+
+
128    def __init__(self, config, layer_id):
+129        super().__init__()
+130        self.config = config
+131        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+132        self.layer_id = layer_id
+133
+134        n_embd = config.n_embd
+135        attn_sz = n_embd
+
+
+
+
+ +

learnable matrix

+ +
+
+
138        self.key_proj = nn.Linear(n_embd, attn_sz, bias=False)
+139        self.value_proj = nn.Linear(n_embd, attn_sz, bias=False)
+140        self.receptance_proj = nn.Linear(n_embd, attn_sz, bias=False)
+141        self.output_proj = nn.Linear(attn_sz, n_embd, bias=False)
+
+
+
+
+ +

learnable vector

+ +
+
+
144        self.time_decay = nn.Parameter(torch.empty(attn_sz))
+145        self.time_first = nn.Parameter(torch.empty(attn_sz))
+146        self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
+147        self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd))
+148        self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
+
+
+
+
+ +

x = (Batch,Time,Channel)

+ +
+
+
150    def forward(self, x, state=None):
+
+
+
+
+ + +
+
+
154        if state is not None:
+155            prev_x = state[self.layer_id, :, [PREV_X_TIME], :]
+156            state[self.layer_id, :, [PREV_X_TIME], :] = x
+157        else:
+158            prev_x = self.time_shift(x)
+
+
+
+
+ +

+ +
+
+
161        receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
+162        receptance = self.receptance_proj(receptance)
+
+
+
+
+ +

+ +
+
+
165        key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
+166        key = self.key_proj(key)
+
+
+
+
+ +

+ +
+
+
169        value = x * self.time_mix_value + prev_x * (1 - self.time_mix_value)
+170        value = self.value_proj(value)
+
+
+
+
+ +

WKV calculation

+ +
+
+
173        _, seq_length, _ = key.size()
+174        output = torch.zeros_like(key)
+175
+176        if state is None:
+177            num_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
+178            den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
+179            max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
+180        else:
+181            num_state = state[self.layer_id, :, NUM_STATE, :]
+182            den_state = state[self.layer_id, :, DEN_STATE, :]
+183            max_state = state[self.layer_id, :, MAX_STATE, :]
+184
+185        time_decay = -torch.exp(self.time_decay)
+186
+187        for current_index in range(seq_length):
+188            current_key = key[:, current_index].float()
+189            current_value = value[:, current_index]
+
+
+
+
+ +

+ +
+
+
192            max_for_output = torch.maximum(max_state, current_key + self.time_first)
+193            e1 = torch.exp(max_state - max_for_output)
+194            e2 = torch.exp(current_key + self.time_first - max_for_output)
+195            numerator = e1 * num_state + e2 * current_value
+196            denominator = e1 * den_state + e2
+197            output[:, current_index] = (numerator / denominator).to(output.dtype)
+
+
+
+
+ +

Update state for next iteration

+ +
+
+
200            max_for_state = torch.maximum(max_state + time_decay, current_key)
+201            e1 = torch.exp(max_state + time_decay - max_for_state)
+202            e2 = torch.exp(current_key - max_for_state)
+203            num_state = e1 * num_state + e2 * current_value
+204            den_state = e1 * den_state + e2
+205            max_state = max_for_state
+
+
+
+
+ +

update states

+ +
+
+
208        state[self.layer_id, :, NUM_STATE, :] = num_state
+209        state[self.layer_id, :, DEN_STATE, :] = den_state
+210        state[self.layer_id, :, MAX_STATE, :] = max_state
+211        wkv, state = self.wkv_function(key, value, use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,
+212                                       state=state)
+
+
+
+
+ +

+ +
+
+
215        rwkv = F.sigmoid(receptance) * wkv
+216        rwkv = self.output_proj(rwkv)
+217
+218        return rwkv, state
+
+
+
+
+ +

RWKV block element

+ +
+
+
221class Block(Module):
+
+
+
+
+ + +
+
+
226    def __init__(self, config, layer_id):
+227        super().__init__()
+228        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
+229        self.attn = TimeMixing(config, layer_id)
+230        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
+231        self.ffn = ChannelMixing(config, layer_id)
+
+
+
+
+ + +
+
+
233    def forward(self, x, state=None):
+
+
+
+
+ +

state: batch_size, 5 , n_embd

+ +
+
+
+
+
+
+
+ +

time mixing

+ +
+
+
237        residual = x
+238        x, state = self.attn(self.ln_1(x), state=state)
+239        x = x + residual
+
+
+
+
+ +

channel mixing

+ +
+
+
242        residual = x
+243        x, state = self.ffn(self.ln_2(x), state=state)
+244        x = x + residual
+245        return x, state
+
+
+
+
+ +

RWKV

+ +
+
+
248class RWKV(Module):
+
+
+
+
+ + +
+
+
252    def __init__(self, config, lr_init=0.0008):
+253        super().__init__()
+254        assert config.vocab_size is not None
+255        assert config.block_size is not None
+256        self.config = config
+257        self.lr_init = lr_init  ## used to initialize embedding parameters
+258        self.n_layer = config.n_layer
+259        self.n_embd = config.n_embd
+
+
+
+
+ +

Initiate model layers

+ +
+
+
262        self.rwkv = nn.ModuleDict(dict(
+263            wte=nn.Embedding(config.vocab_size, config.n_embd),
+264            ln_p=LayerNorm(config.n_embd, bias=config.bias),
+265            h=nn.ModuleList([Block(config, layer_id) for layer_id in range(config.n_layer)]),
+266            ln_f=LayerNorm(config.n_embd, bias=config.bias),
+267        ))
+
+
+
+
+ +

Output linear layer

+ +
+
+
270        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+
+
+
+
+ + +
+
+
272    def forward(self, idx, targets=None, state=None, return_state=False):
+273        b, t = idx.size()
+274        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
+
+
+
+
+ +

Embedding Layer

+ +
+
+
277        x = self.rwkv.wte(idx)
+
+
+
+
+ +

Layer Norm

+ +
+
+
280        x = self.rwkv.ln_p(x)
+
+
+
+
+ +

RWKV Blocks

+ +
+
+
283        for block_idx, block in enumerate(self.rwkv.h):
+284            x, state = block(x, state)
+285        x = self.rwkv.ln_f(x)
+
+
+
+
+ +

Logit Layer and loss Function (for training)

+ +
+
+
288        if targets is not None:
+
+
+
+
+ +

if we are given some desired targets also calculate the loss

+ +
+
+
290            logits = self.lm_head(x)
+291            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
+292            if self.training:
+293                loss = L2Wrap.apply(loss, logits)
+294        else:
+
+
+
+
+ +

inference-time mini-optimization: only forward the lm_head on the very last position

+ +
+
+
296            logits = self.lm_head(x[:, [-1], :])  # note: using list [-1] to preserve the time dim
+297            loss = None
+
+
+
+
+ +

Return Logits and loss

+ +
+
+
300        if return_state:
+301            return logits, loss, state
+302        else:
+303            return logits, loss
+
+
+ +
+ + + + \ No newline at end of file