-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
40 lines (30 loc) · 1.04 KB
/
model.py
File metadata and controls
40 lines (30 loc) · 1.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env python3
from torch import nn
def Conv(filter_size: int, filters: int, channels: int):
return nn.Conv2d(channels, filters, filter_size, padding=filter_size//2)
class FSRCNN(nn.Module):
def __init__(self, *, channels=1, d, s, m, scale):
super(FSRCNN, self).__init__()
self.feature_extraction = nn.Sequential(
Conv(5, d, 1),
nn.PReLU(),
)
self.shrinking = nn.Sequential(
Conv(1, s, d),
nn.PReLU(),
)
self.non_linear_mapping = nn.Sequential(*(
Conv(3, s, s) for _ in range(m)
), nn.PReLU())
self.expanding = nn.Sequential(
Conv(1, d, s),
nn.PReLU()
)
self.deconvolution = nn.ConvTranspose2d(
d, channels, 9, stride=scale, padding=9//2, output_padding=scale-1)
def forward(self, x):
x = self.feature_extraction(x)
x = self.shrinking(x)
x = self.non_linear_mapping(x)
x = self.expanding(x)
return self.deconvolution(x)