-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_vqvae.py
108 lines (83 loc) · 3.19 KB
/
train_vqvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import json
import os
import random
from typing import Dict, Any, Callable
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow as tf
from absl import app, flags
from bax import TrainState
from bax import Trainer
from bax.callbacks import CheckpointCallback, Callback
from chex import ArrayTree, Array
from ml_collections.config_flags import config_flags
from posterior_matching.models.vqvae import VQVAE
from posterior_matching.utils import (
configure_environment,
load_datasets,
make_run_dir,
TensorBoardCallback,
)
configure_environment()
config_flags.DEFINE_config_file("config", lock_config=False)
class ReconstructionCallback(Callback):
def __init__(
self, reconstruction_fn: Callable[[ArrayTree], Array], dataset: tf.data.Dataset
):
reconstruction_fn = hk.transform_with_state(reconstruction_fn).apply
self._reconstruction_fn = jax.jit(reconstruction_fn)
self._data_iter = dataset.unbatch().batch(3).repeat().as_numpy_iterator()
self._prng = hk.PRNGSequence(random.randint(0, int(2e9)))
def on_validation_end(
self, train_state: TrainState, step: int, logs: Dict[str, Any]
):
batch = next(self._data_iter)
reconstructions, _ = self._reconstruction_fn(
train_state.params, train_state.state, self._prng.next(), batch
)
x = np.broadcast_to(batch["image"], reconstructions.shape)
reconstructions = np.concatenate([x, reconstructions], axis=2)
assert np.all(np.logical_and(reconstructions <= 1.0, reconstructions >= 0.0))
logs["reconstructions"] = reconstructions
def main(_):
config = flags.FLAGS.config
if "seed" not in config:
config.seed = random.randint(0, int(2e9))
config.lock()
train_dataset, val_dataset = load_datasets(config.data)
def loss_fn(step, is_training, batch):
model = VQVAE(**config.model)
out = model(batch["image"], is_training=is_training)
aux = {
"perplexity": jnp.mean(out["vq_output"]["perplexity"]),
"reconstruction_loss": jnp.mean(out["reconstruction_loss"]),
"vq_loss": jnp.mean(out["vq_output"]["loss"]),
}
return out["loss"], aux
def reconstruction_fn(batch):
model = VQVAE(**config.model)
out = model(batch["image"], is_training=False)
return jnp.clip(out["reconstruction"], 0.0, 1.0)
optimizer = optax.adam(config.learning_rate)
trainer = Trainer(loss_fn, optimizer, num_devices=1, seed=config.seed)
run_dir = make_run_dir(prefix=f"vqvae-{config.data.dataset}")
print("Using run directory:", run_dir)
with open(os.path.join(run_dir, "model_config.json"), "w") as fp:
json.dump(config.model.to_dict(), fp)
callbacks = [
CheckpointCallback(os.path.join(run_dir, "train_state.pkl")),
ReconstructionCallback(reconstruction_fn, val_dataset),
TensorBoardCallback(os.path.join(run_dir, "tb")),
]
trainer.fit(
train_dataset,
config.steps,
val_dataset=val_dataset,
validation_freq=config.validation_freq,
callbacks=callbacks,
)
if __name__ == "__main__":
app.run(main)