-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
381 lines (351 loc) · 14.9 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
"""
This module contains smaller functions to be used in main model
module (vae_reg_GP.py) or throughout rest of code base.
"""
import torch
from torch import nn
import numpy as np
from scipy.stats import gamma
import pandas as pd
from copy import deepcopy
import re
import matplotlib.pyplot as plt
plt.switch_backend('agg')
from torch.utils.tensorboard import SummaryWriter
from scipy import ndimage
from scipy.stats import norm
def hrf(times):
"""
Args: time points for which we wish to estimate the HRF.
Returns:
Values for HRF at given time points.
This is used to account for HRF when modeling biological/neural
covariates -- e.g. visual stim covariate in checker dset.
"""
# Gamma pdf for the peak
peak_values = gamma.pdf(times, 6)
# Gamma pdf for the undershoot
undershoot_values = gamma.pdf(times, 12)
# Combine them
values = peak_values - 0.35 * undershoot_values
return values / np.max(values) * 0.6
def get_xu_ranges(csv_files, eps = 1e-3):
"""
Gets ranges for x values for GP inducing pts by rounding min/max values
for each covariate across the entire dset.
Args
----
csv_file: file containing data. This is the same file passed to Data Class
and loaders.
"""
train_df = pd.read_csv(csv_files[0])
test_df = pd.read_csv(csv_files[1])
mot_regrssors = ['x', 'y', 'z', 'rot_x', 'rot_y', 'rot_z']
xu_ranges = []
for reg in mot_regrssors:
min_val = min(train_df[reg].min(), test_df[reg].min())
max_val = max(train_df[reg].max(), test_df[reg].max())
xu_ranges.append([(min_val-eps), (max_val+eps)])
return xu_ranges
def str2bool(v):
"""
Str to Bool converter for wrapper script.
This is used both for --from_ckpt and for --recons_only flags, which
are False by default but can be turned on either by listing the flag (without args)
or by listing with an appropriate arg (which can be converted to a corresponding boolean)
"""
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def stimulus_to_neural(vol_times):
"""
Creates binary sequence representing task variable in checker experiment/dataset.
Each experimental block in this dset has 20s. First task block begins AFTER a block
of NO-TASK.
"""
t = vol_times//20
res = []
for i in t:
if i==0:
task=0
elif i%2==0:
task=0
elif i%2!=0:
task=1
res.append(task)
return(np.array(res))
def control_stimulus_to_neural(vol_times):
"""
Almost identical to stimulus_to_neural, except this is intended to create
binary sequence for control experiments involving large 3.
Here, the first stim block starts at time ==0. This was done so as to
place artificial signal preferentially in volumes where no real V1 signal
was present.
"""
t = vol_times//20
res = []
for i in t:
if i==0:
task=1
elif i%2==0:
task=1
elif i%2!=0:
task=0
res.append(task)
return(np.array(res))
def zscore(df):
"""
Takes a df with samples, zscores each one of the motion regressor columns
and replaces raw mot regressor inputs by their z-scored vals.
Z-scoring is done for ALL vols and subjects at once in this case.
"""
mot_regrssors = df[['x', 'y', 'z', 'rot_x', 'rot_y', 'rot_z']]
cols = list(mot_regrssors.columns)
for col in cols:
df[col] = (mot_regrssors[col] - mot_regrssors[col].mean())/mot_regrssors[col].std(ddof=0)
return df
def mk_spherical_mask(size, radius):
"""
Creates spherical masks to be used inside add_control_signal script.
Args
-----
size :: size of original 3D numpy matrix A.
radius :: radius of sphere inside A which will be filled with ones.
"""
s, r = size, radius
#A : numpy.ndarray of shape size*size*size.
A = np.zeros((size,size, size))
#AA : copy of A
AA = deepcopy(A)
#(x0, y0, z0) : coordinates of center of circle inside A.
x0, y0, z0 = int(np.floor(A.shape[0]/2)), \
int(np.floor(A.shape[1]/2)), int(np.floor(A.shape[2]/2))
for x in range(x0-radius, x0+radius+1):
for y in range(y0-radius, y0+radius+1):
for z in range(z0-radius, z0+radius+1):
#deb: measures how far a coordinate in A is far from the center.
#deb>=0: inside the sphere.
#deb<0: outside the sphere.
deb = radius - abs(x0-x) - abs(y0-y) - abs(z0-z)
if (deb)>=0: AA[x,y,z] = 1
return AA
def read_design_mat(mat_file_path):
"""
Reads Design Matrix Files generated by FSL's feat module.
These are used to find Least Squares Solution to beta maps
which is then used to regularize model.
"""
with open(mat_file_path) as f:
content = f.readlines()
design_mat = []
for i in range(5, len(content)):
a = content[i].rstrip()
b = re.split(r'\t+', a)
c = [float(i) for i in b]
design_mat.append(c)
design_mat = np.array(design_mat)
return design_mat
def scale_beta_maps(beta_maps):
"""
Performs min-max scaling for least squares maps used in regularization.
This helps regularizer portion of loss (and overall model) to behave better.
"""
for i in range(beta_maps.shape[0]):
map_max = np.amax(beta_maps[i, :].flatten())
beta_maps[i, :] = beta_maps[i, :]/map_max
return beta_maps
#methods to log maps, GP params and etc during training
def log_qu_plots(epoch, gp_params, writer, log_type):
"""
Creates q(u) plots which can be passed as figs to TB.
Should be called after each epoch uptade.
"""
#get means (qu_m), covariance mat (qu_S) and xu ranges for each covariate
#x
qu_m_x = gp_params['x']['qu_m'].detach().cpu().numpy().reshape(6)
qu_S_x = np.diag(gp_params['x']['qu_S'].detach().cpu().numpy())
xu_x = gp_params['x']['xu'].detach().cpu().numpy()
#y
qu_m_y = gp_params['y']['qu_m'].detach().cpu().numpy().reshape(6)
qu_S_y = np.diag(gp_params['y']['qu_S'].detach().cpu().numpy())
xu_y = gp_params['y']['xu'].detach().cpu().numpy()
#z
qu_m_z = gp_params['z']['qu_m'].detach().cpu().numpy().reshape(6)
qu_S_z = np.diag(gp_params['z']['qu_S'].detach().cpu().numpy())
xu_z = gp_params['z']['xu'].detach().cpu().numpy()
#xrot
qu_m_xrot = gp_params['xrot']['qu_m'].detach().cpu().numpy().reshape(6)
qu_S_xrot = np.diag(gp_params['xrot']['qu_S'].detach().cpu().numpy())
xu_xrot = gp_params['xrot']['xu'].detach().cpu().numpy()
#yrot
qu_m_yrot = gp_params['yrot']['qu_m'].detach().cpu().numpy().reshape(6)
qu_S_yrot = np.diag(gp_params['yrot']['qu_S'].detach().cpu().numpy())
xu_yrot = gp_params['yrot']['xu'].detach().cpu().numpy()
#zrot
qu_m_zrot = gp_params['zrot']['qu_m'].detach().cpu().numpy().reshape(6)
qu_S_zrot = np.diag(gp_params['zrot']['qu_S'].detach().cpu().numpy())
xu_zrot = gp_params['zrot']['xu'].detach().cpu().numpy()
#now create figure
fig, axs = plt.subplots(3,2, figsize=(15, 15))
axs[0,0].plot(xu_x, qu_m_x, c='darkblue', alpha=0.5, label = 'q(u) posterior mean')
x_two_sigma = 2*np.sqrt(qu_S_x)
kwargs = {'color':'lightblue', 'alpha':0.3, 'label':'2 sigma'}
axs[0,0].fill_between(xu_x, (qu_m_x-x_two_sigma), (qu_m_x+x_two_sigma), **kwargs)
axs[0,0].legend(loc='best')
axs[0,0].set_title('q(u) x covariate at epoch {}'.format(epoch))
axs[0,0].set_xlabel('Covariate x -- x vals ')
axs[0,0].set_ylabel('q(u)')
axs[0,1].plot(xu_y, qu_m_y, c='darkblue', alpha=0.5, label = 'q(u) posterior mean')
y_two_sigma = 2*np.sqrt(qu_S_y)
kwargs = {'color':'lightblue', 'alpha':0.3, 'label':'2 sigma'}
axs[0,1].fill_between(xu_y, (qu_m_y-y_two_sigma), (qu_m_y+y_two_sigma), **kwargs)
axs[0,1].legend(loc='best')
axs[0,1].set_title('q(u) y covariate at epoch {}'.format(epoch))
axs[0,1].set_xlabel('Covariate y -- x vals ')
axs[0,1].set_ylabel('q(u)')
axs[1,0].plot(xu_z, qu_m_z, c='darkblue', alpha=0.5, label = 'q(u) posterior mean')
z_two_sigma = 2*np.sqrt(qu_S_z)
kwargs = {'color':'lightblue', 'alpha':0.3, 'label':'2 sigma'}
axs[1,0].fill_between(xu_z, (qu_m_z-z_two_sigma), (qu_m_z+z_two_sigma), **kwargs)
axs[1,0].legend(loc='best')
axs[1,0].set_title('q(u) z covariate at epoch {}'.format(epoch))
axs[1,0].set_xlabel('Covariate z -- x vals ')
axs[1,0].set_ylabel('q(u)')
axs[1,1].plot(xu_xrot, qu_m_xrot, c='darkblue', alpha=0.5, label = 'q(u) posterior mean')
xrot_two_sigma = 2*np.sqrt(qu_S_xrot)
kwargs = {'color':'lightblue', 'alpha':0.3, 'label':'2 sigma'}
axs[1,1].fill_between(xu_xrot, (qu_m_xrot-xrot_two_sigma), (qu_m_xrot+xrot_two_sigma), **kwargs)
axs[1,1].legend(loc='best')
axs[1,1].set_title('q(u) xrot covariate at epoch {}'.format(epoch))
axs[1,1].set_xlabel('Covariate xrot -- x vals ')
axs[1,1].set_ylabel('q(u)')
axs[2,0].plot(xu_yrot, qu_m_yrot, c='darkblue', alpha=0.5, label = 'q(u) posterior mean')
yrot_two_sigma = 2*np.sqrt(qu_S_yrot)
kwargs = {'color':'lightblue', 'alpha':0.3, 'label':'2 sigma'}
axs[2,0].fill_between(xu_yrot, (qu_m_yrot-yrot_two_sigma), (qu_m_yrot+yrot_two_sigma), **kwargs)
axs[2,0].legend(loc='best')
axs[2,0].set_title('q(u) yrot covariate at epoch {}'.format(epoch))
axs[2,0].set_xlabel('Covariate yrot -- x vals ')
axs[2,0].set_ylabel('q(u)')
axs[2,1].plot(xu_zrot, qu_m_zrot, c='darkblue', alpha=0.5, label = 'q(u) posterior mean')
zrot_two_sigma = 2*np.sqrt(qu_S_zrot)
kwargs = {'color':'lightblue', 'alpha':0.3, 'label':'2 sigma'}
axs[2,1].fill_between(xu_zrot, (qu_m_zrot-zrot_two_sigma), (qu_m_zrot+zrot_two_sigma), **kwargs)
axs[2,1].legend(loc='best')
axs[2,1].set_title('q(u) zrot covariate at epoch {}'.format(epoch))
axs[2,1].set_xlabel('Covariate zrot -- x vals ')
axs[2,1].set_ylabel('q(u)')
#and pass it to TB writer
writer.add_figure("q(u)_{}".format(log_type), fig)
def log_qkappa_plots(gp_params, writer, log_type):
"""
Logs q(k) to tensorboard.
Plots only posterior --> prior is N(1, 0.5^2).
"""
#task
sa_task = gp_params['task']['sa'].detach().cpu().numpy().reshape(1)
std_task = np.exp(gp_params['task']['logstd'].detach().cpu().numpy())
task_gauss = norm(sa_task[0], scale = std_task[0])
x_task = np.linspace(task_gauss.ppf(0.01), task_gauss.ppf(0.99), 100)
y_task = task_gauss.pdf(x_task)
#x
sa_x= gp_params['x']['sa'].detach().cpu().numpy().reshape(1)
std_x = np.exp(gp_params['x']['logstd'].detach().cpu().numpy())
x_gauss = norm(sa_x[0], scale = std_x[0])
x_x = np.linspace(x_gauss.ppf(0.01), x_gauss.ppf(0.99), 100)
y_x = x_gauss.pdf(x_x)
#y
sa_y= gp_params['y']['sa'].detach().cpu().numpy().reshape(1)
std_y = np.exp(gp_params['y']['logstd'].detach().cpu().numpy())
y_gauss = norm(sa_y[0], scale = std_y[0])
x_y = np.linspace(y_gauss.ppf(0.01), y_gauss.ppf(0.99), 100)
y_y = y_gauss.pdf(x_y)
#z
sa_z= gp_params['z']['sa'].detach().cpu().numpy().reshape(1)
std_z = np.exp(gp_params['z']['logstd'].detach().cpu().numpy())
z_gauss = norm(sa_z[0], scale = std_z[0])
x_z = np.linspace(z_gauss.ppf(0.01), z_gauss.ppf(0.99), 100)
y_z = z_gauss.pdf(x_z)
#xrot
sa_xrot= gp_params['xrot']['sa'].detach().cpu().numpy().reshape(1)
std_xrot = np.exp(gp_params['xrot']['logstd'].detach().cpu().numpy())
xrot_gauss = norm(sa_xrot[0], scale = std_xrot[0])
x_xrot = np.linspace(xrot_gauss.ppf(0.01), xrot_gauss.ppf(0.99), 100)
y_xrot = xrot_gauss.pdf(x_xrot)
#yrot
sa_yrot= gp_params['yrot']['sa'].detach().cpu().numpy().reshape(1)
std_yrot = np.exp(gp_params['yrot']['logstd'].detach().cpu().numpy())
yrot_gauss = norm(sa_yrot[0], scale = std_yrot[0])
x_yrot = np.linspace(yrot_gauss.ppf(0.01), yrot_gauss.ppf(0.99), 100)
y_yrot = yrot_gauss.pdf(x_yrot)
#zrot
sa_zrot= gp_params['zrot']['sa'].detach().cpu().numpy().reshape(1)
std_zrot = np.exp(gp_params['zrot']['logstd'].detach().cpu().numpy())
zrot_gauss = norm(sa_zrot[0], scale = std_zrot[0])
x_zrot = np.linspace(zrot_gauss.ppf(0.01), zrot_gauss.ppf(0.99), 100)
y_zrot = zrot_gauss.pdf(x_zrot)
#now create plot
fig, axs = plt.subplots(3,3, figsize=(15, 15))
axs[0,0].plot(x_task, y_task, lw=2, alpha = 0.5, color = 'green')
axs[0,0].set_title('Task q(k)')
axs[0,1].plot(x_x, y_x, lw=2, alpha = 0.5, color = 'blue')
axs[0,1].set_title('X q(k)')
axs[0,2].plot(x_y, y_y, lw=2, alpha = 0.5, color = 'orange')
axs[0,2].set_title('Y q(k)')
axs[1,0].plot(x_z, y_z, lw=2, alpha = 0.5, color = 'red')
axs[1,0].set_title('Z q(k)')
axs[1,1].plot(x_xrot, y_xrot, lw=2, alpha = 0.5, color = 'violet')
axs[1,1].set_title('Xrot q(k)')
axs[1,2].plot(x_yrot, y_yrot, lw=2, alpha = 0.5, color = 'magenta')
axs[1,2].set_title('Yrot q(k)')
axs[2,0].plot(x_zrot, y_zrot, lw=2, alpha = 0.5, color = 'purple')
axs[2,0].set_title('Zrot q(k)')
#pass it to TB writer
writer.add_figure("q(k)_{}".format(log_type), fig)
def log_beta(writer, xq, beta_mean, beta_cov, covariate_name, log_type):
"""
Logs beta dist plots to TB.
This is done from within fwd method.
"""
cov_dict = {}
xq = xq.cpu().numpy()
beta_mean = beta_mean.detach().cpu().numpy()
two_sigma = 2*np.sqrt(np.diag(beta_cov.detach().cpu().numpy()))
cov_dict['xq'] = xq
cov_dict['mean'] = beta_mean
cov_dict['two_sig'] = two_sigma
cov_data = pd.DataFrame.from_dict(cov_dict)
sorted_cov_data = cov_data.sort_values(by=["xq"])
fig = plt.figure()
plt.plot(sorted_cov_data['xq'], sorted_cov_data['mean'], \
c='darkblue', alpha=0.5, label='Beta posterior mean')
kwargs = {'color':'lightblue', 'alpha':0.3, 'label':'2 sigma'}
plt.fill_between(sorted_cov_data['xq'], (sorted_cov_data['mean'] - sorted_cov_data['two_sig']), \
(sorted_cov_data['mean'] + sorted_cov_data['two_sig']), **kwargs)
plt.legend(loc='best')
plt.title('Beta_{}'.format(covariate_name))
plt.xlabel('Covariate')
plt.ylabel('Beta Ouput')
writer.add_figure("Beta/{}_{}".format(covariate_name, log_type), fig)
def log_map(writer, img_shape, map, slice, map_name, batch_size, log_type):
"""
Logs a particular brain map reconstruction to TB.
Args
----
Map: (np array) map reconstructions for a given minibatch.
slice: (int) specific slice we wish to log.
map_name: (string) Name of map (e.g., base, task)
batch_size: (int) Size of minibatch.
For now am logging slices only in saggital view.
"""
map = map.reshape((batch_size, img_shape[0], img_shape[1], img_shape[2]))
for i in range(batch_size):
slc = map[i, slice, :, :]
slc = ndimage.rotate(slc, 90)
fig_name = '{}_{}_{}/{}'.format(map_name, log_type, slice, i)
writer.add_image(fig_name, slc, dataformats='HW')