-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconstruct_transform.py
115 lines (105 loc) · 4.1 KB
/
construct_transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#*******************************************************************************
# Imports and Setup
#*******************************************************************************
# packages
import torch
# nflows imports
from nflows.nn.nets.resnet import ResidualNet
from nflows.transforms.autoregressive import \
MaskedPiecewiseRationalQuadraticAutoregressiveTransform
from nflows.transforms.coupling import \
PiecewiseRationalQuadraticCouplingTransform
from nflows.transforms.base import CompositeTransform
from nflows.transforms import LULinear
from nflows.transforms.permutations import ReversePermutation
#*******************************************************************************
# Function Definitions
#*******************************************************************************
def create_alternating_binary_mask(features, even=True):
'''
Create a binary mask for coupling layers. This code is inspired by the
nflows package (https://github.com/bayesiains/nflows).
Args:
features: number of features
Returns:
mask: alternating binary mask
'''
mask = torch.zeros(features).byte()
start = 0 if even else 1
mask[start::2] += 1
return mask
def create_linear_transform(args):
'''
Create a linear transformation, which can be stacked as part of a flow. This
code is inspired by the nflows package
(https://github.com/bayesiains/nflows).
Args:
args: dictionary of flow hyperparameters
Returns:
a list of flow transformations
'''
if args['linear'] == 'permutation':
return [ReversePermutation(features=args['features'])]
elif args['linear'] == 'lu':
return [
ReversePermutation(features=args['features']),
LULinear(args['features'], identity_init=True)
]
else:
raise ValueError
def create_base_transform(i, args):
'''
Create a neural spline base transformation, which can be stacked as part of
a flow. This code is inspired by the nflows package
(https://github.com/bayesiains/nflows).
Args:
i: index of flow block
args: dictionary of flow hyperparameters
Returns:
a list of flow transformations
'''
if args['base'] == 'rq-ar':
return [MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
features=args['features'],
hidden_features=args['hidden_features'],
context_features=args['context_features'],
num_bins=args['num_bins'],
tails='linear',
tail_bound=args['tail_bound'],
dropout_probability=args['dropout_probability'],
use_batch_norm=args['use_batch_norm']
)]
elif args['base'] == 'rq-c':
return [PiecewiseRationalQuadraticCouplingTransform(
mask=create_alternating_binary_mask(args['features'], even=(i%2==0)),
transform_net_create_fn=lambda in_features, out_features: ResidualNet(
in_features=in_features,
out_features=out_features,
hidden_features=args['hidden_features'],
context_features=args['context_features'],
dropout_probability=args['dropout_probability'],
use_batch_norm=args['use_batch_norm']
),
num_bins=args['num_bins'],
tails='linear',
tail_bound=args['tail_bound'],
)]
else:
raise ValueError
def create_transform(args):
'''
Create a flow transformation by alternating linear and base layers. This
code is inspired by the nflows package
(https://github.com/bayesiains/nflows).
Args:
args: dictionary of flow hyperparameters
Returns:
transform: a CompositeTransform object representing the flow transform
'''
transform_list = []
for i in range(args['num_flow_steps']):
transform_list.extend(create_linear_transform(args))
transform_list.extend(create_base_transform(i, args))
transform_list.extend(create_linear_transform(args))
transform = CompositeTransform(transform_list)
return transform