|
1 | 1 | """
|
2 |
| -This file provides two components for consolidating data streams, SumMerge and VectorMerge. |
| 2 | +This file provides an implementation of several different neural network modules that are used for merging and |
| 3 | +transforming input data in various ways. The following components can be used when we are dealing with |
| 4 | +data from multiple modes, or when we need to merge multiple intermediate embedded representations in |
| 5 | +the forward process of a model. |
3 | 6 |
|
4 |
| -The following components can be used when we are dealing with data from multiple modes, |
5 |
| -or when we need to merge multiple intermediate embedded representations in the forward process of a model. |
| 7 | +The main classes defined in this code are: |
6 | 8 |
|
7 |
| -While SumMerge simply sums multiple data streams in the first dimension, |
8 |
| -VectorMerge provides three more complex weighted summations. |
| 9 | + - BilinearGeneral: This class implements a bilinear transformation layer that applies a bilinear transformation to |
| 10 | + incoming data, as described in the "Multiplicative Interactions and Where to Find Them", published at ICLR 2020, |
| 11 | + https://openreview.net/forum?id=rylnK6VtDH. The transformation involves two input features and an output |
| 12 | + feature, and also includes an optional bias term. |
| 13 | +
|
| 14 | + - TorchBilinearCustomized: This class implements a bilinear layer similar to the one provided by PyTorch |
| 15 | + (torch.nn.Bilinear), but with additional customizations. This class can be used as an alternative to the |
| 16 | + BilinearGeneral class. |
| 17 | +
|
| 18 | + - TorchBilinear: This class is a simple wrapper around the PyTorch's built-in nn.Bilinear module. It provides the |
| 19 | + same functionality as PyTorch's nn.Bilinear but within the structure of the current module. |
| 20 | +
|
| 21 | + - FiLM: This class implements a Feature-wise Linear Modulation (FiLM) layer. FiLM layers apply an affine |
| 22 | + transformation to the input data, conditioned on some additional context information. |
| 23 | +
|
| 24 | + - GatingType: This is an enumeration class that defines different types of gating mechanisms that can be used in |
| 25 | + the modules. |
| 26 | +
|
| 27 | + - SumMerge: This class provides a simple summing mechanism to merge input streams. |
| 28 | +
|
| 29 | + - VectorMerge: This class implements a more complex merging mechanism for vector streams. |
| 30 | + The streams are first transformed using layer normalization, a ReLU activation, and a linear layer. |
| 31 | + Then they are merged either by simple summing or by using a gating mechanism. |
| 32 | +
|
| 33 | +The implementation of these classes involves PyTorch and Numpy libraries, and the classes use PyTorch's nn.Module as |
| 34 | +the base class, making them compatible with PyTorch's neural network modules and functionalities. |
| 35 | +These modules can be useful building blocks in more complex deep learning architectures. |
9 | 36 | """
|
10 | 37 |
|
11 | 38 | import enum
|
12 |
| -from typing import List, Dict |
| 39 | +import math |
13 | 40 | from collections import OrderedDict
|
| 41 | +from typing import List, Dict |
| 42 | + |
| 43 | +import numpy as np |
14 | 44 | import torch
|
15 | 45 | import torch.nn as nn
|
16 | 46 | import torch.nn.functional as F
|
17 | 47 | from torch import Tensor
|
18 | 48 |
|
19 | 49 |
|
| 50 | +class BilinearGeneral(nn.Module): |
| 51 | + """ |
| 52 | + Overview: |
| 53 | + Bilinear implementation as in: |
| 54 | + Multiplicative Interactions and Where to Find Them, ICLR 2020, https://openreview.net/forum?id=rylnK6VtDH |
| 55 | + Arguments: |
| 56 | + - in1_features (:obj:`int`): size of each first input sample |
| 57 | + - in2_features (:obj:`int`): size of each second input sample |
| 58 | + - out_features (:obj:`int`): size of each output sample |
| 59 | + """ |
| 60 | + |
| 61 | + def __init__(self, in1_features, in2_features, out_features): |
| 62 | + super(BilinearGeneral, self).__init__() |
| 63 | + # Initialize the weight matrices W and U, and the bias vectors V and b |
| 64 | + self.W = nn.Parameter(torch.Tensor(out_features, in1_features, in2_features)) |
| 65 | + self.U = nn.Parameter(torch.Tensor(out_features, in2_features)) |
| 66 | + self.V = nn.Parameter(torch.Tensor(out_features, in1_features)) |
| 67 | + self.b = nn.Parameter(torch.Tensor(out_features)) |
| 68 | + self.in1_features = in1_features |
| 69 | + self.in2_features = in2_features |
| 70 | + self.out_features = out_features |
| 71 | + self.reset_parameters() |
| 72 | + |
| 73 | + def reset_parameters(self): |
| 74 | + stdv = 1. / np.sqrt(self.in1_features) |
| 75 | + self.W.data.uniform_(-stdv, stdv) |
| 76 | + self.U.data.uniform_(-stdv, stdv) |
| 77 | + self.V.data.uniform_(-stdv, stdv) |
| 78 | + self.b.data.uniform_(-stdv, stdv) |
| 79 | + |
| 80 | + def forward(self, x, z): |
| 81 | + # Compute the bilinear function |
| 82 | + # x^TWz |
| 83 | + out_W = torch.einsum('bi,kij,bj->bk', x, self.W, z) |
| 84 | + # x^TU |
| 85 | + out_U = z.matmul(self.U.t()) |
| 86 | + # Vz |
| 87 | + out_V = x.matmul(self.V.t()) |
| 88 | + # x^TWz + x^TU + Vz + b |
| 89 | + out = out_W + out_U + out_V + self.b |
| 90 | + return out |
| 91 | + |
| 92 | + |
| 93 | +class TorchBilinearCustomized(nn.Module): |
| 94 | + """ |
| 95 | + Overview: |
| 96 | + Customized Torch Bilinear implementation. |
| 97 | + Arguments: |
| 98 | + - in1_features (:obj:`int`): size of each first input sample |
| 99 | + - in2_features (:obj:`int`): size of each second input sample |
| 100 | + - out_features (:obj:`int`): size of each output sample |
| 101 | + """ |
| 102 | + |
| 103 | + def __init__(self, in1_features, in2_features, out_features): |
| 104 | + super(TorchBilinearCustomized, self).__init__() |
| 105 | + self.in1_features = in1_features |
| 106 | + self.in2_features = in2_features |
| 107 | + self.out_features = out_features |
| 108 | + self.weight = nn.Parameter(torch.Tensor(out_features, in1_features, in2_features)) |
| 109 | + self.bias = nn.Parameter(torch.Tensor(out_features)) |
| 110 | + self.reset_parameters() |
| 111 | + |
| 112 | + def reset_parameters(self): |
| 113 | + bound = 1 / math.sqrt(self.in1_features) |
| 114 | + nn.init.uniform_(self.weight, -bound, bound) |
| 115 | + nn.init.uniform_(self.bias, -bound, bound) |
| 116 | + |
| 117 | + def forward(self, x, z): |
| 118 | + # Using torch.einsum for the bilinear operation |
| 119 | + out = torch.einsum('bi,oij,bj->bo', x, self.weight, z) + self.bias |
| 120 | + return out.squeeze(-1) |
| 121 | + |
| 122 | + |
| 123 | +""" |
| 124 | +Overview: |
| 125 | + Implementation of the Bilinear layer as in PyTorch: |
| 126 | + https://pytorch.org/docs/stable/generated/torch.nn.Bilinear.html#torch.nn.Bilinear |
| 127 | +Arguments: |
| 128 | + - in1_features (:obj:`int`): size of each first input sample |
| 129 | + - in2_features (:obj:`int`): size of each second input sample |
| 130 | + - out_features (:obj:`int`): size of each output sample |
| 131 | + - bias (:obj:`bool`): If set to False, the layer will not learn an additive bias. Default: ``True``. |
| 132 | +""" |
| 133 | +TorchBilinear = nn.Bilinear |
| 134 | + |
| 135 | + |
| 136 | +class FiLM(nn.Module): |
| 137 | + """ |
| 138 | + Overview: |
| 139 | + Feature-wise Linear Modulation (FiLM) Layer. |
| 140 | + This layer applies feature-wise affine transformation based on context. |
| 141 | + Arguments: |
| 142 | + - feature_dim (:obj:`int`). The dimension of the input feature vector. |
| 143 | + - context_dim (:obj:`int`). The dimension of the input context vector. |
| 144 | + """ |
| 145 | + |
| 146 | + def __init__(self, feature_dim, context_dim): |
| 147 | + super(FiLM, self).__init__() |
| 148 | + # Define the fully connected layer for context |
| 149 | + # The output dimension is twice the feature dimension for gamma and beta |
| 150 | + self.context_layer = nn.Linear(context_dim, 2 * feature_dim) |
| 151 | + |
| 152 | + def forward(self, feature, context): |
| 153 | + """ |
| 154 | + Overview: |
| 155 | + Forward propagation. |
| 156 | + Arguments: |
| 157 | + - feature (:obj:`torch.Tensor`). The input feature, shape (batch_size, feature_dim) |
| 158 | + - context (:obj:`torch.Tensor`). The input context, shape (batch_size, context_dim) |
| 159 | + Returns: |
| 160 | + - conditioned_feature : torch.Tensor. The output feature after FiLM, shape (batch_size, feature_dim) |
| 161 | + """ |
| 162 | + # Pass context through the fully connected layer |
| 163 | + out = self.context_layer(context) |
| 164 | + # Split the output into two parts: gamma and beta |
| 165 | + # The dimension for splitting is 1 (feature dimension) |
| 166 | + gamma, beta = torch.split(out, out.shape[1] // 2, dim=1) |
| 167 | + # Apply feature-wise affine transformation |
| 168 | + conditioned_feature = gamma * feature + beta |
| 169 | + return conditioned_feature |
| 170 | + |
| 171 | + |
20 | 172 | class GatingType(enum.Enum):
|
21 | 173 | r"""
|
22 | 174 | Overview:
|
|
0 commit comments