From c609bf7b9b5da105db2b5eabd6c22d835fd60b91 Mon Sep 17 00:00:00 2001 From: Wang hl <59834623+kxzxvbk@users.noreply.github.com> Date: Sun, 5 Mar 2023 00:33:20 -0500 Subject: [PATCH] feature(whl): add demo code for lstm and gtrxl (#45) * init commit * polish * polish according to comments --- chapter5_time/gtrxl.py | 285 +++++++++++++++++++++++++++++++++++++++++ chapter5_time/lstm.py | 162 +++++++++++++++++++++++ 2 files changed, 447 insertions(+) create mode 100644 chapter5_time/gtrxl.py create mode 100644 chapter5_time/lstm.py diff --git a/chapter5_time/gtrxl.py b/chapter5_time/gtrxl.py new file mode 100644 index 0000000..0e85f13 --- /dev/null +++ b/chapter5_time/gtrxl.py @@ -0,0 +1,285 @@ +""" +Gated Transformer XL (GTrXL) is a stabilized transformer architecture for reinforcement learning. +This document mainly includes: +- Pytorch implementation for GTrXL. +- An example to test GTrXL. +""" +from typing import Optional, Dict +import warnings +import numpy as np +import torch +import torch.nn as nn +import treetensor +from ding.torch_utils import GRUGatingUnit, build_normalization + +from ding.torch_utils.network.nn_module import fc_block +from ding.torch_utils.network.gtrxl import PositionalEmbedding, Memory, AttentionXL + + +class GatedTransformerXLLayer(torch.nn.Module): + """ + **Overview:** + Attention layer of GTrXL + """ + def __init__( + self, + input_dim: int, + head_dim: int, + hidden_dim: int, + head_num: int, + mlp_num: int, + dropout: nn.Module, + activation: nn.Module, + gru_gating: bool = True, + gru_bias: float = 2. + ) -> None: + super(GatedTransformerXLLayer, self).__init__() + self.dropout = dropout + # Decide whether to use GRU-gating. + self.gating = gru_gating + if self.gating is True: + self.gate1 = GRUGatingUnit(input_dim, gru_bias) + self.gate2 = GRUGatingUnit(input_dim, gru_bias) + # Build attention block. + self.attention = AttentionXL( + input_dim, + head_dim, + head_num, + dropout, + ) + # Build Feed-Forward-Network. + layers = [] + dims = [input_dim] + [hidden_dim] * (mlp_num - 1) + [input_dim] + for i in range(mlp_num): + layers.append(fc_block(dims[i], dims[i + 1], activation=activation)) + if i != mlp_num - 1: + layers.append(self.dropout) + layers.append(self.dropout) + self.mlp = nn.Sequential(*layers) + # Build layer norm. + self.layernorm1 = build_normalization('LN')(input_dim) + self.layernorm2 = build_normalization('LN')(input_dim) + self.activation = activation + + def forward( + self, + inputs: torch.Tensor, + pos_embedding: torch.Tensor, + u: torch.nn.Parameter, + v: torch.nn.Parameter, + memory: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Concat memory with input across sequence dimension. The shape is: [full_sequence, batch_size, input_dim] + full_input = torch.cat([memory, inputs], dim=0) + # Forward calculation for GTrXL layer. + x1 = self.layernorm1(full_input) + # Attention module. + a1 = self.dropout(self.attention(inputs, pos_embedding, x1, u, v, mask=mask)) + a1 = self.activation(a1) + o1 = self.gate1(inputs, a1) if self.gating else inputs + a1 + x2 = self.layernorm2(o1) + # Feed Forward Network. + m2 = self.dropout(self.mlp(x2)) + o2 = self.gate2(o1, m2) if self.gating else o1 + m2 + return o2 + + +class GTrXL(nn.Module): + """ + **Overview:** + PyTorch implementation for GTrXL. + """ + def __init__( + self, + input_dim: int, + head_dim: int = 128, + embedding_dim: int = 256, + head_num: int = 2, + mlp_num: int = 2, + layer_num: int = 3, + memory_len: int = 64, + dropout_ratio: float = 0., + activation: nn.Module = nn.ReLU(), + gru_gating: bool = True, + gru_bias: float = 2., + use_embedding_layer: bool = True, + ) -> None: + super(GTrXL, self).__init__() + assert embedding_dim % 2 == 0, 'embedding_dim={} should be even'.format(input_dim) + self.head_num = head_num + self.head_dim = head_dim + self.layer_num = layer_num + self.embedding_dim = embedding_dim + if isinstance(input_dim, list): + input_dim = np.prod(input_dim) + # Initialize embedding layer. + self.use_embedding_layer = use_embedding_layer + if use_embedding_layer: + self.embedding = fc_block(input_dim, embedding_dim, activation=activation) + # Initialize activate function. + self.activation = activation + # Initialize position embedding. + self.pos_embedding = PositionalEmbedding(embedding_dim) + # Memory to save hidden states of past segments. It will be initialized in the forward method to get its size dynamically. + self.memory = None + self.memory_len = memory_len + # Initialize GTrXL layers. + layers = [] + dims = [embedding_dim] + [embedding_dim] * layer_num + self.dropout = nn.Dropout(dropout_ratio) if dropout_ratio > 0 else nn.Identity() + for i in range(layer_num): + layers.append( + GatedTransformerXLLayer( + dims[i], head_dim, embedding_dim, head_num, mlp_num, self.dropout, self.activation, gru_gating, + gru_bias + ) + ) + self.layers = nn.Sequential(*layers) + # u and v are the parameters to compute global content bias and global positional bias. + self.u, self.v = ( + torch.nn.Parameter(torch.zeros(self.head_num, self.head_dim)), + torch.nn.Parameter(torch.zeros(self.head_num, self.head_dim)), + ) + # Create an attention mask for each different seq_len. In this way we don't need to create a new one each time we call the forward method. + self.att_mask = {} + # Create a pos embedding for each different seq_len. In this way we don't need to create a new one each time we call the forward method. + self.pos_embedding_dict = {} + + def reset_memory(self, batch_size: Optional[int] = None, state: Optional[torch.Tensor] = None): + # Reset the memory of GTrXL. + self.memory = Memory(memory_len=self.memory_len, layer_num=self.layer_num, embedding_dim=self.embedding_dim) + # If batch_size is not None, specify the batch_size when initializing the memory. + if batch_size is not None: + self.memory = Memory(self.memory_len, batch_size, self.embedding_dim, self.layer_num) + # If state is not None, add state into the memory. + elif state is not None: + self.memory.init(state) + + def get_memory(self): + # Get the memory of GTrXL. + if self.memory is None: + return None + else: + return self.memory.get() + + def forward(self, x: torch.Tensor, batch_first: bool = False, return_mem: bool = True) -> Dict[str, torch.Tensor]: + # If the first dimension of input x is batch_size, then reshape x from [batch_size ,sequence_length ,input_dim] to [sequence_length, batch_size, input_dim] + if batch_first: + x = torch.transpose(x, 1, 0) + cur_seq, bs = x.shape[:2] + # Get back memory. + memory = None if self.memory is None else self.memory.get() + # Abnormal case: no memory or memory shape mismatch. + if memory is None: + self.reset_memory(bs) + elif memory.shape[-2] != bs or memory.shape[-1] != self.embedding_dim: + warnings.warn( + "Memory {} and Input {} dimensions don't match," + " this will cause the memory to be initialized to fit your input!".format( + list(memory.shape[-2:]), [x.shape[-2]] + [self.embedding_dim] + ) + ) + self.reset_memory(bs) + self.memory.to(x.device) + memory = self.memory.get() + # Pass through embedding layer. + if self.use_embedding_layer: + x = self.dropout(self.embedding(x)) + # Get full sequence length: memory length + current length + prev_seq = self.memory_len + full_seq = cur_seq + prev_seq + # If the attention mask for current sequence length is already created, reuse the mask stored in self.att_mask. + if cur_seq in self.att_mask.keys(): + attn_mask = self.att_mask[cur_seq] + # Otherwise, create a new attention mask and store it into self.att_mask. + else: + # For example, if cur_seq = 3, full_seq = 7, then the mask is: $$ \begin{matrix} 0 & 0 & 0 & 0 & 0 & 1 & 1 \\ 0 & 0 & 0 & 0 & 0 & 0 & 1 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 \end{matrix}$$ This forces that the hidden state of current token is only associated with previous tokens. + attn_mask = ( + torch.triu( + torch.ones((cur_seq, full_seq)), + diagonal=1 + prev_seq, + ).bool().unsqueeze(-1).to(x.device) + ) + self.att_mask[cur_seq] = attn_mask + # If the position encoding for current sequence length is already created, reuse it stored in self.pos_embedding_dict. + if cur_seq in self.pos_embedding_dict.keys(): + pos_embedding = self.pos_embedding_dict[cur_seq] + # Otherwise, create a new position encoding and store it into self.pos_embedding_dict. + else: + pos_ips = torch.arange(full_seq - 1, -1, -1.0, dtype=torch.float) # full_seq + pos_embedding = self.pos_embedding(pos_ips.to(x.device)) + self.pos_embedding_dict[cur_seq] = pos_embedding + pos_embedding = self.dropout(pos_embedding) # full_seq x 1 x embedding_dim + + hidden_state = [x] + out = x + # Calculate results for each GTrXL layer. + for i in range(self.layer_num): + layer = self.layers[i] + out = layer( + out, + pos_embedding, + self.u, + self.v, + mask=attn_mask, + memory=memory[i], + ) + hidden_state.append(out.clone()) + out = self.dropout(out) + # Update the GTrXL memory. + self.memory.update(hidden_state) + # If the first dimension of output is required to be batch_size, then reshape x from [sequence_length, batch_size, input_dim] to [batch_size ,sequence_length ,input_dim]. + if batch_first: + out = torch.transpose(out, 1, 0) + # Return memory is needed. + if return_mem: + output = treetensor.Object({"logit": out, "memory": memory}) + else: + output = treetensor.Object({"logit": out}) + return output + + +def test_gtrxl() -> None: + # Generate data for testing. + input_dim = 128 + seq_len = 64 + bs = 32 + embedding_dim = 256 + layer_num = 5 + mem_len = 40 + memory = [None, torch.rand(layer_num + 1, mem_len, bs, embedding_dim)] + + # Test GTrXL under different situations. + for i in range(2): + m = memory[i] + model = GTrXL( + input_dim=input_dim, + head_dim=2, + embedding_dim=embedding_dim, + memory_len=mem_len, + head_num=2, + mlp_num=2, + layer_num=layer_num, + ) + # Input shape: [sequence_length, batch_size, input_dim] + input = torch.rand(seq_len, bs, input_dim, requires_grad=True) + # Reset the model memory. + if m is None: + model.reset_memory(batch_size=bs) + else: + model.reset_memory(state=m) + output = model(input) + # Check the calculation results. + assert output['logit'].shape == (seq_len, bs, embedding_dim) + assert output['memory'].shape == (layer_num + 1, mem_len, bs, embedding_dim) + torch.sum(output['logit']).backward() + assert isinstance(input.grad, torch.Tensor) + # Check memory. + memory_out = output['memory'] + if m is not None: + assert torch.all(torch.eq(memory_out, m)) + + +if __name__ == '__main__': + test_gtrxl() diff --git a/chapter5_time/lstm.py b/chapter5_time/lstm.py new file mode 100644 index 0000000..fc23eef --- /dev/null +++ b/chapter5_time/lstm.py @@ -0,0 +1,162 @@ +""" +Long Short Term Memory (LSTM) is a kind of recurrent neural network that can capture long-short term information. +This document mainly includes: +- Pytorch implementation for LSTM. +- An example to test LSTM. +For beginners, you can refer to to learn the basics about how LSTM works. +""" +from typing import Optional, Union, Tuple, List, Dict +import math +import torch +import torch.nn as nn +from ding.torch_utils.network.rnn import is_sequence +from ding.torch_utils import build_normalization + + +class LSTM(nn.Module): + """ + **Overview:** + Implementation of LSTM cell with layer norm. + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int, + norm_type: Optional[str] = 'LN', + dropout: float = 0. + ) -> None: + # Initialize arguments. + super(LSTM, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + # Initialize normalization functions. + norm_func = build_normalization(norm_type) + self.norm = nn.ModuleList([norm_func(hidden_size * 4) for _ in range(2 * num_layers)]) + # Initialize LSTM parameters. + self.wx = nn.ParameterList() + self.wh = nn.ParameterList() + dims = [input_size] + [hidden_size] * num_layers + for l in range(num_layers): + self.wx.append(nn.Parameter(torch.zeros(dims[l], dims[l + 1] * 4))) + self.wh.append(nn.Parameter(torch.zeros(hidden_size, hidden_size * 4))) + self.bias = nn.Parameter(torch.zeros(num_layers, hidden_size * 4)) + # Initialize the Dropout Layer. + self.use_dropout = dropout > 0. + if self.use_dropout: + self.dropout = nn.Dropout(dropout) + self._init() + + # Dealing with different types of input and return preprocessed prev_state. + def _before_forward(self, inputs: torch.Tensor, prev_state: Union[None, List[Dict]]) -> torch.Tensor: + seq_len, batch_size = inputs.shape[:2] + # If prev_state is None, it indicates that this is the beginning of a sequence. In this case, prev_state will be initialized as zero. + if prev_state is None: + zeros = torch.zeros( + self.num_layers, + batch_size, + self.hidden_size, + dtype=inputs.dtype, + device=inputs.device + ) + prev_state = (zeros, zeros) + # If prev_state is not None, then preprocess it into one batch. + else: + assert len(prev_state) == batch_size + state = [[v for v in prev.values()] for prev in prev_state] + state = list(zip(*state)) + prev_state = [torch.cat(t, dim=1) for t in state] + + return prev_state + + def _init(self): + # Initialize parameters. Each parameter is initialized using a uniform distribution of: $$U(-\sqrt {\frac 1 {HiddenSize}}, -\sqrt {\frac 1 {HiddenSize}})$$ + gain = math.sqrt(1. / self.hidden_size) + for l in range(self.num_layers): + torch.nn.init.uniform_(self.wx[l], -gain, gain) + torch.nn.init.uniform_(self.wh[l], -gain, gain) + if self.bias is not None: + torch.nn.init.uniform_(self.bias[l], -gain, gain) + + def forward(self, + inputs: torch.Tensor, + prev_state: torch.Tensor, + ) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]: + # The shape of input is: [sequence length, batch size, input size] + seq_len, batch_size = inputs.shape[:2] + prev_state = self._before_forward(inputs, prev_state) + + H, C = prev_state + x = inputs + next_state = [] + for l in range(self.num_layers): + h, c = H[l], C[l] + new_x = [] + for s in range(seq_len): + # Calculate $$z, z^i, z^f, z^o$$ simultaneously. + gate = self.norm[l * 2](torch.matmul(x[s], self.wx[l]) + ) + self.norm[l * 2 + 1](torch.matmul(h, self.wh[l])) + if self.bias is not None: + gate += self.bias[l] + gate = list(torch.chunk(gate, 4, dim=1)) + i, f, o, z = gate + # $$z^i = \sigma (Wx^ix^t + Wh^ih^{t-1})$$ + i = torch.sigmoid(i) + # $$z^f = \sigma (Wx^fx^t + Wh^fh^{t-1})$$ + f = torch.sigmoid(f) + # $$z^o = \sigma (Wx^ox^t + Wh^oh^{t-1})$$ + o = torch.sigmoid(o) + # $$z = tanh(Wxx^t + Whh^{t-1})$$ + z = torch.tanh(z) + # $$c^t = z^f \odot c^{t-1}+z^i \odot z$$ + c = f * c + i * z + # $$h^t = z^o \odot tanh(c^t)$$ + h = o * torch.tanh(c) + new_x.append(h) + next_state.append((h, c)) + x = torch.stack(new_x, dim=0) + # Dropout layer. + if self.use_dropout and l != self.num_layers - 1: + x = self.dropout(x) + next_state = [torch.stack(t, dim=0) for t in zip(*next_state)] + # Return list type, split the next_state . + h, c = next_state + batch_size = h.shape[1] + # Split h with shape [num_layers, batch_size, hidden_size] to a list with length batch_size and each element is a tensor with shape [num_layers, 1, hidden_size]. The same operation is performed on c. + next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)] + next_state = list(zip(*next_state)) + next_state = [{k: v for k, v in zip(['h', 'c'], item)} for item in next_state] + return x, next_state + + +def test_lstm(): + # Randomly generate test data. + seq_len = 2 + num_layers = 3 + input_size = 4 + hidden_size = 5 + batch_size = 6 + norm_type = 'LN' + dropout = 0.1 + input = torch.rand(seq_len, batch_size, input_size).requires_grad_(True) + lstm = LSTM(input_size, hidden_size, num_layers, norm_type, dropout) + + # Test the LSTM recurrently, using the hidden states of last input as new prev_state. + prev_state = None + for s in range(seq_len): + input_step = input[s:s + 1] + # The prev_state is None if the input_step is the first step of the sequence. Otherwise, the prev_state contains a list of dictions with key 'h', 'c', and the corresponding values are tensors with shape [num_layers, 1, hidden_size]. The length of the list equuals to the batch_size. + output, prev_state = lstm(input_step, prev_state) + + # Check whether the output is correct. + assert output.shape == (1, batch_size, hidden_size) + assert len(prev_state) == batch_size + assert prev_state[0]['h'].shape == (num_layers, 1, hidden_size) + torch.mean(output).backward() + assert isinstance(input.grad, torch.Tensor) + + +if __name__ == '__main__': + test_lstm()