diff --git a/chapter5_time/gtrxl.py b/chapter5_time/gtrxl.py
new file mode 100644
index 0000000..0e85f13
--- /dev/null
+++ b/chapter5_time/gtrxl.py
@@ -0,0 +1,285 @@
+"""
+Gated Transformer XL (GTrXL) is a stabilized transformer architecture for reinforcement learning.
+This document mainly includes:
+- Pytorch implementation for GTrXL.
+- An example to test GTrXL.
+"""
+from typing import Optional, Dict
+import warnings
+import numpy as np
+import torch
+import torch.nn as nn
+import treetensor
+from ding.torch_utils import GRUGatingUnit, build_normalization
+
+from ding.torch_utils.network.nn_module import fc_block
+from ding.torch_utils.network.gtrxl import PositionalEmbedding, Memory, AttentionXL
+
+
+class GatedTransformerXLLayer(torch.nn.Module):
+ """
+ **Overview:**
+ Attention layer of GTrXL
+ """
+ def __init__(
+ self,
+ input_dim: int,
+ head_dim: int,
+ hidden_dim: int,
+ head_num: int,
+ mlp_num: int,
+ dropout: nn.Module,
+ activation: nn.Module,
+ gru_gating: bool = True,
+ gru_bias: float = 2.
+ ) -> None:
+ super(GatedTransformerXLLayer, self).__init__()
+ self.dropout = dropout
+ # Decide whether to use GRU-gating.
+ self.gating = gru_gating
+ if self.gating is True:
+ self.gate1 = GRUGatingUnit(input_dim, gru_bias)
+ self.gate2 = GRUGatingUnit(input_dim, gru_bias)
+ # Build attention block.
+ self.attention = AttentionXL(
+ input_dim,
+ head_dim,
+ head_num,
+ dropout,
+ )
+ # Build Feed-Forward-Network.
+ layers = []
+ dims = [input_dim] + [hidden_dim] * (mlp_num - 1) + [input_dim]
+ for i in range(mlp_num):
+ layers.append(fc_block(dims[i], dims[i + 1], activation=activation))
+ if i != mlp_num - 1:
+ layers.append(self.dropout)
+ layers.append(self.dropout)
+ self.mlp = nn.Sequential(*layers)
+ # Build layer norm.
+ self.layernorm1 = build_normalization('LN')(input_dim)
+ self.layernorm2 = build_normalization('LN')(input_dim)
+ self.activation = activation
+
+ def forward(
+ self,
+ inputs: torch.Tensor,
+ pos_embedding: torch.Tensor,
+ u: torch.nn.Parameter,
+ v: torch.nn.Parameter,
+ memory: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # Concat memory with input across sequence dimension. The shape is: [full_sequence, batch_size, input_dim]
+ full_input = torch.cat([memory, inputs], dim=0)
+ # Forward calculation for GTrXL layer.
+ x1 = self.layernorm1(full_input)
+ # Attention module.
+ a1 = self.dropout(self.attention(inputs, pos_embedding, x1, u, v, mask=mask))
+ a1 = self.activation(a1)
+ o1 = self.gate1(inputs, a1) if self.gating else inputs + a1
+ x2 = self.layernorm2(o1)
+ # Feed Forward Network.
+ m2 = self.dropout(self.mlp(x2))
+ o2 = self.gate2(o1, m2) if self.gating else o1 + m2
+ return o2
+
+
+class GTrXL(nn.Module):
+ """
+ **Overview:**
+ PyTorch implementation for GTrXL.
+ """
+ def __init__(
+ self,
+ input_dim: int,
+ head_dim: int = 128,
+ embedding_dim: int = 256,
+ head_num: int = 2,
+ mlp_num: int = 2,
+ layer_num: int = 3,
+ memory_len: int = 64,
+ dropout_ratio: float = 0.,
+ activation: nn.Module = nn.ReLU(),
+ gru_gating: bool = True,
+ gru_bias: float = 2.,
+ use_embedding_layer: bool = True,
+ ) -> None:
+ super(GTrXL, self).__init__()
+ assert embedding_dim % 2 == 0, 'embedding_dim={} should be even'.format(input_dim)
+ self.head_num = head_num
+ self.head_dim = head_dim
+ self.layer_num = layer_num
+ self.embedding_dim = embedding_dim
+ if isinstance(input_dim, list):
+ input_dim = np.prod(input_dim)
+ # Initialize embedding layer.
+ self.use_embedding_layer = use_embedding_layer
+ if use_embedding_layer:
+ self.embedding = fc_block(input_dim, embedding_dim, activation=activation)
+ # Initialize activate function.
+ self.activation = activation
+ # Initialize position embedding.
+ self.pos_embedding = PositionalEmbedding(embedding_dim)
+ # Memory to save hidden states of past segments. It will be initialized in the forward method to get its size dynamically.
+ self.memory = None
+ self.memory_len = memory_len
+ # Initialize GTrXL layers.
+ layers = []
+ dims = [embedding_dim] + [embedding_dim] * layer_num
+ self.dropout = nn.Dropout(dropout_ratio) if dropout_ratio > 0 else nn.Identity()
+ for i in range(layer_num):
+ layers.append(
+ GatedTransformerXLLayer(
+ dims[i], head_dim, embedding_dim, head_num, mlp_num, self.dropout, self.activation, gru_gating,
+ gru_bias
+ )
+ )
+ self.layers = nn.Sequential(*layers)
+ # u and v are the parameters to compute global content bias and global positional bias.
+ self.u, self.v = (
+ torch.nn.Parameter(torch.zeros(self.head_num, self.head_dim)),
+ torch.nn.Parameter(torch.zeros(self.head_num, self.head_dim)),
+ )
+ # Create an attention mask for each different seq_len. In this way we don't need to create a new one each time we call the forward method.
+ self.att_mask = {}
+ # Create a pos embedding for each different seq_len. In this way we don't need to create a new one each time we call the forward method.
+ self.pos_embedding_dict = {}
+
+ def reset_memory(self, batch_size: Optional[int] = None, state: Optional[torch.Tensor] = None):
+ # Reset the memory of GTrXL.
+ self.memory = Memory(memory_len=self.memory_len, layer_num=self.layer_num, embedding_dim=self.embedding_dim)
+ # If batch_size is not None, specify the batch_size when initializing the memory.
+ if batch_size is not None:
+ self.memory = Memory(self.memory_len, batch_size, self.embedding_dim, self.layer_num)
+ # If state is not None, add state into the memory.
+ elif state is not None:
+ self.memory.init(state)
+
+ def get_memory(self):
+ # Get the memory of GTrXL.
+ if self.memory is None:
+ return None
+ else:
+ return self.memory.get()
+
+ def forward(self, x: torch.Tensor, batch_first: bool = False, return_mem: bool = True) -> Dict[str, torch.Tensor]:
+ # If the first dimension of input x is batch_size, then reshape x from [batch_size ,sequence_length ,input_dim] to [sequence_length, batch_size, input_dim]
+ if batch_first:
+ x = torch.transpose(x, 1, 0)
+ cur_seq, bs = x.shape[:2]
+ # Get back memory.
+ memory = None if self.memory is None else self.memory.get()
+ # Abnormal case: no memory or memory shape mismatch.
+ if memory is None:
+ self.reset_memory(bs)
+ elif memory.shape[-2] != bs or memory.shape[-1] != self.embedding_dim:
+ warnings.warn(
+ "Memory {} and Input {} dimensions don't match,"
+ " this will cause the memory to be initialized to fit your input!".format(
+ list(memory.shape[-2:]), [x.shape[-2]] + [self.embedding_dim]
+ )
+ )
+ self.reset_memory(bs)
+ self.memory.to(x.device)
+ memory = self.memory.get()
+ # Pass through embedding layer.
+ if self.use_embedding_layer:
+ x = self.dropout(self.embedding(x))
+ # Get full sequence length: memory length + current length
+ prev_seq = self.memory_len
+ full_seq = cur_seq + prev_seq
+ # If the attention mask for current sequence length is already created, reuse the mask stored in self.att_mask.
+ if cur_seq in self.att_mask.keys():
+ attn_mask = self.att_mask[cur_seq]
+ # Otherwise, create a new attention mask and store it into self.att_mask.
+ else:
+ # For example, if cur_seq = 3, full_seq = 7, then the mask is: $$ \begin{matrix} 0 & 0 & 0 & 0 & 0 & 1 & 1 \\ 0 & 0 & 0 & 0 & 0 & 0 & 1 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 \end{matrix}$$ This forces that the hidden state of current token is only associated with previous tokens.
+ attn_mask = (
+ torch.triu(
+ torch.ones((cur_seq, full_seq)),
+ diagonal=1 + prev_seq,
+ ).bool().unsqueeze(-1).to(x.device)
+ )
+ self.att_mask[cur_seq] = attn_mask
+ # If the position encoding for current sequence length is already created, reuse it stored in self.pos_embedding_dict.
+ if cur_seq in self.pos_embedding_dict.keys():
+ pos_embedding = self.pos_embedding_dict[cur_seq]
+ # Otherwise, create a new position encoding and store it into self.pos_embedding_dict.
+ else:
+ pos_ips = torch.arange(full_seq - 1, -1, -1.0, dtype=torch.float) # full_seq
+ pos_embedding = self.pos_embedding(pos_ips.to(x.device))
+ self.pos_embedding_dict[cur_seq] = pos_embedding
+ pos_embedding = self.dropout(pos_embedding) # full_seq x 1 x embedding_dim
+
+ hidden_state = [x]
+ out = x
+ # Calculate results for each GTrXL layer.
+ for i in range(self.layer_num):
+ layer = self.layers[i]
+ out = layer(
+ out,
+ pos_embedding,
+ self.u,
+ self.v,
+ mask=attn_mask,
+ memory=memory[i],
+ )
+ hidden_state.append(out.clone())
+ out = self.dropout(out)
+ # Update the GTrXL memory.
+ self.memory.update(hidden_state)
+ # If the first dimension of output is required to be batch_size, then reshape x from [sequence_length, batch_size, input_dim] to [batch_size ,sequence_length ,input_dim].
+ if batch_first:
+ out = torch.transpose(out, 1, 0)
+ # Return memory is needed.
+ if return_mem:
+ output = treetensor.Object({"logit": out, "memory": memory})
+ else:
+ output = treetensor.Object({"logit": out})
+ return output
+
+
+def test_gtrxl() -> None:
+ # Generate data for testing.
+ input_dim = 128
+ seq_len = 64
+ bs = 32
+ embedding_dim = 256
+ layer_num = 5
+ mem_len = 40
+ memory = [None, torch.rand(layer_num + 1, mem_len, bs, embedding_dim)]
+
+ # Test GTrXL under different situations.
+ for i in range(2):
+ m = memory[i]
+ model = GTrXL(
+ input_dim=input_dim,
+ head_dim=2,
+ embedding_dim=embedding_dim,
+ memory_len=mem_len,
+ head_num=2,
+ mlp_num=2,
+ layer_num=layer_num,
+ )
+ # Input shape: [sequence_length, batch_size, input_dim]
+ input = torch.rand(seq_len, bs, input_dim, requires_grad=True)
+ # Reset the model memory.
+ if m is None:
+ model.reset_memory(batch_size=bs)
+ else:
+ model.reset_memory(state=m)
+ output = model(input)
+ # Check the calculation results.
+ assert output['logit'].shape == (seq_len, bs, embedding_dim)
+ assert output['memory'].shape == (layer_num + 1, mem_len, bs, embedding_dim)
+ torch.sum(output['logit']).backward()
+ assert isinstance(input.grad, torch.Tensor)
+ # Check memory.
+ memory_out = output['memory']
+ if m is not None:
+ assert torch.all(torch.eq(memory_out, m))
+
+
+if __name__ == '__main__':
+ test_gtrxl()
diff --git a/chapter5_time/lstm.py b/chapter5_time/lstm.py
new file mode 100644
index 0000000..fc23eef
--- /dev/null
+++ b/chapter5_time/lstm.py
@@ -0,0 +1,162 @@
+"""
+Long Short Term Memory (LSTM) is a kind of recurrent neural network that can capture long-short term information.
+This document mainly includes:
+- Pytorch implementation for LSTM.
+- An example to test LSTM.
+For beginners, you can refer to to learn the basics about how LSTM works.
+"""
+from typing import Optional, Union, Tuple, List, Dict
+import math
+import torch
+import torch.nn as nn
+from ding.torch_utils.network.rnn import is_sequence
+from ding.torch_utils import build_normalization
+
+
+class LSTM(nn.Module):
+ """
+ **Overview:**
+ Implementation of LSTM cell with layer norm.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ hidden_size: int,
+ num_layers: int,
+ norm_type: Optional[str] = 'LN',
+ dropout: float = 0.
+ ) -> None:
+ # Initialize arguments.
+ super(LSTM, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ # Initialize normalization functions.
+ norm_func = build_normalization(norm_type)
+ self.norm = nn.ModuleList([norm_func(hidden_size * 4) for _ in range(2 * num_layers)])
+ # Initialize LSTM parameters.
+ self.wx = nn.ParameterList()
+ self.wh = nn.ParameterList()
+ dims = [input_size] + [hidden_size] * num_layers
+ for l in range(num_layers):
+ self.wx.append(nn.Parameter(torch.zeros(dims[l], dims[l + 1] * 4)))
+ self.wh.append(nn.Parameter(torch.zeros(hidden_size, hidden_size * 4)))
+ self.bias = nn.Parameter(torch.zeros(num_layers, hidden_size * 4))
+ # Initialize the Dropout Layer.
+ self.use_dropout = dropout > 0.
+ if self.use_dropout:
+ self.dropout = nn.Dropout(dropout)
+ self._init()
+
+ # Dealing with different types of input and return preprocessed prev_state.
+ def _before_forward(self, inputs: torch.Tensor, prev_state: Union[None, List[Dict]]) -> torch.Tensor:
+ seq_len, batch_size = inputs.shape[:2]
+ # If prev_state is None, it indicates that this is the beginning of a sequence. In this case, prev_state will be initialized as zero.
+ if prev_state is None:
+ zeros = torch.zeros(
+ self.num_layers,
+ batch_size,
+ self.hidden_size,
+ dtype=inputs.dtype,
+ device=inputs.device
+ )
+ prev_state = (zeros, zeros)
+ # If prev_state is not None, then preprocess it into one batch.
+ else:
+ assert len(prev_state) == batch_size
+ state = [[v for v in prev.values()] for prev in prev_state]
+ state = list(zip(*state))
+ prev_state = [torch.cat(t, dim=1) for t in state]
+
+ return prev_state
+
+ def _init(self):
+ # Initialize parameters. Each parameter is initialized using a uniform distribution of: $$U(-\sqrt {\frac 1 {HiddenSize}}, -\sqrt {\frac 1 {HiddenSize}})$$
+ gain = math.sqrt(1. / self.hidden_size)
+ for l in range(self.num_layers):
+ torch.nn.init.uniform_(self.wx[l], -gain, gain)
+ torch.nn.init.uniform_(self.wh[l], -gain, gain)
+ if self.bias is not None:
+ torch.nn.init.uniform_(self.bias[l], -gain, gain)
+
+ def forward(self,
+ inputs: torch.Tensor,
+ prev_state: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]:
+ # The shape of input is: [sequence length, batch size, input size]
+ seq_len, batch_size = inputs.shape[:2]
+ prev_state = self._before_forward(inputs, prev_state)
+
+ H, C = prev_state
+ x = inputs
+ next_state = []
+ for l in range(self.num_layers):
+ h, c = H[l], C[l]
+ new_x = []
+ for s in range(seq_len):
+ # Calculate $$z, z^i, z^f, z^o$$ simultaneously.
+ gate = self.norm[l * 2](torch.matmul(x[s], self.wx[l])
+ ) + self.norm[l * 2 + 1](torch.matmul(h, self.wh[l]))
+ if self.bias is not None:
+ gate += self.bias[l]
+ gate = list(torch.chunk(gate, 4, dim=1))
+ i, f, o, z = gate
+ # $$z^i = \sigma (Wx^ix^t + Wh^ih^{t-1})$$
+ i = torch.sigmoid(i)
+ # $$z^f = \sigma (Wx^fx^t + Wh^fh^{t-1})$$
+ f = torch.sigmoid(f)
+ # $$z^o = \sigma (Wx^ox^t + Wh^oh^{t-1})$$
+ o = torch.sigmoid(o)
+ # $$z = tanh(Wxx^t + Whh^{t-1})$$
+ z = torch.tanh(z)
+ # $$c^t = z^f \odot c^{t-1}+z^i \odot z$$
+ c = f * c + i * z
+ # $$h^t = z^o \odot tanh(c^t)$$
+ h = o * torch.tanh(c)
+ new_x.append(h)
+ next_state.append((h, c))
+ x = torch.stack(new_x, dim=0)
+ # Dropout layer.
+ if self.use_dropout and l != self.num_layers - 1:
+ x = self.dropout(x)
+ next_state = [torch.stack(t, dim=0) for t in zip(*next_state)]
+ # Return list type, split the next_state .
+ h, c = next_state
+ batch_size = h.shape[1]
+ # Split h with shape [num_layers, batch_size, hidden_size] to a list with length batch_size and each element is a tensor with shape [num_layers, 1, hidden_size]. The same operation is performed on c.
+ next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)]
+ next_state = list(zip(*next_state))
+ next_state = [{k: v for k, v in zip(['h', 'c'], item)} for item in next_state]
+ return x, next_state
+
+
+def test_lstm():
+ # Randomly generate test data.
+ seq_len = 2
+ num_layers = 3
+ input_size = 4
+ hidden_size = 5
+ batch_size = 6
+ norm_type = 'LN'
+ dropout = 0.1
+ input = torch.rand(seq_len, batch_size, input_size).requires_grad_(True)
+ lstm = LSTM(input_size, hidden_size, num_layers, norm_type, dropout)
+
+ # Test the LSTM recurrently, using the hidden states of last input as new prev_state.
+ prev_state = None
+ for s in range(seq_len):
+ input_step = input[s:s + 1]
+ # The prev_state is None if the input_step is the first step of the sequence. Otherwise, the prev_state contains a list of dictions with key 'h', 'c', and the corresponding values are tensors with shape [num_layers, 1, hidden_size]. The length of the list equuals to the batch_size.
+ output, prev_state = lstm(input_step, prev_state)
+
+ # Check whether the output is correct.
+ assert output.shape == (1, batch_size, hidden_size)
+ assert len(prev_state) == batch_size
+ assert prev_state[0]['h'].shape == (num_layers, 1, hidden_size)
+ torch.mean(output).backward()
+ assert isinstance(input.grad, torch.Tensor)
+
+
+if __name__ == '__main__':
+ test_lstm()