forked from jarrycyx/UNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate.py
76 lines (65 loc) · 3 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import numpy as np
import os
import random
def generate(model, test_loader, radom, batch_size, gen_length, save_path, device, variance=0.001,radom_seed = 42, n = 100, residual=False):
all_batches = list(test_loader)
def generate_data(model, gen_length, seed_ori, device):
model.eval()
model = model.to(device)
data = []
seed = seed_ori.to(device)
for i in range(gen_length):
seed = seed.to(device)
output = model(seed).detach()
if output.shape[1] != seed.shape[2]:
seed = torch.cat((seed[:, 1:, :], output.unsqueeze(1)[:,:,:seed.shape[2]]), dim=1)
if i == 0:
data.append(torch.cat((seed, torch.zeros_like(seed).to(device)), dim=2))
else:
seed = torch.cat((seed[:, 1:, :], output.unsqueeze(1)), dim=1)
data.append(output.unsqueeze(1))
return torch.cat(data, dim=1)
def generate_data_radom(model, gen_length, seed_ori, device, variance=0.1):
model.eval()
model = model.to(device)
data = []
seed = seed_ori.to(device)
def generate_random_data(size, mean, variance):
random_data = torch.normal(mean, variance, size)
return random_data
size = seed_ori.shape
mean = 0
for i in range(gen_length):
seed = seed.to(device)
output = model(seed + generate_random_data(size, mean, variance).to(device)).detach()
if output.shape[1] != seed.shape[2]:
seed = torch.cat((seed[:, 1:, :], output.unsqueeze(1)[:,:,:seed.shape[2]]), dim=1)
if i == 0:
data.append(torch.cat((seed, torch.zeros_like(seed).to(device)), dim=2))
else:
data.append(output.unsqueeze(1))
else:
seed = torch.cat((seed[:, 1:, :], output.unsqueeze(1)), dim=1)
data.append(output.unsqueeze(1))
return torch.cat(data, dim=1)
generated_datas = []
max_random = min(n//batch_size + 1, len(all_batches))
seed_list = random.sample(all_batches, max_random - 1)
if radom:
for i in range(max_random - 1):
seed = seed_list[i][0]
generated_data = generate_data_radom(model, gen_length, seed, device, variance)
generated_datas.append(generated_data)
else:
for i in range(max_random - 1):
seed = seed_list[i][0]
generated_data = generate_data(model, gen_length, seed, device)
generated_datas.append(generated_data)
generated_datas = torch.cat(generated_datas, dim=0)
linked_datas = generated_datas.reshape(-1, generated_datas.shape[2])
if not os.path.exists(save_path):
os.makedirs(save_path)
np.save(save_path + '/generated_datas.npy', generated_datas.cpu().detach().numpy())
np.save(save_path + '/linked_datas.npy', linked_datas.cpu().detach().numpy())
return generated_datas