Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pairwise distances example #339

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
384 changes: 384 additions & 0 deletions examples/compute_pairwise_distances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,384 @@
# ruff: noqa: E402
"""Compute distances between keypoints.
====================================

Compute pairwise distances between keypoints, within and across individuals.
"""

# %%
# Imports
# -------

import numpy as np

# For interactive plots: install ipympl with `pip install ipympl` and uncomment
# the following line in your notebook
# %matplotlib widget
from matplotlib import pyplot as plt

from movement import sample_data
from movement.kinematics import (
compute_forward_vector,
compute_pairwise_distances,
)

# %%
# Load sample dataset
# ------------------------
# First, we load an example dataset. In this case, we select the
# ``DLC_two-mice.predictions.csv`` sample data.
ds = sample_data.fetch_dataset(
"DLC_two-mice.predictions.csv",
)

print(ds)

# 2 individuals, 12 keypoints, 2d, time in seconds, 59999 frames

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Visually inspect the data on top of the sample frame

# compute centroid of all keypoints
ds["centroid"] = ds.position.mean(dim="keypoints")

# read sample frame
im = plt.imread(ds.frame_path)

fig, ax = plt.subplots()
ax.imshow(im)
for ind in ds.coords["individuals"].data:
ax.scatter(
x=ds.centroid.sel(individuals=ind, space="x"),
y=ds.centroid.sel(individuals=ind, space="y"),
s=5,
label=f"{ind}",
alpha=0.05,
# color=cmap(i),
)
ax.set_xlabel("x (pixels)")
ax.set_ylabel("y (pixels)")
ax.legend()


# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Get reference length

# Measure the long side of the box in pixels
# Note the lens is a bit distorted
# Should I use diagonal?
start_point = np.array([[209, 382]])
end_point = np.array([[213, 1022]])

reference_length = np.linalg.norm(end_point - start_point)

fig, ax = plt.subplots()
ax.imshow(im)
ax.plot(
[start_point[:, 0], end_point[:, 0]],
[start_point[:, 1], end_point[:, 1]],
"r",
)
ax.text(
1.01 * (start_point[0, 0] + end_point[0, 0]),
0.49 * (start_point[0, 1] + end_point[0, 1]),
f"{reference_length:.2f} pixels",
color="r",
horizontalalignment="center",
)
ax.set_xlabel("x (pixels)")
ax.set_ylabel("y (pixels)")
ax.set_title("Reference length")


# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Compute distances between keypoints on different individuals

inter_individual_kpt_distances = compute_pairwise_distances(
ds.position,
dim="individuals",
pairs={
"individual1": "individual2",
# this will set the dims of the output,
# (keypoints will be the coordinates)
},
) # pixels, dimensions are individual1 and individual2

# for each frame, this matrix has the distance between all keypoints
# from individual 1 to all keypoints on individual 2
print(inter_individual_kpt_distances.shape) # inter_individual_distances

# # normalise with reference length?
# inter_individual_kpt_distances_norm = (
# inter_individual_kpt_distances / reference_length
# )

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Plot matrix of distances and keypoints
# Show different patterns / positions between the two animals
# Note that the colorbars vary across plots aka frames!

time_sel = [50.0, 100.0, 250.0]

# get colormap tab20 for keypoints
cmap = plt.get_cmap("tab20")

# get list of keypoints per individual
# (it may not be the same)
list_kpts_individual_1 = list(
inter_individual_kpt_distances.coords["individual1"].data
)
list_kpts_individual_2 = list(
inter_individual_kpt_distances.coords["individual2"].data
)

for k in range(len(time_sel)):
fig, axs = plt.subplots(1, 2, figsize=(13, 5))
fig.subplots_adjust(wspace=0.5)

# plot keypoints
for kpt_i, kpt in enumerate(ds.coords["keypoints"].data):
axs[0].scatter(
x=ds.position.sel(keypoints=kpt, space="x", time=time_sel[k]),
y=ds.position.sel(keypoints=kpt, space="y", time=time_sel[k]),
s=10,
label=f"{kpt}",
color=cmap(kpt_i),
)

# add text per individual
for ind in ds.coords["individuals"].data:
axs[0].text(
ds.centroid.sel(individuals=ind, space="x", time=time_sel[k]),
ds.centroid.sel(individuals=ind, space="y", time=time_sel[k]),
ind,
horizontalalignment="left",
# verticalalignment="center",
)
axs[0].invert_yaxis()
axs[0].set_xlabel("x (pixels)")
axs[0].set_ylabel("y (pixels)")
axs[0].set_title(f"Keypoints at {time_sel[k]} s")
axs[0].axis("equal")
axs[0].legend() # bbox_to_anchor=(1.1, 1.05))

# plot distances normalised matrix
im = axs[1].imshow(
inter_individual_kpt_distances.sel(time=time_sel[k]),
# vmin=0,
# vmax=1,
)
axs[1].set_xticks(range(0, len(list_kpts_individual_1)))
axs[1].set_yticks(range(0, len(list_kpts_individual_2)))
axs[1].set_xticklabels(
inter_individual_kpt_distances.coords["individual1"].data,
rotation=45,
)
axs[1].set_yticklabels(
inter_individual_kpt_distances.coords["individual2"].data,
rotation=0,
)

axs[1].set_xlabel(inter_individual_kpt_distances.dims[1])
axs[1].set_ylabel(inter_individual_kpt_distances.dims[2])
axs[1].set_title(f"Inter-individual keypoint distances at {time_sel[k]} s")

# cbar = plt.colorbar(im, ax=axs[0])
# cbar.set_label("distance (pixels)")
fig.colorbar(
im,
ax=axs[1],
label="distance (pixels)",
# use_gridspec=True
# ticks=np.linspace(0,1,5)
# ticks=list(range(0, int(reference_length), 100)),
)


# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# To get distance between homologous keypoints
# get the diagonal of the previous matrix at each frame
inter_individual_same_kpts = np.diagonal(
inter_individual_kpt_distances,
axis1=1,
axis2=2,
)
print(inter_individual_same_kpts.shape) # (59999, 12)


# should match selecting each keypoint manually
for k_i, kpt in enumerate(list_kpts_individual_1):
np.testing.assert_almost_equal(
inter_individual_kpt_distances.sel(individual1=kpt, individual2=kpt),
inter_individual_same_kpts[:, k_i],
)

# # plot matrix as sparse matrix?
# # plot vectors on top of a given frame?
# for k in range(len(time_sel)):
# fig, axs = plt.subplots(1, 2, figsize=(13, 5))
# fig.subplots_adjust(wspace=0.5)

# # plot keypoints
# for kpt_i, kpt in enumerate(ds.coords["keypoints"].data):
# axs[0].scatter(
# x=ds.position.sel(keypoints=kpt, space="x", time=time_sel[k]),
# y=ds.position.sel(keypoints=kpt, space="y", time=time_sel[k]),
# s=10,
# label=f"{kpt}",
# color=cmap(kpt_i),
# )
# # connect matching keypoints
# axs[0].plot(
# ds.position.sel(keypoints=kpt, space="x", time=time_sel[k]),
# ds.position.sel(keypoints=kpt, space="y", time=time_sel[k]),
# "r",
# )

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# To get distance between specific keypoints on different individuals
# e.g. snout of individual 1 to tail base of individual 2
# you can select the relevant keypoint coordinates along the dimensions
# "individual1" and "individual2"

distance_snout_1_to_tail_2 = inter_individual_kpt_distances.sel(
individual1="snout", individual2="tailbase"
)

# plot distance from snout 1 to tailbase 2 over time
# plot in a short time window?
fig, ax = plt.subplots()
ax.plot(
distance_snout_1_to_tail_2.time, # seconds
distance_snout_1_to_tail_2 / reference_length,
)
ax.set_xlabel("time (seconds)")
ax.set_ylabel("distance snout-to-tail normalised")


# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Compute distances between the keypoints on the same individual
# compute average bodylength = snout to tailbase

distance_snout_to_tailbase_all = compute_pairwise_distances(
ds.position,
dim="keypoints",
pairs={
"snout": "tailbase",
# this will set the dims of the output
# (individuals will be the coordinates)
},
) # pixels

print(distance_snout_to_tailbase_all) # dimensions are snout and tailbase!

# compute distances within individual
bodylength_individual_1 = distance_snout_to_tailbase_all.sel(
snout="individual1",
tailbase="individual1",
)

bodylength_individual_2 = distance_snout_to_tailbase_all.sel(
snout="individual2",
tailbase="individual2",
)

# compute distances across individuals
# (an alternative way to the above)
snout_1_to_tail_2 = distance_snout_to_tailbase_all.sel(
snout="individual1",
tailbase="individual2",
)
snout_2_to_tail_1 = distance_snout_to_tailbase_all.sel(
snout="individual2",
tailbase="individual1",
)

# check that this approach is equivalent to the previous one
np.testing.assert_almost_equal(
snout_1_to_tail_2.data, distance_snout_1_to_tail_2.data
)

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Plot bodylength over time
# as a histogram instead?
for b_i, bodylength_data_array in enumerate(
[
bodylength_individual_1,
bodylength_individual_2,
]
):
fig, ax = plt.subplots()
ax.plot(
bodylength_data_array.time,
bodylength_data_array,
)
ax.hlines(
bodylength_data_array.mean(dim="time"),
bodylength_data_array.time.min(),
bodylength_data_array.time.max(),
"r",
label="mean length",
)
ax.set_title(f"Bopdy length of individual {b_i+1}")
ax.set_xlabel("time (seconds)")
ax.set_ylabel("length (pixels)")

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Try usage of 'all' and plot distance matrix with four quadrants

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Compute distances between centroids

distances_between_centroids = compute_pairwise_distances(
ds.centroid,
dim="individuals",
pairs={
"individual1": "individual2",
},
)

print(distances_between_centroids.shape) # (59999,)

# histogram
fig, ax = plt.subplots()
ax.hist(
distances_between_centroids,
)
ax.set_xlabel("distance (pixels)")
ax.set_ylabel("frames") # make it relative to the total number of frames?
ax.set_title("Distances between centroids")

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Try a different metric, e.g cosine distance
# https://en.wikipedia.org/wiki/Cosine_similarity

# compute forward vector per individual

ds["head_vector"] = compute_forward_vector(
ds.position,
left_keypoint="leftear",
right_keypoint="rightear",
camera_view="top_down",
)

# compute cosine distance between forward vectors
# 1 - dot product of unit vectors
cosine_distance_head_vectors = compute_pairwise_distances(
ds.head_vector,
dim="individuals",
pairs={
"individual1": "individual2",
},
metric="cosine",
)

# plot histogram
# most of the time the vectors are antiparallel?
fig, ax = plt.subplots() # figsize=(3, 3))
ax.hist(
cosine_distance_head_vectors,
)
ax.set_xlabel("cosine distance")
ax.set_ylabel("frames")
# %%
Loading