-
Notifications
You must be signed in to change notification settings - Fork 0
/
auxiliary_functions.py
130 lines (103 loc) · 4.44 KB
/
auxiliary_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: Melisa
Auxiliary functions created to help in the neural analysis course, assignment4 related to calcium imaging analysis.
"""
import os
import caiman as cm
import matplotlib.pyplot as plt
import math
import numpy as np
from caiman.motion_correction import high_pass_filter_space
from caiman.source_extraction.cnmf.cnmf import load_CNMF
from matplotlib.patches import Rectangle
from random import randrange
def plot_FOV(FOV_file = None, ROI_file = None, output_file = None):
FOV = cm.load(FOV_file)
ROI = cm.load(ROI_file)
figure = plt.figure(constrained_layout=True)
gs = figure.add_gridspec(1, 2)
figure_ax1 = figure.add_subplot(gs[0, 0])
figure_ax1.set_title('FOV', fontsize = 15)
figure_ax2 = figure.add_subplot(gs[0, 1])
figure_ax2.set_title('ROI', fontsize = 15)
figure_ax1.imshow(FOV[0, :, :],cmap='gray')
[x_, _x, y_, _y] = [100, 500, 200, 600]
rect = Rectangle((y_, x_), _y - y_, _x - x_, fill=False, color='r', linestyle='-', linewidth=2)
figure_ax1.add_patch(rect)
figure_ax2.imshow(ROI[0, :, :],cmap='gray')
figure.savefig(output_file)
return
def temporal_evolution(file_name = None, output_file_name = None):
'''
After decoding this plots the time evolution of some pixel values in the ROI, the histogram if pixel values and
the ROI with the mark of the position for the randomly selected pixels
'''
movie_original = cm.load(file_name)
figure = plt.figure(constrained_layout=True)
gs = figure.add_gridspec(5, 6)
figure_ax1 = figure.add_subplot(gs[0:2, 0:3])
figure_ax1.set_title('ROI', fontsize = 15)
figure_ax1.set_yticks([])
figure_ax1.set_xticks([])
figure_ax2 = figure.add_subplot(gs[2:5, 0:3])
figure_ax2.set_xlabel('Time [s]', fontsize = 15)
figure_ax2.set_ylabel('Pixel value', fontsize = 15)
figure_ax2.set_title('Temporal Evolution', fontsize = 15)
figure_ax2.set_ylim((300,1000))
figure_ax1.imshow(movie_original[0,:,:], cmap = 'gray')
color = ['b', 'r' , 'g', 'c', 'm']
for i in range(5):
x = randrange(movie_original.shape[1]-5)+5
y = randrange(movie_original.shape[2]-5)+5
[x_, _x, y_, _y] = [x-5,x+5,y-5,y+5]
rect = Rectangle((y_, x_), _y - y_, _x - x_, fill=False, color=color[i], linestyle='-', linewidth=2)
figure_ax1.add_patch(rect)
figure_ax2.plot(np.arange(0,movie_original.shape[0],)/10, movie_original[:,x,y], color = color[i])
figure_ax_i = figure.add_subplot(gs[i, 4:])
figure_ax_i.hist(movie_original[:,x,y],50, color = color[i])
figure_ax_i.set_xlim((300,1000))
figure_ax_i.set_ylabel('#')
figure_ax_i.set_xlabel('Pixel value')
figure.set_size_inches([5.,5.])
figure.savefig(output_file_name)
return
def get_fig_gSig_filt_vals(file_name = None , gSig_filt_vals = None, output_file = None):
'''
Plot original FOV and several versions of spatial filtering for comparison
:param row: analisis state row for which the filtering is computed
:param gSig_filt_vals: array containing size of spatial filters that will be applyed
:return: figure
'''
m = cm.load(file_name)
temp = cm.motion_correction.bin_median(m)
N = len(gSig_filt_vals)
fig, axes = plt.subplots(2, int(math.ceil((N + 1) / 2)))
axes[0, 0].imshow(temp, cmap='gray')
axes[0, 0].set_title('Unfiltered', fontsize = 12)
axes[0, 0].axis('off')
for i in range(0, N):
gSig_filt = gSig_filt_vals[i]
m_filt = [high_pass_filter_space(m_, (gSig_filt, gSig_filt)) for m_ in m]
temp_filt = cm.motion_correction.bin_median(m_filt)
axes.flatten()[i + 1].imshow(temp_filt, cmap='gray')
axes.flatten()[i + 1].set_title(f'gSig_filt = {gSig_filt}', fontsize = 12)
axes.flatten()[i + 1].axis('off')
if N + 1 != axes.size:
for i in range(N + 1, axes.size):
axes.flatten()[i].axis('off')
# Save figure
fig.suptitle('Spatial filtering', fontsize = 15)
fig.set_size_inches([7.,7.])
fig.savefig(output_file)
return
def summary_images(corr_image = None, pnr_image = None, output_path = None):
figure, axes = plt.subplots(1, 2)
axes[0].imshow(corr_image, cmap='viridis')
axes[0].set_title('Correlation image', fontsize=12)
axes[1].imshow(pnr_image, cmap='viridis')
axes[1].set_title('PNR image', fontsize=12)
figure.set_size_inches([7.,7.])
figure.savefig(output_path)
return