@@ -69,8 +69,10 @@ def _init_learn(self) -> None:
69
69
self .act_dim = self ._cfg .model .act_dim
70
70
71
71
self ._learn_model = self ._model
72
+ self ._atari_env = 'state_mean' not in self ._cfg
73
+ self ._basic_discrete_env = not self ._cfg .model .continuous and 'state_mean' in self ._cfg
72
74
73
- if 'state_mean' not in self ._cfg :
75
+ if self ._atari_env :
74
76
self ._optimizer = self ._learn_model .configure_optimizers (wt_decay , lr )
75
77
else :
76
78
self ._optimizer = torch .optim .AdamW (self ._learn_model .parameters (), lr = lr , weight_decay = wt_decay )
@@ -93,22 +95,18 @@ def _forward_learn(self, data: list) -> Dict[str, Any]:
93
95
self ._learn_model .train ()
94
96
95
97
timesteps , states , actions , returns_to_go , traj_mask = data
96
- if actions .dtype is not torch .long :
97
- actions = actions .to (torch .long )
98
- action_target = torch .clone (actions ).detach ().to (self ._device )
99
98
100
99
# The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1),
101
100
# and we need a 3-dim tensor
102
101
if len (returns_to_go .shape ) == 2 :
103
102
returns_to_go = returns_to_go .unsqueeze (- 1 )
104
103
105
- # if discrete
106
- if not self ._cfg .model .continuous and 'state_mean' in self ._cfg :
107
- # actions = one_hot(actions.squeeze(-1), num=self.act_dim)
104
+ if self ._basic_discrete_env :
105
+ actions = actions .to (torch .long )
108
106
actions = actions .squeeze (- 1 )
109
107
action_target = torch .clone (actions ).detach ().to (self ._device )
110
108
111
- if 'state_mean' not in self ._cfg :
109
+ if self ._atari_env :
112
110
state_preds , action_preds , return_preds = self ._learn_model .forward (
113
111
timesteps = timesteps , states = states , actions = actions , returns_to_go = returns_to_go , tar = 1
114
112
)
@@ -117,7 +115,7 @@ def _forward_learn(self, data: list) -> Dict[str, Any]:
117
115
timesteps = timesteps , states = states , actions = actions , returns_to_go = returns_to_go
118
116
)
119
117
120
- if 'state_mean' not in self ._cfg :
118
+ if self ._atari_env :
121
119
action_loss = F .cross_entropy (action_preds .reshape (- 1 , action_preds .size (- 1 )), action_target .reshape (- 1 ))
122
120
else :
123
121
traj_mask = traj_mask .view (- 1 , )
@@ -171,7 +169,9 @@ def _init_eval(self) -> None:
171
169
self .actions = torch .zeros (
172
170
(self .eval_batch_size , self .max_eval_ep_len , 1 ), dtype = torch .long , device = self ._device
173
171
)
174
- if 'state_mean' not in self ._cfg :
172
+ self ._atari_env = 'state_mean' not in self ._cfg
173
+ self ._basic_discrete_env = not self ._cfg .model .continuous and 'state_mean' in self ._cfg
174
+ if self ._atari_env :
175
175
self .states = torch .zeros (
176
176
(
177
177
self .eval_batch_size ,
@@ -201,7 +201,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
201
201
202
202
self ._eval_model .eval ()
203
203
with torch .no_grad ():
204
- if 'state_mean' not in self ._cfg :
204
+ if self ._atari_env :
205
205
states = torch .zeros (
206
206
(
207
207
self .eval_batch_size ,
@@ -228,15 +228,15 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
228
228
(self .eval_batch_size , self .context_len , 1 ), dtype = torch .float32 , device = self ._device
229
229
)
230
230
for i in data_id :
231
- if 'state_mean' not in self ._cfg :
231
+ if self ._atari_env :
232
232
self .states [i , self .t [i ]] = data [i ]['obs' ].to (self ._device )
233
233
else :
234
234
self .states [i , self .t [i ]] = (data [i ]['obs' ].to (self ._device ) - self .state_mean ) / self .state_std
235
235
self .running_rtg [i ] = self .running_rtg [i ] - data [i ]['reward' ].to (self ._device )
236
236
self .rewards_to_go [i , self .t [i ]] = self .running_rtg [i ]
237
237
238
238
if self .t [i ] <= self .context_len :
239
- if 'state_mean' not in self ._cfg :
239
+ if self ._atari_env :
240
240
timesteps [i ] = min (self .t [i ], self ._cfg .model .max_timestep ) * torch .ones (
241
241
(1 , 1 ), dtype = torch .int64
242
242
).to (self ._device )
@@ -246,7 +246,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
246
246
actions [i ] = self .actions [i , :self .context_len ]
247
247
rewards_to_go [i ] = self .rewards_to_go [i , :self .context_len ]
248
248
else :
249
- if 'state_mean' not in self ._cfg :
249
+ if self ._atari_env :
250
250
timesteps [i ] = min (self .t [i ], self ._cfg .model .max_timestep ) * torch .ones (
251
251
(1 , 1 ), dtype = torch .int64
252
252
).to (self ._device )
@@ -255,15 +255,14 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
255
255
states [i ] = self .states [i , self .t [i ] - self .context_len + 1 :self .t [i ] + 1 ]
256
256
actions [i ] = self .actions [i , self .t [i ] - self .context_len + 1 :self .t [i ] + 1 ]
257
257
rewards_to_go [i ] = self .rewards_to_go [i , self .t [i ] - self .context_len + 1 :self .t [i ] + 1 ]
258
- if not self ._cfg .model .continuous and 'state_mean' in self ._cfg :
259
- # actions = one_hot(actions.squeeze(-1), num=self.act_dim)
258
+ if self ._basic_discrete_env :
260
259
actions = actions .squeeze (- 1 )
261
260
_ , act_preds , _ = self ._eval_model .forward (timesteps , states , actions , rewards_to_go )
262
261
del timesteps , states , actions , rewards_to_go
263
262
264
263
logits = act_preds [:, - 1 , :]
265
264
if not self ._cfg .model .continuous :
266
- if 'state_mean' not in self ._cfg :
265
+ if self ._atari_env :
267
266
probs = F .softmax (logits , dim = - 1 )
268
267
act = torch .zeros ((self .eval_batch_size , 1 ), dtype = torch .long , device = self ._device )
269
268
for i in data_id :
@@ -297,7 +296,7 @@ def _reset_eval(self, data_id: List[int] = None) -> None:
297
296
dtype = torch .float32 ,
298
297
device = self ._device
299
298
)
300
- if 'state_mean' not in self ._cfg :
299
+ if self ._atari_env :
301
300
self .states = torch .zeros (
302
301
(
303
302
self .eval_batch_size ,
@@ -327,7 +326,7 @@ def _reset_eval(self, data_id: List[int] = None) -> None:
327
326
self .actions [i ] = torch .zeros (
328
327
(self .max_eval_ep_len , self .act_dim ), dtype = torch .float32 , device = self ._device
329
328
)
330
- if 'state_mean' not in self ._cfg :
329
+ if self ._atari_env :
331
330
self .states [i ] = torch .zeros (
332
331
(self .max_eval_ep_len , ) + tuple (self .state_dim ), dtype = torch .float32 , device = self ._device
333
332
)
0 commit comments