-
Notifications
You must be signed in to change notification settings - Fork 5
/
bitEstimator.py
47 lines (43 loc) · 1.5 KB
/
bitEstimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class Bitparm(nn.Module):
'''
save params
'''
def __init__(self, channel, final=False):
super(Bitparm, self).__init__()
self.final = final
self.h = nn.Parameter(torch.nn.init.normal_(torch.empty(channel).view(1, -1), 0, 0.01))
self.b = nn.Parameter(torch.nn.init.normal_(torch.empty(channel).view(1, -1), 0, 0.01))
if not final:
self.a = nn.Parameter(torch.nn.init.normal_(torch.empty(channel).view(1, -1), 0, 0.01))
else:
self.a = None
def forward(self, x):
if self.final:
return torch.sigmoid(x * F.softplus(self.h) + self.b)
else:
x = x * F.softplus(self.h) + self.b
return x + torch.tanh(x) * torch.tanh(self.a)
class BitEstimator(nn.Module):
'''
Estimate bit
'''
def __init__(self, channel):
super(BitEstimator, self).__init__()
self.f1 = Bitparm(channel)
self.f2 = Bitparm(channel)
self.f3 = Bitparm(channel)
self.f4 = Bitparm(channel, True)
def forward(self, x):
x = self.f1(x)
x = self.f2(x)
x = self.f3(x)
return self.f4(x)
# for i in range(3):
# # print(x.size(), F.softplus(self.h[i]).size())
# x = x * F.softplus(self.h[i]) + self.b[i]
# x = x + F.tanh(x) * F.tanh(self.a[i])
# return F.sigmoid(x * F.softplus(self.h[3]) + self.b[3])