forked from NM512/dreamerv3-torch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
517 lines (484 loc) · 21.3 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
import copy
import torch
from torch import nn
import numpy as np
from PIL import ImageColor, Image, ImageDraw, ImageFont
import networks
import tools
to_np = lambda x: x.detach().cpu().numpy()
class RewardEMA(object):
"""running mean and std"""
def __init__(self, device, alpha=1e-2):
self.device = device
self.values = torch.zeros((2,)).to(device)
self.alpha = alpha
self.range = torch.tensor([0.05, 0.95]).to(device)
def __call__(self, x):
flat_x = torch.flatten(x.detach())
x_quantile = torch.quantile(input=flat_x, q=self.range)
self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values
scale = torch.clip(self.values[1] - self.values[0], min=1.0)
offset = self.values[0]
return offset.detach(), scale.detach()
class WorldModel(nn.Module):
def __init__(self, step, config):
super(WorldModel, self).__init__()
self._step = step
self._use_amp = True if config.precision == 16 else False
self._config = config
self.encoder = networks.ConvEncoder(
config.grayscale,
config.cnn_depth,
config.act,
config.norm,
config.encoder_kernels,
)
if config.size[0] == 64 and config.size[1] == 64:
embed_size = (
(64 // 2 ** (len(config.encoder_kernels))) ** 2
* config.cnn_depth
* 2 ** (len(config.encoder_kernels) - 1)
)
else:
raise NotImplemented(f"{config.size} is not applicable now")
self.dynamics = networks.RSSM(
config.dyn_stoch,
config.dyn_deter,
config.dyn_hidden,
config.dyn_input_layers,
config.dyn_output_layers,
config.dyn_rec_depth,
config.dyn_shared,
config.dyn_discrete,
config.act,
config.norm,
config.dyn_mean_act,
config.dyn_std_act,
config.dyn_temp_post,
config.dyn_min_std,
config.dyn_cell,
config.unimix_ratio,
config.initial,
config.num_actions,
embed_size,
config.device,
)
self.heads = nn.ModuleDict()
channels = 1 if config.grayscale else 3
shape = (channels,) + config.size
if config.dyn_discrete:
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
else:
feat_size = config.dyn_stoch + config.dyn_deter
self.heads["image"] = networks.ConvDecoder(
feat_size, # pytorch version
config.cnn_depth,
config.act,
config.norm,
shape,
config.decoder_kernels,
)
if config.reward_head == "twohot_symlog":
self.heads["reward"] = networks.DenseHead(
feat_size, # pytorch version
(255,),
config.reward_layers,
config.units,
config.act,
config.norm,
dist=config.reward_head,
outscale=0.0,
device=config.device,
)
else:
self.heads["reward"] = networks.DenseHead(
feat_size, # pytorch version
[],
config.reward_layers,
config.units,
config.act,
config.norm,
dist=config.reward_head,
outscale=0.0,
device=config.device,
)
self.heads["cont"] = networks.DenseHead(
feat_size, # pytorch version
[],
config.cont_layers,
config.units,
config.act,
config.norm,
dist="binary",
device=config.device,
)
for name in config.grad_heads:
assert name in self.heads, name
self._model_opt = tools.Optimizer(
"model",
self.parameters(),
config.model_lr,
config.opt_eps,
config.grad_clip,
config.weight_decay,
opt=config.opt,
use_amp=self._use_amp,
)
self._scales = dict(reward=config.reward_scale, cont=config.cont_scale)
def _train(self, data):
# action (batch_size, batch_length, act_dim)
# image (batch_size, batch_length, h, w, ch)
# reward (batch_size, batch_length)
# discount (batch_size, batch_length)
data = self.preprocess(data)
with tools.RequiresGrad(self):
with torch.cuda.amp.autocast(self._use_amp):
embed = self.encoder(data)
post, prior = self.dynamics.observe(
embed, data["action"], data["is_first"]
)
kl_free = tools.schedule(self._config.kl_free, self._step)
dyn_scale = tools.schedule(self._config.dyn_scale, self._step)
rep_scale = tools.schedule(self._config.rep_scale, self._step)
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
post, prior, kl_free, dyn_scale, rep_scale
)
losses = {}
likes = {}
for name, head in self.heads.items():
grad_head = name in self._config.grad_heads
feat = self.dynamics.get_feat(post)
feat = feat if grad_head else feat.detach()
pred = head(feat)
like = pred.log_prob(data[name])
likes[name] = like
losses[name] = -torch.mean(like) * self._scales.get(name, 1.0)
model_loss = sum(losses.values()) + kl_loss
metrics = self._model_opt(model_loss, self.parameters())
metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()})
metrics["kl_free"] = kl_free
metrics["dyn_scale"] = dyn_scale
metrics["rep_scale"] = rep_scale
metrics["dyn_loss"] = to_np(dyn_loss)
metrics["rep_loss"] = to_np(rep_loss)
metrics["kl"] = to_np(torch.mean(kl_value))
with torch.cuda.amp.autocast(self._use_amp):
metrics["prior_ent"] = to_np(
torch.mean(self.dynamics.get_dist(prior).entropy())
)
metrics["post_ent"] = to_np(
torch.mean(self.dynamics.get_dist(post).entropy())
)
context = dict(
embed=embed,
feat=self.dynamics.get_feat(post),
kl=kl_value,
postent=self.dynamics.get_dist(post).entropy(),
)
post = {k: v.detach() for k, v in post.items()}
return post, context, metrics # metrics: have losses
def preprocess(self, obs):
obs = obs.copy()
obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
obs["reward"] = torch.Tensor(obs["reward"]).unsqueeze(-1)
if "discount" in obs:
obs["discount"] *= self._config.discount
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1)
if "is_terminal" in obs:
# this label is necessary to train cont_head
obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
else:
raise ValueError('"is_terminal" was not found in observation.')
obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
return obs
def video_pred(self, data):
data = self.preprocess(data)
embed = self.encoder(data)
states, _ = self.dynamics.observe(
embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
)
recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6]
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
init = {k: v[:, -1] for k, v in states.items()}
prior = self.dynamics.imagine(data["action"][:6, 5:], init)
openl = self.heads["image"](self.dynamics.get_feat(prior)).mode()
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
# observed image is given until 5 steps
model = torch.cat([recon[:, :5], openl], 1)
truth = data["image"][:6] + 0.5
model = model + 0.5
error = (model - truth + 1.0) / 2.0
return torch.cat([truth, model, error], 2)
class ImagBehavior(nn.Module):
def __init__(self, config, world_model, stop_grad_actor=True, reward=None):
super(ImagBehavior, self).__init__()
self._use_amp = True if config.precision == 16 else False
self._config = config
self._world_model = world_model
self._stop_grad_actor = stop_grad_actor
self._reward = reward
if config.dyn_discrete:
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
else:
feat_size = config.dyn_stoch + config.dyn_deter
self.actor = networks.ActionHead(
feat_size, # pytorch version
config.num_actions,
config.actor_layers,
config.units,
config.act,
config.norm,
config.actor_dist,
config.actor_init_std,
config.actor_min_std,
config.actor_max_std,
config.actor_temp,
outscale=1.0,
unimix_ratio=config.action_unimix_ratio,
) # action_dist -> action_disc?
if config.value_head == "twohot_symlog":
self.value = networks.DenseHead(
feat_size, # pytorch version
(255,),
config.value_layers,
config.units,
config.act,
config.norm,
config.value_head,
outscale=0.0,
device=config.device,
)
else:
self.value = networks.DenseHead(
feat_size, # pytorch version
[],
config.value_layers,
config.units,
config.act,
config.norm,
config.value_head,
outscale=0.0,
device=config.device,
)
if config.slow_value_target:
self._slow_value = copy.deepcopy(self.value)
self._updates = 0
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
self._actor_opt = tools.Optimizer(
"actor",
self.actor.parameters(),
config.actor_lr,
config.ac_opt_eps,
config.actor_grad_clip,
**kw,
)
self._value_opt = tools.Optimizer(
"value",
self.value.parameters(),
config.value_lr,
config.ac_opt_eps,
config.value_grad_clip,
**kw,
)
if self._config.reward_EMA:
self.reward_ema = RewardEMA(device=self._config.device)
def _train(
self,
start,# entire batch's posterior
objective=None,
action=None,
reward=None,
imagine=None,
tape=None,
repeats=None,
):
objective = objective or self._reward
self._update_slow_target()
metrics = {}
with tools.RequiresGrad(self.actor):
with torch.cuda.amp.autocast(self._use_amp):
imag_feat, imag_state, imag_action = self._imagine(
start, self.actor, self._config.imag_horizon, repeats
)
reward = objective(imag_feat, imag_state, imag_action)
actor_ent = self.actor(imag_feat).entropy() #actor() fn returned horizon, [ batch size * sequence len ], action space distribution
state_ent = self._world_model.dynamics.get_dist(imag_state).entropy()
# this target is not scaled
# slow is flag to indicate whether slow_target is used for lambda-return
target, weights, base = self._compute_target(
imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
)
actor_loss, mets = self._compute_actor_loss(
imag_feat,
imag_state,
imag_action,
target,
actor_ent,
state_ent,
weights,
base,
)
metrics.update(mets)
value_input = imag_feat
with tools.RequiresGrad(self.value):
with torch.cuda.amp.autocast(self._use_amp):
value = self.value(value_input[:-1].detach())
target = torch.stack(target, dim=1)
# (time, batch, 1), (time, batch, 1) -> (time, batch)
value_loss = -value.log_prob(target.detach())
slow_target = self._slow_value(value_input[:-1].detach())
if self._config.slow_value_target:
value_loss = value_loss - value.log_prob(
slow_target.mode().detach()
)
if self._config.value_decay:
value_loss += self._config.value_decay * value.mode()
# (time, batch, 1), (time, batch, 1) -> (1,)
value_loss = torch.mean(weights[:-1] * value_loss[:, :, None])
metrics.update(tools.tensorstats(value.mode(), "value"))
metrics.update(tools.tensorstats(target, "target"))
metrics.update(tools.tensorstats(reward, "imag_reward"))
if self._config.actor_dist in ["onehot"]:
metrics.update(
tools.tensorstats(
torch.argmax(imag_action, dim=-1).float(), "imag_action"
)
)
else:
metrics.update(tools.tensorstats(imag_action, "imag_action"))
metrics["actor_ent"] = to_np(torch.mean(actor_ent))
with tools.RequiresGrad(self):
metrics.update(self._actor_opt(actor_loss, self.actor.parameters()))
metrics.update(self._value_opt(value_loss, self.value.parameters()))
return imag_feat, imag_state, imag_action, weights, metrics
def _imagine(self, start, policy, horizon, repeats=None):
dynamics = self._world_model.dynamics
if repeats:
raise NotImplemented("repeats is not implemented in this version")
flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
# "stoch" -> (16, 64, 1024)
# "deter" -> (16, 64, 4096)
# start: now "stoch" and "deter" each of size (16*64, 1024/4096)
# first two dims are flattened -> effectively in parallel rolling out 16*64 many episode slices.
start = {k: flatten(v) for k, v in start.items()}
def step(prev, _):
state, _, _ = prev
feat = dynamics.get_feat(state) # feat: concat[posterior, deter]
inp = feat.detach() if self._stop_grad_actor else feat
action = policy(inp).sample()
succ = dynamics.img_step(state, action, sample=self._config.imag_sample)
return succ, feat, action #Succ: the {deter; prior} dict returned by img_step
# [torch.arange(horizon)], is only a range from 0 to horizon -1 => this line simply does a for-loop
# succ: dict of stoch & deter, an array of size 16(horizon), each elem has size 16*64(#ofslices), 1024/4096
succ, feats, actions = tools.static_scan(
step, [torch.arange(horizon)], (start, None, None)
) # TODO: feats seem to be just a concat of succ, might be able to reduce returned stuff to save memory
# next line adds 1 dimension to start[k] so that it can be appened at the fornt of v
# and last stoch/deter in v is discarded
# Note: the "feats" returned above are features from the initial state to the second to the last rolled-out state
# "actions" returned above are actions taken from the initial state to the second to the last rolled-out state
states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}
if repeats:
raise NotImplemented("repeats is not implemented in this version")
return feats, states, actions
def _compute_target(
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
):
if "cont" in self._world_model.heads:
inp = self._world_model.dynamics.get_feat(imag_state) #todo: redundant concat, imag feat already exists
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
else:
discount = self._config.discount * torch.ones_like(reward)
# self._config.future_entropy = false
if self._config.future_entropy and self._config.actor_entropy() > 0:
reward += self._config.actor_entropy() * actor_ent
if self._config.future_entropy and self._config.actor_state_entropy() > 0:
reward += self._config.actor_state_entropy() * state_ent
# self.value is the our critic estimates v; mode() below is same as getting the means
value = self.value(imag_feat).mode()
# value(15, 960, ch)
# action(15, 960, ch)
# discount(15, 960, ch)
# target dimension starts with [horizon - 1]
target = tools.lambda_return(
reward[:-1],
value[:-1],
discount[:-1],
bootstrap=value[-1], # This is the base case for computing lambda target
lambda_=self._config.discount_lambda,
axis=0,
)
#discount is size [horizon, batch size * seq len]
#torch.ones_like(discount[:1]) has dim [1, batch size * seq len], but filled with all ones, since the first state has discount 1 for sure
# but actually not sure why they predicted the first discount and then override it.
# Cumprod is similar to scalar, with ith elem being all the prefix items cum product =>
# => Weights should still be [horizon, batch size * seq len], each elem is either 1 or 0, representing whether there
# exists a previous time step where the predicted cont flag = 0, in which case no future rewards should be
# accumualted.
weights = torch.cumprod(
torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0
).detach()
return target, weights, value[:-1]
def _compute_actor_loss(
self,
imag_feat,
imag_state,
imag_action,
target,
actor_ent,
state_ent,
weights,
base,
):
metrics = {}
inp = imag_feat.detach() if self._stop_grad_actor else imag_feat
policy = self.actor(inp)
actor_ent = policy.entropy() #TODO can be removed, since this is same as the actor_ent arg.
# Q-val for actor is not transformed using symlog
target = torch.stack(target, dim=1)
if self._config.reward_EMA:
offset, scale = self.reward_ema(target)
normed_target = (target - offset) / scale
normed_base = (base - offset) / scale
adv = normed_target - normed_base
metrics.update(tools.tensorstats(normed_target, "normed_target"))
values = self.reward_ema.values
metrics["EMA_005"] = to_np(values[0])
metrics["EMA_095"] = to_np(values[1])
if self._config.imag_gradient == "dynamics":
actor_target = adv
elif self._config.imag_gradient == "reinforce": # in our case, the reinforce actor target is not normalized in any way
#actor_target dim starts with [horizon-1], the last advantage should be zero so discarded
actor_target = (
policy.log_prob(imag_action)[:-1][:, :, None]
* (target - self.value(imag_feat[:-1]).mode()).detach()
)
elif self._config.imag_gradient == "both":
actor_target = (
policy.log_prob(imag_action)[:-1][:, :, None]
* (target - self.value(imag_feat[:-1]).mode()).detach()
)
mix = self._config.imag_gradient_mix()
actor_target = mix * target + (1 - mix) * actor_target
metrics["imag_gradient_mix"] = mix
else:
raise NotImplementedError(self._config.imag_gradient)
if not self._config.future_entropy and (self._config.actor_entropy() > 0):
actor_entropy = self._config.actor_entropy() * actor_ent[:-1][:, :, None]
actor_target += actor_entropy
metrics["actor_entropy"] = to_np(torch.mean(actor_entropy))
if not self._config.future_entropy and (self._config.actor_state_entropy() > 0):
state_entropy = self._config.actor_state_entropy() * state_ent[:-1]
actor_target += state_entropy
metrics["actor_state_entropy"] = to_np(torch.mean(state_entropy))
actor_loss = -torch.mean(weights[:-1] * actor_target)
return actor_loss, metrics
def _update_slow_target(self):
if self._config.slow_value_target:
if self._updates % self._config.slow_target_update == 0:
mix = self._config.slow_target_fraction
for s, d in zip(self.value.parameters(), self._slow_value.parameters()):
d.data = mix * s.data + (1 - mix) * d.data
self._updates += 1