-
Notifications
You must be signed in to change notification settings - Fork 0
/
vae_extract.py
71 lines (57 loc) · 1.99 KB
/
vae_extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from model import VAE
from multiprocessing import Process
from joblib import Parallel, delayed
import time
import numpy as np
import glob
from common import Logger
from config import cfg
def extract(fs, idx, N):
model = VAE()
model.load_state_dict(torch.load(cfg.vae_save_ckpt, map_location=lambda storage, loc: storage)['model'])
model = model.cuda(idx)
for n, f in enumerate(fs):
data = np.load(f)
imgs = data['sx'].transpose(0, 3, 1, 2)
actions = data['ax']
rewards = data['rx']
dones = data['dx']
x = torch.from_numpy(imgs).float().cuda(idx) / 255.0
mu, logvar, _, z = model(x)
save_path = "{}/{}".format(cfg.seq_extract_dir, f.split('/')[-1])
np.savez_compressed(save_path,
mu=mu.detach().cpu().numpy(),
logvar=logvar.detach().cpu().numpy(),
dones=dones,
rewards=rewards,
actions=actions)
if n % 10 == 0:
print('Process %d: %5d / %5d' % (idx, n, N))
def vae_extract():
logger = Logger("{}/vae_extract_{}.log".format(cfg.logger_save_dir, cfg.timestr))
logger.log(cfg.info)
print("Loading Dataset")
data_list = glob.glob(cfg.seq_save_dir +'/*.npz')
data_list.sort()
N = len(data_list) // 4
procs = []
for idx in range(4):
p = Process(target=extract, args=(data_list[idx*N:(idx+1)*N], idx, N))
procs.append(p)
p.start()
time.sleep(1)
for p in procs:
p.join()
def load_init(f):
data = np.load(f)
return data['mu'][0], data['logvar'][0]
def save_init_z():
data_list = glob.glob(cfg.seq_extract_dir + '/*.npz')
datas = Parallel(n_jobs=cfg.num_cpus, verbose=1)(delayed(load_init)(f) for f in data_list)
mus = np.array([data[0] for data in datas])
logvars = np.array([data[1] for data in datas])
np.savez_compressed('init_z.npz', mus=mus, logvars=logvars)
if __name__ == '__main__':
# vae_extract()
save_init_z()