-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
99 lines (65 loc) · 2.88 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from a2c_ppo_acktr.model import NNBase
from a2c_ppo_acktr.utils import init
class IAMBase(NNBase):
def __init__(self, num_inputs, dset=None, fnn_hidden_layer=512, fnn_last_layer=256, rnn_last_layer=128):
if dset == None:
rnn_input_size = num_inputs
else:
rnn_input_size = len(dset)
super(IAMBase, self).__init__(True, rnn_input_size, rnn_last_layer)
self.dset = dset
self._output_size = fnn_last_layer+rnn_last_layer
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
constant_(x, 0), np.sqrt(2))
self.fnn = nn.Sequential(
init_(nn.Linear(num_inputs, fnn_hidden_layer)), nn.Tanh(),
init_(nn.Linear(fnn_hidden_layer, fnn_last_layer)), nn.Tanh())
self.critic_linear = init_(nn.Linear(self._output_size, 1))
self.gru_tanh = nn.Tanh()
self.train()
@property
def output_size(self):
return self._output_size
def forward(self, inputs, rnn_hxs, masks):
x = inputs
y_rnn, rnn_hxs = self._forward_gru(x[..., self.dset], rnn_hxs, masks)
y_rnn = self.gru_tanh(y_rnn)
y_fnn = self.fnn(x)
y = torch.cat((y_fnn, y_rnn), dim=-1)
return self.critic_linear(y), y, rnn_hxs
class GRUBase(NNBase):
def __init__(self, num_inputs, fnn_hidden_layer=640, rnn_last_layer=128):
super(GRUBase, self).__init__(True, fnn_hidden_layer, rnn_last_layer)
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
constant_(x, 0), np.sqrt(2))
self.fnn = nn.Sequential(
init_(nn.Linear(num_inputs, fnn_hidden_layer)), nn.Tanh()
)
self.critic_linear = init_(nn.Linear(rnn_last_layer, 1))
self.gru_tanh = nn.Tanh()
self.train()
def forward(self, inputs, rnn_hxs, masks):
x = inputs
y = self.fnn(x)
y, rnn_hxs = self._forward_gru(y, rnn_hxs, masks)
y = self.gru_tanh(y)
return self.critic_linear(y), y, rnn_hxs
class FNNBase(NNBase):
def __init__(self, num_inputs, fnn_hidden_layer=640, fnn_last_layer=256):
super(FNNBase, self).__init__(False, fnn_hidden_layer, fnn_last_layer)
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
constant_(x, 0), np.sqrt(2))
self.fnn = nn.Sequential(
init_(nn.Linear(num_inputs, fnn_hidden_layer)), nn.Tanh(),
init_(nn.Linear(fnn_hidden_layer, fnn_last_layer)), nn.Tanh()
)
self.critic_linear = init_(nn.Linear(fnn_last_layer, 1))
self.train()
def forward(self, inputs, rnn_hxs, masks):
x = inputs
y = self.fnn(x)
return self.critic_linear(y), y, rnn_hxs