-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathppo.py
More file actions
695 lines (600 loc) · 27.2 KB
/
ppo.py
File metadata and controls
695 lines (600 loc) · 27.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
from dataclasses import dataclass, field
from typing import Literal, Optional, List
import tyro
import math
import os
import time
import warnings
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
try:
from collections.abc import Mapping
except ImportError:
from collections import Mapping
@dataclass
class PPOConfig:
"""
Configuration class for PPO
"""
seed: int = 0
"""Seed value for random generations"""
log_with: Optional[Literal["wandb", "tensorboard"]] = None
"""Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"""
# hyperparameters
steps: int = 20000
"""Number of training steps"""
learning_rate: float = 1e-5
"""Adam learning rate"""
adap_kl_ctrl: bool = False
"""Use adaptive KL control, otherwise linear"""
init_kl_coef: Optional[float] = 0.2
"""Initial KL penalty coefficient (used for adaptive and linear control)"""
kl_penalty: Literal["kl", "abs", "mse", "full"] = "kl"
"""kl penalty options: 'kl': model_logp - ref_logp, 'abs': abs(kl), 'mse': mean squared error mse(kl) and 'full': the actual kl for all tokens in the distribution"""
target: Optional[float] = 6
"""Target KL value for adaptive KL control"""
horizon: Optional[float] = 10000
"""Horizon for adaptive KL control"""
gamma: float = 1
"""Gamma parameter for advantage calculation"""
lam: float = 0.95
"""Lambda parameter for advantage calculation"""
cliprange: float = 0.2
"""Range for clipping in PPO policy gradient loss"""
cliprange_value: float = 0.2
"""Range for clipping values in loss calculation"""
vf_coef: float = 0.1
"""Scaling factor for value loss"""
batch_size: int = 256
"""Number of samples per optimisation step"""
mini_batch_size: int = 1
"""Number of samples optimized in each mini batch"""
gradient_accumulation_steps: int = 1
"""The number of gradient accumulation steps"""
world_size: int = None
"""The world size for distributed training"""
ppo_epochs: int = 4
"""Number of optimisation epochs per batch of samples"""
max_grad_norm: Optional[float] = None
"""Maximum gradient norm for gradient clipping"""
early_stopping: bool = False
"""Whether to stop the PPO optimization loop early if the KL too high"""
target_kl: float = 0.1
"""Stop early if we exceed this value by over 50%"""
compare_steps: int = 1
"""Number of steps between comparison of the current reward with the best seen so far"""
ratio_threshold: float = 1.5
"""Skip mini-batches with high PPO ratios that can cause loss spikes"""
whiten_rewards: bool = False
"""Whiten the rewards before compute advantages"""
# computed hyperparameters at runtime; we use `tyro.conf.Suppress` to hide them from the help text
backward_batch_size: tyro.conf.Suppress[int] = None
"""TO BE FILLED In RUNTIME: Number of samples optimized in an `optimizer.step()` call"""
global_backward_batch_size: tyro.conf.Suppress[int] = None
"""TO BE FILLED In RUNTIME: the effective `backward_batch_size` across all processes"""
global_batch_size: tyro.conf.Suppress[int] = None
"""TO BE FILLED In RUNTIME: the effective `batch_size` across all processes"""
class PPOTrainer():
"""
The PPOTrainer uses Proximal Policy Optimization to optimise language models.
Note, this trainer is heavily based on the original OpenAI learning to summarize work here:
https://github.com/openai/summarize-from-feedback
and the HuggingFace implementation trl here: https://github.com/huggingface/trl/tree/main
"""
def __init__(
self,
config: PPOConfig = None,
model: nn.Module = None,
ref_model: Optional[nn.Module] = None,
tokenizer = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scaler: Optional[torch.cuda.amp.GradScaler] = None,
is_ddp: bool = False,
):
np.random.seed(config.seed)
self.config = config or PPOConfig()
self.model = model
self.ref_model = ref_model
self.tokenizer = tokenizer
self.optimizer = optimizer
self.scaler = scaler
self.is_ddp = is_ddp
if self.config.adap_kl_ctrl:
self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon)
else:
self.kl_ctl = FixedKLController(self.config.init_kl_coef)
# @PPODecorators.empty_device_cache()
def step(
self,
queries: List[torch.LongTensor],
responses: List[torch.LongTensor],
scores: List[torch.FloatTensor],
response_masks: Optional[List[torch.LongTensor]] = None,
):
"""
Run a PPO optimisation step given a list of queries, model responses, and rewards.
Args:
queries (List[`torch.LongTensor`]):
List of tensors containing the encoded queries of shape (`query_length`)
responses (List[`torch.LongTensor`]):
List of tensors containing the encoded responses of shape (`response_length`)
scores (List[`torch.FloatTensor`]):
List of tensors containing the scores.
response_masks (List[`torch.FloatTensor`], *optional*)):
List of tensors containing masks of the response tokens.
Returns:
`dict[str, Any]`: A summary of the training statistics
"""
bs = self.config.batch_size
assert len(queries) == len(responses) == len(scores) == bs, "All lists must have the same length"
timing = dict()
t0 = time.time()
t = time.time()
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
input_data = [{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids]
model_inputs = self.tokenizer.pad(input_data)
model_inputs = {k: v.to(queries[0].device) for k, v in model_inputs.items()}
model_inputs_names = list(model_inputs.keys())
full_kl_penalty = self.config.kl_penalty == "full"
self.model.eval()
with torch.no_grad():
all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
self.model,
queries,
responses,
model_inputs,
response_masks=response_masks,
return_logits=full_kl_penalty,
)
ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
self.ref_model,
queries,
responses,
model_inputs,
return_logits=full_kl_penalty,
)
kl_ref = ((all_logprobs - ref_logprobs) * masks).sum(axis=-1).mean()
timing["time/ppo/forward_pass"] = time.time() - t
with torch.no_grad():
t = time.time()
if full_kl_penalty:
active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False)
ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False)
rewards, non_score_reward = self.compute_rewards(
scores, active_full_logprobs, ref_full_logprobs, masks
)
else:
rewards, non_score_reward = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
reward_all = (rewards * masks).sum(axis=-1).mean()
timing["time/ppo/compute_rewards"] = time.time() - t
t = time.time()
values, advantages, returns = self.compute_advantages(values, rewards, masks)
timing["time/ppo/compute_advantages"] = time.time() - t
# upcast to float32 to avoid dataset issues
batch_dict = {
"queries": queries,
"responses": responses,
"logprobs": all_logprobs.to(torch.float32),
"values": values.to(torch.float32),
"masks": masks,
"advantages": advantages,
"returns": returns,
}
batch_dict.update(model_inputs)
# training loop
t = time.time()
all_stats = []
early_stop = False
for ep in range(self.config.ppo_epochs):
if early_stop:
break
b_inds = np.random.permutation(bs)
# backward_batch_size = mini_batch_size * gradient_accumulation_steps
for backward_batch_start in range(0, bs, self.config.backward_batch_size):
backward_batch_end = backward_batch_start + self.config.backward_batch_size
backward_batch_inds = b_inds[backward_batch_start:backward_batch_end]
for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size):
mini_batch_end = mini_batch_start + self.config.mini_batch_size
mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end]
mini_batch_dict = {
"logprobs": batch_dict["logprobs"][mini_batch_inds],
"values": batch_dict["values"][mini_batch_inds],
"masks": batch_dict["masks"][mini_batch_inds],
# hacks: the queries and responses are ragged.
"queries": [batch_dict["queries"][i] for i in mini_batch_inds],
"responses": [batch_dict["responses"][i] for i in mini_batch_inds],
"advantages": batch_dict["advantages"][mini_batch_inds],
"returns": batch_dict["returns"][mini_batch_inds],
}
for k in model_inputs_names:
mini_batch_dict[k] = batch_dict[k][mini_batch_inds]
# with self.accelerator.accumulate(self.model):
model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}
self.model.train()
logprobs, logits, vpreds, _ = self.batched_forward_pass(
self.model,
mini_batch_dict["queries"],
mini_batch_dict["responses"],
model_inputs,
return_logits=True,
)
loss, train_stats = self.compute_loss(
mini_batch_dict["logprobs"],
mini_batch_dict["values"],
logprobs,
logits,
vpreds,
mini_batch_dict["masks"],
mini_batch_dict["advantages"],
mini_batch_dict["returns"],
)
all_stats.append(train_stats)
if self.is_ddp:
self.model.require_backward_grad_sync = mini_batch_end >= self.config.backward_batch_size
loss = loss / self.config.gradient_accumulation_steps
self.scaler.scale(loss).backward()
# clip the gradient
if self.config.max_grad_norm is not None:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
# step the optimizer and scaler if training in fp16
self.scaler.step(self.optimizer)
self.scaler.update()
# flush the gradients as soon as we can, no need for this memory anymore
self.optimizer.zero_grad(set_to_none=True)
# typically, early stopping is done at the epoch level
if self.config.early_stopping:
policykl = train_stats["policy/policykl"]
early_stop = self._early_stop(policykl)
if early_stop:
print(f"Early stopping at epoch {ep}, policykl is {policykl} larger than {self.config.target_kl*1.5}")
break
timing["time/ppo/optimize_step"] = time.time() - t
t = time.time()
train_stats = stack_dicts(all_stats)
stats = {}
for k, v in train_stats.items():
stats[f"ppo/{k}"] = float(torch.mean(v, axis=0).detach().cpu().numpy()[0])
timing["time/ppo/calc_stats"] = time.time() - t
stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"]
# if kl_ref < 0, then we are in trouble. (KL should be positive. (ln(x) <= x - 1))
stats["ppo/kl_ref"] = float(kl_ref.item()) # this is actually the penalty term in reward we want to minimize.
stats["ppo/reward_all"] = float(reward_all.item()) # this is the actual reward we want to maximize. (reward - kl_ref)
# Log the total ppo time
timing["time/ppo/total"] = time.time() - t0
stats.update(timing)
#NOTE: lr_scheduler is updated outside of this function
return stats
def batched_forward_pass(
self,
model: torch.nn.Module,
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: bool = False,
response_masks: Optional[torch.Tensor] = None,
):
"""
Calculate model outputs in multiple batches.
Args:
queries (`torch.LongTensor`):
List of tensors containing the encoded queries, shape (`batch_size`, `query_length`)
responses (`torch.LongTensor`):
List of tensors containing the encoded responses, shape (`batch_size`, `response_length`)
return_logits (`bool`, *optional*, defaults to `False`):
Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption.
Returns:
(tuple):
- all_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
shape (`batch_size`, `response_length`)
- all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
shape (`batch_size`, `response_length`)
- all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`)
"""
bs = len(queries)
fbs = self.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
all_values = []
for i in range(math.ceil(bs / fbs)):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
logits, _, values, _ = model(input_ids)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(attention_mask)
masks[:, :-1] = attention_mask[:, 1:]
for j in range(len(query_batch)):
start = len(query_batch[j]) - 1 # logprobs starts from the second query token
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch[j] = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
masks[j, :start] = 0
masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
if return_logits:
all_logits.append(logits)
else:
del logits
all_values.append(values)
all_logprobs.append(logprobs)
all_masks.append(masks)
return (
torch.cat(all_logprobs),
torch.cat(all_logits)[:, :-1] if return_logits else None,
torch.cat(all_values)[:, :-1],
torch.cat(all_masks)[:, :-1],
)
def compute_rewards(
self,
scores: torch.FloatTensor,
logprobs: torch.FloatTensor,
ref_logprobs: torch.FloatTensor,
masks: torch.LongTensor,
):
"""
Compute per token rewards from scores and KL-penalty.
Args:
scores (`torch.FloatTensor`):
Scores from the reward model, shape (`batch_size`)
logprobs (`torch.FloatTensor`):
Log probabilities of the model, shape (`batch_size`, `response_length`)
ref_logprobs (`torch.FloatTensor`):
Log probabilities of the reference model, shape (`batch_size`, `response_length`)
"""
rewards, non_score_rewards = [], []
for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
# compute KL penalty (from difference in logprobs)
kl = self._kl_penalty(logprob, ref_logprob)
non_score_reward = -self.kl_ctl.value * kl
non_score_rewards.append(non_score_reward)
reward = non_score_reward.clone()
last_non_masked_index = mask.nonzero()[-1]
# reward is preference model score + KL penalty
reward[last_non_masked_index] += score
rewards.append(reward)
return torch.stack(rewards), torch.stack(non_score_rewards)
def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor) -> torch.FloatTensor:
if self.config.kl_penalty == "kl":
return logprob - ref_logprob
if self.config.kl_penalty == "abs":
return (logprob - ref_logprob).abs()
if self.config.kl_penalty == "mse":
return 0.5 * (logprob - ref_logprob).square()
if self.config.kl_penalty == "full":
# Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459
return F.kl_div(ref_logprob, logprob, log_target=True, reduction="none").sum(-1)
raise NotImplementedError
def compute_advantages(
self,
values: torch.FloatTensor,
rewards: torch.FloatTensor,
mask: torch.FloatTensor,
):
lastgaelam = 0
advantages_reversed = []
gen_len = rewards.shape[-1]
values = values * mask
rewards = rewards * mask
if self.config.whiten_rewards:
rewards = masked_whiten(rewards, mask, shift_mean=False)
for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)
returns = advantages + values
advantages = masked_whiten(advantages, mask)
advantages = advantages.detach()
return values, advantages, returns
def compute_loss(
self,
old_logprobs: torch.FloatTensor,
values: torch.FloatTensor,
logprobs: torch.FloatTensor,
logits: torch.FloatTensor,
vpreds: torch.FloatTensor,
mask: torch.LongTensor,
advantages: torch.FloatTensor,
returns: torch.FloatTensor,
):
"""
Compute loss and stats
"""
vpredclipped = clip_by_value(
vpreds,
values - self.config.cliprange_value,
values + self.config.cliprange_value,
)
vf_losses1 = (vpreds - returns) ** 2
vf_losses2 = (vpredclipped - returns) ** 2
vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask)
vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask)
ratio = torch.exp(logprobs - old_logprobs)
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)
pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask)
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask)
loss = pg_loss + self.config.vf_coef * vf_loss
avg_ratio = masked_mean(ratio, mask).item()
if avg_ratio > self.config.ratio_threshold:
warnings.warn(
f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {self.config.ratio_threshold:.2f}. Skipping batch."
)
pg_loss = pg_loss * 0.0
vf_loss = vf_loss * 0.0
loss = loss * 0.0
entropy = masked_mean(entropy_from_logits(logits), mask)
approxkl = 0.5 * masked_mean((old_logprobs - logprobs) ** 2, mask)
policykl = masked_mean(old_logprobs - logprobs, mask)
return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask)
value_mean, value_var = masked_mean(values, mask), masked_var(values, mask)
stats = dict(
loss=dict(policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach()),
policy=dict(
entropy=entropy.detach(),
# approxkl=approxkl.detach(),
policykl=policykl.detach(),
clipfrac=pg_clipfrac.detach(),
advantages_mean=masked_mean(advantages, mask).detach(),
),
returns=dict(mean=return_mean.detach(), var=return_var.detach()),
val=dict(
vpred=masked_mean(vpreds, mask).detach(),
error=masked_mean((vpreds - returns) ** 2, mask).detach(),
clipfrac=vf_clipfrac.detach(),
mean=value_mean.detach(),
var=value_var.detach(),
),
)
loss_p, loss_v, train_stats = pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats)
return loss, train_stats
def _early_stop(self, policykl):
r"""
Handles the early stopping logic. If the policy KL is greater than the target KL, then the gradient is zeroed and
the optimization step is skipped.
This also handles the multi-gpu case where the policy KL is averaged across all processes.
Args:
policy_kl (torch.Tensor):
the policy KL
Returns:
`bool`: whether to early stop or not
"""
early_stop = False
if not self.config.early_stopping:
return early_stop
if not self.is_ddp and policykl > 1.5 * self.config.target_kl:
self.optimizer.zero_grad()
early_stop = True
elif self.is_ddp:
import torch.distributed as dist
# Wait for all processes to finish
dist.barrier()
# all gather the policykl
dist.all_reduce(policykl, dist.ReduceOp.SUM)
policykl /= int(os.environ["WORLD_SIZE"])
if policykl > 1.5 * self.config.target_kl:
self.optimizer.zero_grad()
early_stop = True
return early_stop
def logprobs_from_logits(logits, labels, gather=True):
"""
See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
"""
logp = F.log_softmax(logits, dim=2)
if not gather:
return logp
logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
return logpy
def masked_mean(values, mask, axis=None):
"""Compute mean of tensor with a masked values."""
if axis is not None:
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
else:
return (values * mask).sum() / mask.sum()
def masked_var(values, mask, unbiased=True):
"""Compute variance of tensor with masked values."""
mean = masked_mean(values, mask)
centered_values = values - mean
variance = masked_mean(centered_values**2, mask)
if unbiased:
mask_sum = mask.sum()
if mask_sum == 0:
raise ValueError(
"The sum of the mask is zero, which can happen when `mini_batch_size=1`;"
"try increase the `mini_batch_size` or `gradient_accumulation_steps`"
)
# note that if mask_sum == 1, then there is a division by zero issue
# to avoid it you just need to use a larger minibatch_size
bessel_correction = mask_sum / (mask_sum - 1)
variance = variance * bessel_correction
return variance
def masked_whiten(values, mask, shift_mean=True):
"""Whiten values with masked values."""
mean, var = masked_mean(values, mask), masked_var(values, mask)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
def clip_by_value(x, tensor_min, tensor_max):
"""
Tensor extenstion to torch.clamp
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
"""
clipped = torch.max(torch.min(x, tensor_max), tensor_min)
return clipped
def entropy_from_logits(logits):
"""Calculate entropy from logits."""
pd = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1)
return entropy
def flatten_dict(nested, sep="/"):
"""Flatten dictionary and concatenate nested keys with separator."""
def rec(nest, prefix, into):
for k, v in nest.items():
if sep in k:
raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
if isinstance(v, Mapping):
rec(v, prefix + k + sep, into)
else:
into[prefix + k] = v
flat = {}
rec(nested, "", flat)
return flat
def stack_dicts(stats_dicts):
"""Stack the values of a dict."""
results = dict()
for k in stats_dicts[0]:
stats_list = [torch.flatten(d[k]) for d in stats_dicts]
results[k] = pad_sequence(stats_list, batch_first=True, padding_value=-1)
return results
class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
https://arxiv.org/pdf/1909.08593.pdf
"""
def __init__(self, init_kl_coef, target, horizon):
self.value = init_kl_coef
self.target = target
self.horizon = horizon
def update(self, current, n_steps):
target = self.target
proportional_error = np.clip(current / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
class FixedKLController:
"""Fixed KL controller."""
def __init__(self, kl_coef):
self.value = kl_coef
def update(self, current, n_steps):
pass
class LengthSampler:
"""
Samples a length
"""
def __init__(self, min_value=0, max_value=10):
self.values = list(range(min_value, max_value))
def __call__(self):
return np.random.choice(self.values)
class LengthReward:
"""
Calculates the reward based on the length of the sequence
"""
def __init__(self, target_length=200):
self.target_length = target_length
def __call__(self, sequence_length):
return -abs(self.target_length - sequence_length)/100.