Skip to content

Commit 089d2c9

Browse files
authored
Merge pull request #482 from pynapple-org/476-template-matching
476 template matching
2 parents 96e0d52 + dbe1746 commit 089d2c9

File tree

8 files changed

+853
-139
lines changed

8 files changed

+853
-139
lines changed

.github/workflows/main.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,13 @@ jobs:
9090
with:
9191
directory: "doc/_build/html"
9292
# The directory to scan
93-
arguments: --checks Links,Scripts --ignore-urls "https://fonts.gstatic.com,https://mkdocs-gallery.github.io,./doc/_build/html/_static/,https://www.nature.com/articles/s41593-022-01020-w" --assume-extension --check-external-hash --ignore-status-codes 403 --ignore-files "/.+\/html\/_static\/.+/"
93+
arguments:
94+
--checks Links,Scripts
95+
--ignore-urls "https://fonts.gstatic.com,https://mkdocs-gallery.github.io,./doc/_build/html/_static/,https://www.nature.com/articles/s41593-022-01020-w,https://elifesciences.org/reviewed-preprints/85786"
96+
--assume-extension
97+
--check-external-hash
98+
--ignore-status-codes 403
99+
--ignore-files "/.+\/html\/_static\/.+/"
94100
# The arguments to pass to HTMLProofer
95101

96102
check:

doc/examples.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Streaming data from Dandi <examples/tutorial_pynapple_dandi>
2424
:::{card}
2525
```{toctree}
2626
:maxdepth: 3
27-
Computing calcium imaging tuning curves <examples/tutorial_calcium_imaging>
27+
Analyzing calcium imaging data <examples/tutorial_calcium_imaging>
2828
```
2929
:::
3030

doc/examples/tutorial_HD_dataset.md

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -139,17 +139,12 @@ color = xr.DataArray(
139139
To make the tuning curves look nice, we will smooth them before plotting:
140140

141141
```{code-cell} ipython3
142-
from scipy.ndimage import gaussian_filter1d
143-
144-
tmp = np.concatenate(
145-
[
146-
tuning_curves.values,
147-
tuning_curves.values,
148-
tuning_curves.values
149-
],
150-
axis=1)
151-
tmp = gaussian_filter1d(tmp, sigma=3, axis=1)
152-
tuning_curves.values = tmp[:, tuning_curves.shape[1]:2*tuning_curves.shape[1]]
142+
tuning_curves.values = scipy.ndimage.gaussian_filter1d(
143+
tuning_curves.values,
144+
sigma=3,
145+
axis=1,
146+
mode="wrap" # important for circular variables!
147+
)
153148
```
154149

155150
What does this look like? Let's plot them!
@@ -180,7 +175,7 @@ Now that we have HD tuning curves, we can go one step further. Using only the po
180175
We will then compare this to the real head-direction of the animal, and discover that population activity in the ADn indeed codes for HD.
181176

182177
To decode the population activity, we will be using a bayesian decoder as implemented in Pynapple.
183-
Again, just a single line of code!
178+
Again, just a single line of code:
184179

185180
```{code-cell} ipython3
186181
decoded, proba_feature = nap.decode_bayes(
@@ -197,7 +192,7 @@ What does this look like?
197192
print(decoded)
198193
```
199194

200-
The variable 'decoded' contains the most probable angle, and 'proba_feature' contains the probability of a given angular bin at a given time point:
195+
The variable ``decoded`` contains the most probable angle, and ``proba_feature`` contains the probability of a given angular bin at a given time point:
201196

202197
```{code-cell} ipython3
203198
print(proba_feature)
@@ -227,15 +222,15 @@ plt.ylabel("Neurons")
227222
plt.show()
228223
```
229224

230-
From this plot, we can see that the decoder is able to estimate the head-direction based on the population activity in ADn. Amazing!
225+
From this plot, we can see that the decoder is able to estimate the head-direction based on the population activity in ADn.
231226

232-
What does the probability distribution in this example event look like?
233-
Ideally, the bins with the highest probability will correspond to the bins having the most spikes. Let's plot the probability matrix to visualize this.
227+
We can also visualize the probability distribution.
228+
Ideally, the bins with the highest probability correspond to the bins with the most spikes.
234229

235230
```{code-cell} ipython3
236231
smoothed = scipy.ndimage.gaussian_filter(
237232
proba_feature, 1
238-
) # Smoothening the probability distribution
233+
) # Smoothing the probability distribution
239234
240235
# Create a DataFrame with the smoothed distribution
241236
p_feature = pd.DataFrame(
@@ -270,8 +265,7 @@ plt.show()
270265
```
271266

272267
<!-- #region -->
273-
From this probability distribution, we observe that the decoded HD closely matches the actual HD.
274-
Hence, the population activity in ADn is a reliable estimate of the heading direction of the animal.
268+
The decoded HD (dashed grey line) closely matches the actual HD (solid white line), and thus the population activity in ADn is a reliable estimate of the heading direction of the animal.
275269

276270
I hope this tutorial was helpful. If you have any questions, comments or suggestions, please feel free to reach out to the Pynapple Team!
277271

doc/examples/tutorial_calcium_imaging.md

Lines changed: 162 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ The NWB file for the example is hosted on [OSF](https://osf.io/sbnaw). We show b
2323

2424
```{code-cell} ipython3
2525
:tags: [hide-output]
26+
import numpy as np
2627
import pynapple as nap
2728
import matplotlib.pyplot as plt
2829
import seaborn as sns
@@ -91,7 +92,6 @@ As you can see, we have a longer recording for our tracking of the animal's head
9192

9293
```{code-cell} ipython3
9394
transients.time_support
94-
angle.time_support
9595
```
9696

9797
***
@@ -133,25 +133,180 @@ We start by finding the midpoint of the recording, using the function [`get_inte
133133
center = transients.time_support.get_intervals_center()
134134
135135
halves = nap.IntervalSet(
136-
start = [transients.time_support.start[0], center.t[0]],
136+
start = [transients.time_support.start[0], center.t[0]],
137137
end = [center.t[0], transients.time_support.end[0]]
138-
)
138+
)
139139
```
140140

141141
Now, we can compute the tuning curves for each half of the recording and plot the tuning curves again.
142142

143143
```{code-cell} ipython3
144-
half1 = nap.compute_tuning_curves(transients, angle, bins = 120, epochs = halves.loc[[0]])
145-
half2 = nap.compute_tuning_curves(transients, angle, bins = 120, epochs = halves.loc[[1]])
144+
tuning_curves_half1 = nap.compute_tuning_curves(transients, angle, bins = 120, epochs = halves.loc[[0]])
145+
tuning_curves_half2 = nap.compute_tuning_curves(transients, angle, bins = 120, epochs = halves.loc[[1]])
146146
147147
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
148-
set_metadata(half1[4]).plot(ax=ax1)
148+
set_metadata(tuning_curves_half1[4]).plot(ax=ax1)
149149
ax1.set_title("First half")
150-
set_metadata(half2[4]).plot(ax=ax2)
150+
set_metadata(tuning_curves_half2[4]).plot(ax=ax2)
151151
ax2.set_title("Second half")
152152
plt.show()
153153
```
154154

155+
***
156+
Calcium decoding
157+
---------------------
158+
159+
Given some tuning curves, we can also try to decode head direction from the population activity.
160+
For calcium imaging data, Pynapple has `decode_template`, which implements a template matching algorithm.
161+
162+
```{code-cell} ipython3
163+
epochs = nap.IntervalSet([50, 150])
164+
decoded, dist = nap.decode_template(
165+
tuning_curves=tuning_curves,
166+
data=transients,
167+
epochs=epochs,
168+
bin_size=0.1,
169+
metric="correlation",
170+
)
171+
```
172+
173+
```{code-cell} ipython3
174+
:tags: [hide-input]
175+
# normalize distance for better visualization
176+
dist_norm = (dist - np.min(dist.values, axis=1, keepdims=True)) / np.ptp(
177+
dist.values, axis=1, keepdims=True
178+
)
179+
180+
fig, (ax1, ax2, ax3) = plt.subplots(figsize=(8, 8), nrows=3, ncols=1, sharex=True)
181+
ax1.plot(angle.restrict(epochs), label="True")
182+
ax1.scatter(decoded.times(), decoded.values, label="Decoded", c="orange")
183+
ax1.legend(frameon=False, bbox_to_anchor=(1.0, 1.0))
184+
ax1.set_ylabel("Angle [rad]")
185+
186+
im = ax2.imshow(
187+
dist.values.T,
188+
aspect="auto",
189+
origin="lower",
190+
cmap="inferno_r",
191+
extent=(epochs.start[0], epochs.end[0], 0.0, 2*np.pi)
192+
)
193+
ax2.set_ylabel("Angle [rad]")
194+
cbar_ax2 = fig.add_axes([0.95, ax2.get_position().y0, 0.015, ax2.get_position().height])
195+
fig.colorbar(im, cax=cbar_ax2, label="Distance")
196+
197+
im = ax3.imshow(
198+
dist_norm.values.T,
199+
aspect="auto",
200+
origin="lower",
201+
cmap="inferno_r",
202+
extent=(epochs.start[0], epochs.end[0], 0.0, 2*np.pi)
203+
)
204+
cbar_ax3 = fig.add_axes([0.95, ax3.get_position().y0, 0.015, ax3.get_position().height])
205+
fig.colorbar(im, cax=cbar_ax3, label="Norm. distance")
206+
ax3.set_xlabel("Time (s)")
207+
ax3.set_ylabel("Angle [rad]")
208+
plt.show()
209+
```
210+
211+
The distance metric you choose can influence how well we decode.
212+
Internally, ``decode_template`` uses `scipy.spatial.distance.cdist` to compute the distance matrix;
213+
you can take a look at [its documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html)
214+
to see which metrics are supported. Here are a couple examples:
215+
216+
```{code-cell} ipython3
217+
:tags: [hide-input]
218+
metrics = [
219+
"chebyshev",
220+
"dice",
221+
"canberra",
222+
"sqeuclidean",
223+
"minkowski",
224+
"euclidean",
225+
"cityblock",
226+
"mahalanobis",
227+
"correlation",
228+
"cosine",
229+
"seuclidean",
230+
"braycurtis",
231+
"jensenshannon",
232+
]
233+
234+
fig, axs = plt.subplots(5, 1, figsize=(8,12), sharex=True, sharey=True)
235+
for metric, ax in zip(metrics[-5:], axs.flatten()):
236+
decoded, dist = nap.decode_template(
237+
tuning_curves=tuning_curves,
238+
data=transients,
239+
bin_size=0.1,
240+
metric=metric,
241+
epochs=epochs,
242+
)
243+
# normalize distance for better visualization
244+
dist_norm = (dist - np.min(dist.values, axis=1, keepdims=True)) / np.ptp(
245+
dist.values, axis=1, keepdims=True
246+
)
247+
ax.plot(angle.restrict(epochs), label="True")
248+
im = ax.imshow(
249+
dist_norm.values.T,
250+
aspect="auto",
251+
origin="lower",
252+
cmap="inferno_r",
253+
extent=(epochs.start[0], epochs.end[0], 0.0, 2*np.pi)
254+
)
255+
if metric != metrics[-1]:
256+
ax.spines['bottom'].set_visible(False)
257+
ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
258+
ax.set_yticks([])
259+
ax.spines['left'].set_visible(False)
260+
ax.set_ylabel(metric)
261+
cbar_ax = fig.add_axes([0.92, ax.get_position().y0, 0.015, ax.get_position().height])
262+
cbar=fig.colorbar(im, cax=cbar_ax)
263+
cbar.set_label("Norm. distance")
264+
ax.set_xlabel("Time (s)")
265+
plt.show()
266+
```
267+
268+
We recommend trying a bunch to see which one works best for you.
269+
In the case of head direction, we can quantify how well we decode using the absolute angular error.
270+
To get a fair estimate of error, we will compute the tuning curves on the first half of the data
271+
and compute the error for predictions of the second half.
272+
273+
```{code-cell} ipython3
274+
def absolute_angular_error(x, y):
275+
return np.abs(np.angle(np.exp(1j * (x - y))))
276+
277+
# Compute errors
278+
errors = {}
279+
for metric in metrics:
280+
decoded, dist = nap.decode_template(
281+
tuning_curves=tuning_curves_half1,
282+
data=transients,
283+
bin_size=0.1,
284+
metric=metric,
285+
epochs=halves.loc[[1]],
286+
)
287+
errors[metric] = absolute_angular_error(
288+
angle.interpolate(decoded).values, decoded.values
289+
)
290+
```
291+
292+
```{code-cell} ipython3
293+
:tags: [hide-input]
294+
sorted_items = sorted(errors.items(), key=lambda item: np.median(item[1]))
295+
sorted_labels, sorted_values = zip(*sorted_items)
296+
297+
fig, ax = plt.subplots(figsize=(8, 8))
298+
bp = ax.boxplot(
299+
x=sorted_values,
300+
tick_labels=sorted_labels,
301+
vert=False,
302+
showfliers=False
303+
)
304+
ax.set_xlabel("Angular error [rad]")
305+
plt.show()
306+
```
307+
308+
In this case, `jensenshannon` yields the lowest angular error.
309+
155310
:::{card}
156311
Authors
157312
^^^

0 commit comments

Comments
 (0)