@@ -95,6 +95,7 @@ def forward(self, x):
95
95
96
96
97
97
class DecisionTransformer (nn .Module ):
98
+
98
99
def __init__ (
99
100
self ,
100
101
state_dim ,
@@ -121,7 +122,7 @@ def __init__(
121
122
self .embed_ln = nn .LayerNorm (h_dim )
122
123
self .embed_timestep = nn .Embedding (max_timestep , h_dim )
123
124
self .drop = nn .Dropout (drop_p )
124
-
125
+
125
126
self .pos_emb = nn .Parameter (torch .zeros (1 , input_seq_len + 1 , self .h_dim ))
126
127
self .global_pos_emb = nn .Parameter (torch .zeros (1 , max_timestep + 1 , self .h_dim ))
127
128
@@ -140,9 +141,11 @@ def __init__(
140
141
# discrete actions
141
142
self .embed_action = torch .nn .Embedding (act_dim , h_dim )
142
143
use_action_tanh = False # False for discrete actions
143
- self .predict_action = nn .Sequential (* ([nn .Linear (h_dim , act_dim )] + ([nn .Tanh ()] if use_action_tanh else [])))
144
+ self .predict_action = nn .Sequential (
145
+ * ([nn .Linear (h_dim , act_dim )] + ([nn .Tanh ()] if use_action_tanh else []))
146
+ )
144
147
else :
145
- blocks = [Block (h_dim , input_seq_len + 1 , n_heads , drop_p ) for _ in range (n_blocks )]
148
+ blocks = [Block (h_dim , input_seq_len + 1 , n_heads , drop_p ) for _ in range (n_blocks )]
146
149
self .state_encoder = state_encoder
147
150
self .embed_rtg = nn .Sequential (nn .Linear (1 , h_dim ), nn .Tanh ())
148
151
self .head = nn .Linear (h_dim , act_dim , bias = False )
@@ -161,9 +164,8 @@ def forward(self, timesteps, states, actions, returns_to_go, tar=None):
161
164
162
165
# stack rtg, states and actions and reshape sequence as
163
166
# (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...)
164
- t_p = torch .stack (
165
- (returns_embeddings , state_embeddings , action_embeddings ), dim = 1
166
- ).permute (0 , 2 , 1 , 3 ).reshape (B , 3 * T , self .h_dim )
167
+ t_p = torch .stack ((returns_embeddings , state_embeddings , action_embeddings ),
168
+ dim = 1 ).permute (0 , 2 , 1 , 3 ).reshape (B , 3 * T , self .h_dim )
167
169
h = self .embed_ln (t_p )
168
170
# transformer and prediction
169
171
h = self .transformer (h )
@@ -183,20 +185,24 @@ def forward(self, timesteps, states, actions, returns_to_go, tar=None):
183
185
state_embeddings = self .state_encoder (
184
186
states .reshape (- 1 , 4 , 84 , 84 ).type (torch .float32 ).contiguous ()
185
187
) # (batch * block_size, h_dim)
186
- state_embeddings = state_embeddings .reshape (
187
- B , T , self .h_dim
188
- ) # (batch, block_size, h_dim)
188
+ state_embeddings = state_embeddings .reshape (B , T , self .h_dim ) # (batch, block_size, h_dim)
189
189
returns_embeddings = self .embed_rtg (returns_to_go .type (torch .float32 ))
190
190
action_embeddings = self .embed_action (actions .type (torch .long ).squeeze (- 1 )) # (batch, block_size, h_dim)
191
191
192
- token_embeddings = torch .zeros ((B , T * 3 - int (tar is None ), self .h_dim ), dtype = torch .float32 , device = state_embeddings .device )
193
- token_embeddings [:,::3 ,:] = returns_embeddings
194
- token_embeddings [:,1 ::3 ,:] = state_embeddings
195
- token_embeddings [:,2 ::3 ,:] = action_embeddings [:,- T + int (tar is None ):,:]
196
-
197
- all_global_pos_emb = torch .repeat_interleave (self .global_pos_emb , B , dim = 0 ) # batch_size, traj_length, h_dim
192
+ token_embeddings = torch .zeros (
193
+ (B , T * 3 - int (tar is None ), self .h_dim ), dtype = torch .float32 , device = state_embeddings .device
194
+ )
195
+ token_embeddings [:, ::3 , :] = returns_embeddings
196
+ token_embeddings [:, 1 ::3 , :] = state_embeddings
197
+ token_embeddings [:, 2 ::3 , :] = action_embeddings [:, - T + int (tar is None ):, :]
198
+
199
+ all_global_pos_emb = torch .repeat_interleave (
200
+ self .global_pos_emb , B , dim = 0
201
+ ) # batch_size, traj_length, h_dim
198
202
199
- position_embeddings = torch .gather (all_global_pos_emb , 1 , torch .repeat_interleave (timesteps , self .h_dim , dim = - 1 )) + self .pos_emb [:, :token_embeddings .shape [1 ], :]
203
+ position_embeddings = torch .gather (
204
+ all_global_pos_emb , 1 , torch .repeat_interleave (timesteps , self .h_dim , dim = - 1 )
205
+ ) + self .pos_emb [:, :token_embeddings .shape [1 ], :]
200
206
201
207
t_p = token_embeddings + position_embeddings
202
208
@@ -207,7 +213,7 @@ def forward(self, timesteps, states, actions, returns_to_go, tar=None):
207
213
208
214
return_preds = None
209
215
state_preds = None
210
- action_preds = logits [:, 1 ::3 , :] # only keep predictions from state_embeddings
216
+ action_preds = logits [:, 1 ::3 , :] # only keep predictions from state_embeddings
211
217
212
218
return state_preds , action_preds , return_preds
213
219
@@ -227,7 +233,7 @@ def configure_optimizers(self, weight_decay, learning_rate, betas=(0.9, 0.95)):
227
233
blacklist_weight_modules = (torch .nn .LayerNorm , torch .nn .Embedding )
228
234
for mn , m in self .named_modules ():
229
235
for pn , p in m .named_parameters ():
230
- fpn = '%s.%s' % (mn , pn ) if mn else pn # full param name
236
+ fpn = '%s.%s' % (mn , pn ) if mn else pn # full param name
231
237
232
238
if pn .endswith ('bias' ):
233
239
# all biases will not be decayed
@@ -253,8 +259,14 @@ def configure_optimizers(self, weight_decay, learning_rate, betas=(0.9, 0.95)):
253
259
254
260
# create the pytorch optimizer object
255
261
optim_groups = [
256
- {"params" : [param_dict [pn ] for pn in sorted (list (decay ))], "weight_decay" : weight_decay },
257
- {"params" : [param_dict [pn ] for pn in sorted (list (no_decay ))], "weight_decay" : 0.0 },
262
+ {
263
+ "params" : [param_dict [pn ] for pn in sorted (list (decay ))],
264
+ "weight_decay" : weight_decay
265
+ },
266
+ {
267
+ "params" : [param_dict [pn ] for pn in sorted (list (no_decay ))],
268
+ "weight_decay" : 0.0
269
+ },
258
270
]
259
271
optimizer = torch .optim .AdamW (optim_groups , lr = learning_rate , betas = betas )
260
- return optimizer
272
+ return optimizer
0 commit comments