Skip to content

Commit 3638eab

Browse files
authored
feature(pu): add three variants of Bilinear classes and a FiLM class (#703)
* feature(pu): add three variants of Bilinear classes and a FiLM class into merge.py * polish(pu): polish TorchBilinear * style(pu): yapf format * style(pu): yapf format * style(pu): flake8 format * style(pu): flake8 format
1 parent e3a7935 commit 3638eab

File tree

2 files changed

+289
-6
lines changed

2 files changed

+289
-6
lines changed

ding/torch_utils/network/merge.py

+158-6
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,174 @@
11
"""
2-
This file provides two components for consolidating data streams, SumMerge and VectorMerge.
2+
This file provides an implementation of several different neural network modules that are used for merging and
3+
transforming input data in various ways. The following components can be used when we are dealing with
4+
data from multiple modes, or when we need to merge multiple intermediate embedded representations in
5+
the forward process of a model.
36
4-
The following components can be used when we are dealing with data from multiple modes,
5-
or when we need to merge multiple intermediate embedded representations in the forward process of a model.
7+
The main classes defined in this code are:
68
7-
While SumMerge simply sums multiple data streams in the first dimension,
8-
VectorMerge provides three more complex weighted summations.
9+
- BilinearGeneral: This class implements a bilinear transformation layer that applies a bilinear transformation to
10+
incoming data, as described in the "Multiplicative Interactions and Where to Find Them", published at ICLR 2020,
11+
https://openreview.net/forum?id=rylnK6VtDH. The transformation involves two input features and an output
12+
feature, and also includes an optional bias term.
13+
14+
- TorchBilinearCustomized: This class implements a bilinear layer similar to the one provided by PyTorch
15+
(torch.nn.Bilinear), but with additional customizations. This class can be used as an alternative to the
16+
BilinearGeneral class.
17+
18+
- TorchBilinear: This class is a simple wrapper around the PyTorch's built-in nn.Bilinear module. It provides the
19+
same functionality as PyTorch's nn.Bilinear but within the structure of the current module.
20+
21+
- FiLM: This class implements a Feature-wise Linear Modulation (FiLM) layer. FiLM layers apply an affine
22+
transformation to the input data, conditioned on some additional context information.
23+
24+
- GatingType: This is an enumeration class that defines different types of gating mechanisms that can be used in
25+
the modules.
26+
27+
- SumMerge: This class provides a simple summing mechanism to merge input streams.
28+
29+
- VectorMerge: This class implements a more complex merging mechanism for vector streams.
30+
The streams are first transformed using layer normalization, a ReLU activation, and a linear layer.
31+
Then they are merged either by simple summing or by using a gating mechanism.
32+
33+
The implementation of these classes involves PyTorch and Numpy libraries, and the classes use PyTorch's nn.Module as
34+
the base class, making them compatible with PyTorch's neural network modules and functionalities.
35+
These modules can be useful building blocks in more complex deep learning architectures.
936
"""
1037

1138
import enum
12-
from typing import List, Dict
39+
import math
1340
from collections import OrderedDict
41+
from typing import List, Dict
42+
43+
import numpy as np
1444
import torch
1545
import torch.nn as nn
1646
import torch.nn.functional as F
1747
from torch import Tensor
1848

1949

50+
class BilinearGeneral(nn.Module):
51+
"""
52+
Overview:
53+
Bilinear implementation as in:
54+
Multiplicative Interactions and Where to Find Them, ICLR 2020, https://openreview.net/forum?id=rylnK6VtDH
55+
Arguments:
56+
- in1_features (:obj:`int`): size of each first input sample
57+
- in2_features (:obj:`int`): size of each second input sample
58+
- out_features (:obj:`int`): size of each output sample
59+
"""
60+
61+
def __init__(self, in1_features, in2_features, out_features):
62+
super(BilinearGeneral, self).__init__()
63+
# Initialize the weight matrices W and U, and the bias vectors V and b
64+
self.W = nn.Parameter(torch.Tensor(out_features, in1_features, in2_features))
65+
self.U = nn.Parameter(torch.Tensor(out_features, in2_features))
66+
self.V = nn.Parameter(torch.Tensor(out_features, in1_features))
67+
self.b = nn.Parameter(torch.Tensor(out_features))
68+
self.in1_features = in1_features
69+
self.in2_features = in2_features
70+
self.out_features = out_features
71+
self.reset_parameters()
72+
73+
def reset_parameters(self):
74+
stdv = 1. / np.sqrt(self.in1_features)
75+
self.W.data.uniform_(-stdv, stdv)
76+
self.U.data.uniform_(-stdv, stdv)
77+
self.V.data.uniform_(-stdv, stdv)
78+
self.b.data.uniform_(-stdv, stdv)
79+
80+
def forward(self, x, z):
81+
# Compute the bilinear function
82+
# x^TWz
83+
out_W = torch.einsum('bi,kij,bj->bk', x, self.W, z)
84+
# x^TU
85+
out_U = z.matmul(self.U.t())
86+
# Vz
87+
out_V = x.matmul(self.V.t())
88+
# x^TWz + x^TU + Vz + b
89+
out = out_W + out_U + out_V + self.b
90+
return out
91+
92+
93+
class TorchBilinearCustomized(nn.Module):
94+
"""
95+
Overview:
96+
Customized Torch Bilinear implementation.
97+
Arguments:
98+
- in1_features (:obj:`int`): size of each first input sample
99+
- in2_features (:obj:`int`): size of each second input sample
100+
- out_features (:obj:`int`): size of each output sample
101+
"""
102+
103+
def __init__(self, in1_features, in2_features, out_features):
104+
super(TorchBilinearCustomized, self).__init__()
105+
self.in1_features = in1_features
106+
self.in2_features = in2_features
107+
self.out_features = out_features
108+
self.weight = nn.Parameter(torch.Tensor(out_features, in1_features, in2_features))
109+
self.bias = nn.Parameter(torch.Tensor(out_features))
110+
self.reset_parameters()
111+
112+
def reset_parameters(self):
113+
bound = 1 / math.sqrt(self.in1_features)
114+
nn.init.uniform_(self.weight, -bound, bound)
115+
nn.init.uniform_(self.bias, -bound, bound)
116+
117+
def forward(self, x, z):
118+
# Using torch.einsum for the bilinear operation
119+
out = torch.einsum('bi,oij,bj->bo', x, self.weight, z) + self.bias
120+
return out.squeeze(-1)
121+
122+
123+
"""
124+
Overview:
125+
Implementation of the Bilinear layer as in PyTorch:
126+
https://pytorch.org/docs/stable/generated/torch.nn.Bilinear.html#torch.nn.Bilinear
127+
Arguments:
128+
- in1_features (:obj:`int`): size of each first input sample
129+
- in2_features (:obj:`int`): size of each second input sample
130+
- out_features (:obj:`int`): size of each output sample
131+
- bias (:obj:`bool`): If set to False, the layer will not learn an additive bias. Default: ``True``.
132+
"""
133+
TorchBilinear = nn.Bilinear
134+
135+
136+
class FiLM(nn.Module):
137+
"""
138+
Overview:
139+
Feature-wise Linear Modulation (FiLM) Layer.
140+
This layer applies feature-wise affine transformation based on context.
141+
Arguments:
142+
- feature_dim (:obj:`int`). The dimension of the input feature vector.
143+
- context_dim (:obj:`int`). The dimension of the input context vector.
144+
"""
145+
146+
def __init__(self, feature_dim, context_dim):
147+
super(FiLM, self).__init__()
148+
# Define the fully connected layer for context
149+
# The output dimension is twice the feature dimension for gamma and beta
150+
self.context_layer = nn.Linear(context_dim, 2 * feature_dim)
151+
152+
def forward(self, feature, context):
153+
"""
154+
Overview:
155+
Forward propagation.
156+
Arguments:
157+
- feature (:obj:`torch.Tensor`). The input feature, shape (batch_size, feature_dim)
158+
- context (:obj:`torch.Tensor`). The input context, shape (batch_size, context_dim)
159+
Returns:
160+
- conditioned_feature : torch.Tensor. The output feature after FiLM, shape (batch_size, feature_dim)
161+
"""
162+
# Pass context through the fully connected layer
163+
out = self.context_layer(context)
164+
# Split the output into two parts: gamma and beta
165+
# The dimension for splitting is 1 (feature dimension)
166+
gamma, beta = torch.split(out, out.shape[1] // 2, dim=1)
167+
# Apply feature-wise affine transformation
168+
conditioned_feature = gamma * feature + beta
169+
return conditioned_feature
170+
171+
20172
class GatingType(enum.Enum):
21173
r"""
22174
Overview:

ding/torch_utils/tests/test_merge.py

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import pytest
2+
import torch
3+
from ding.torch_utils.network.merge import TorchBilinearCustomized, TorchBilinear, BilinearGeneral, FiLM
4+
5+
6+
@pytest.mark.unittest
7+
def test_torch_bilinear_customized():
8+
batch_size = 10
9+
in1_features = 20
10+
in2_features = 30
11+
out_features = 40
12+
bilinear_customized = TorchBilinearCustomized(in1_features, in2_features, out_features)
13+
x = torch.randn(batch_size, in1_features)
14+
z = torch.randn(batch_size, in2_features)
15+
out = bilinear_customized(x, z)
16+
assert out.shape == (batch_size, out_features), "Output shape does not match expected shape."
17+
18+
19+
@pytest.mark.unittest
20+
def test_torch_bilinear():
21+
batch_size = 10
22+
in1_features = 20
23+
in2_features = 30
24+
out_features = 40
25+
torch_bilinear = TorchBilinear(in1_features, in2_features, out_features)
26+
x = torch.randn(batch_size, in1_features)
27+
z = torch.randn(batch_size, in2_features)
28+
out = torch_bilinear(x, z)
29+
assert out.shape == (batch_size, out_features), "Output shape does not match expected shape."
30+
31+
32+
@pytest.mark.unittest
33+
def test_bilinear_consistency():
34+
batch_size = 10
35+
in1_features = 20
36+
in2_features = 30
37+
out_features = 40
38+
39+
# Initialize weights and biases with set values
40+
weight = torch.randn(out_features, in1_features, in2_features)
41+
bias = torch.randn(out_features)
42+
43+
# Create and initialize TorchBilinearCustomized and TorchBilinear models
44+
bilinear_customized = TorchBilinearCustomized(in1_features, in2_features, out_features)
45+
bilinear_customized.weight.data = weight.clone()
46+
bilinear_customized.bias.data = bias.clone()
47+
48+
torch_bilinear = TorchBilinear(in1_features, in2_features, out_features)
49+
torch_bilinear.weight.data = weight.clone()
50+
torch_bilinear.bias.data = bias.clone()
51+
52+
# Provide same input to both models
53+
x = torch.randn(batch_size, in1_features)
54+
z = torch.randn(batch_size, in2_features)
55+
56+
# Compute outputs
57+
out_bilinear_customized = bilinear_customized(x, z)
58+
out_torch_bilinear = torch_bilinear(x, z)
59+
60+
# Compute the mean squared error between outputs
61+
mse = torch.mean((out_bilinear_customized - out_torch_bilinear) ** 2)
62+
63+
print(f"Mean Squared Error between outputs: {mse.item()}")
64+
65+
# Check if outputs are the same
66+
# assert torch.allclose(out_bilinear_customized, out_torch_bilinear),
67+
# "Outputs of TorchBilinearCustomized and TorchBilinear are not the same."
68+
69+
70+
def test_bilinear_general():
71+
"""
72+
Overview:
73+
Test for the `BilinearGeneral` class.
74+
"""
75+
# Define the input dimensions and batch size
76+
in1_features = 20
77+
in2_features = 30
78+
out_features = 40
79+
batch_size = 10
80+
81+
# Create a BilinearGeneral instance
82+
bilinear_general = BilinearGeneral(in1_features, in2_features, out_features)
83+
84+
# Create random inputs
85+
input1 = torch.randn(batch_size, in1_features)
86+
input2 = torch.randn(batch_size, in2_features)
87+
88+
# Perform forward pass
89+
output = bilinear_general(input1, input2)
90+
91+
# Check output shape
92+
assert output.shape == (batch_size, out_features), "Output shape does not match expected shape."
93+
94+
# Check parameter shapes
95+
assert bilinear_general.W.shape == (
96+
out_features, in1_features, in2_features
97+
), "Weight W shape does not match expected shape."
98+
assert bilinear_general.U.shape == (out_features, in2_features), "Weight U shape does not match expected shape."
99+
assert bilinear_general.V.shape == (out_features, in1_features), "Weight V shape does not match expected shape."
100+
assert bilinear_general.b.shape == (out_features, ), "Bias shape does not match expected shape."
101+
102+
# Check parameter types
103+
assert isinstance(bilinear_general.W, torch.nn.Parameter), "Weight W is not an instance of torch.nn.Parameter."
104+
assert isinstance(bilinear_general.U, torch.nn.Parameter), "Weight U is not an instance of torch.nn.Parameter."
105+
assert isinstance(bilinear_general.V, torch.nn.Parameter), "Weight V is not an instance of torch.nn.Parameter."
106+
assert isinstance(bilinear_general.b, torch.nn.Parameter), "Bias is not an instance of torch.nn.Parameter."
107+
108+
109+
@pytest.mark.unittest
110+
def test_film_forward():
111+
# Set the feature and context dimensions
112+
feature_dim = 128
113+
context_dim = 256
114+
115+
# Initialize the FiLM layer
116+
film_layer = FiLM(feature_dim, context_dim)
117+
118+
# Create random feature and context vectors
119+
feature = torch.randn((32, feature_dim)) # batch size is 32
120+
context = torch.randn((32, context_dim)) # batch size is 32
121+
122+
# Forward propagation
123+
conditioned_feature = film_layer(feature, context)
124+
125+
# Check the output shape
126+
assert conditioned_feature.shape == feature.shape, \
127+
f'Expected output shape {feature.shape}, but got {conditioned_feature.shape}'
128+
129+
# Check that the output is different from the input
130+
assert not torch.all(torch.eq(feature, conditioned_feature)), \
131+
'The output feature is the same as the input feature'

0 commit comments

Comments
 (0)