Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Template-based smartSPIM-CCF registration #19

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion code/aind_ccf_reg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""CCF Registration package.
"""
__version__ = "0.0.17"

__version__ = "0.0.18"
132 changes: 132 additions & 0 deletions code/aind_ccf_reg/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
This config file points to data directories, defines global variables,
specify schema format for Preprocess and Registration.
"""

from pathlib import Path
from typing import Union

import dask.array as da
import numpy as np
from argschema import ArgSchema
from argschema.fields import Dict as sch_dict
from argschema.fields import Int
from argschema.fields import List as sch_list
from argschema.fields import Str

PathLike = Union[str, Path]
ArrayLike = Union[da.core.Array, np.ndarray]

VMIN = 0
VMAX = 1.5


class RegSchema(ArgSchema):
"""
Schema format for Preprocess and Registration.
"""

input_data = Str(
metadata={
"required": True,
"description": "Input data without timestamp",
}
)

input_channel = Str(
metadata={"required": True, "description": "Channel to register"}
)

input_scale = Int(
metadata={"required": True, "description": "Zarr scale to start with"}
)

input_orientation = sch_list(
cls_or_instance=sch_dict,
metadata={
"required": True,
"description": "Brain orientation during aquisition",
},
)

template_path = Str(
metadata={"required": True, "description": "Path to the SPIM template"}
)

ccf_reference_path = Str(
metadata={"required": True, "description": "Path to the CCF template"}
)

template_to_ccf_transform_path = sch_list(
cls_or_instance=Str,
metadata={
"required": True,
"description": "Path to the template-to-ccf transform",
},
)

ccf_annotation_to_template_moved_path = Str(
metadata={
"required": True,
"description": "Path to CCF annotation in SPIM template space",
}
)

output_data = Str(
metadata={"required": True, "description": "Output file"}
)

results_folder = Str(
metadata={
"required": True,
"description": "Folder to save registration results",
}
)

reg_folder = Str(
metadata={
"required": True,
"description": "Folder to save derivative results of registration",
}
)

bucket_path = Str(
required=True,
metadata={"description": "Amazon Bucket or Google Bucket name"},
)

code_url = Str(
metadata={"required": True, "description": "CCF registration URL"}
)

metadata_folder = Str(
metadata={"required": True, "description": "Metadata folder"}
)

OMEZarr_params = sch_dict(
metadata={
"required": True,
"description": "OMEZarr writing parameters",
}
)

prep_params = sch_dict(
metadata={
"required": True,
"description": "raw data preprocessing parameters",
}
)

ants_params = sch_dict(
metadata={
"required": True,
"description": "ants registering parameters",
}
)

reference_res = Int(
metadata={
"required": True,
"description": "Voxel Resolution of reference in microns",
}
)
134 changes: 134 additions & 0 deletions code/aind_ccf_reg/plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Plot functions for easy and fast visualiztaion of images and
regsitration results
"""

import matplotlib.pyplot as plt
import numpy as np


def plot_antsimgs(ants_img, figpath, title="", vmin=0, vmax=500):
"""
Plot ANTs image

Parameters
------------
ants_img: ANTsImage
figpath: PathLike
Path where the plot is going to be saved
title: str
Figure title
vmin: float
Set the color limits of the current image.
vmax: float
Set the color limits of the current image.
"""

if figpath:
ants_img = ants_img.numpy()
half_size = np.array(ants_img.shape) // 2
fig, ax = plt.subplots(1, 3, figsize=(10, 6))
ax[0].imshow(
ants_img[half_size[0], :, :], cmap="gray", vmin=vmin, vmax=vmax
)
ax[1].imshow(
ants_img[:, half_size[1], :], cmap="gray", vmin=vmin, vmax=vmax
)
im = ax[2].imshow(
ants_img[
:,
:,
half_size[2],
],
cmap="gray",
vmin=vmin,
vmax=vmax,
)
fig.suptitle(title, y=0.9)
plt.colorbar(
im, ax=ax.ravel().tolist(), fraction=0.1, pad=0.025, shrink=0.7
)
plt.savefig(figpath, bbox_inches="tight", pad_inches=0.1)


def plot_reg(
moving, fixed, warped, figpath, title="", loc=0, vmin=0, vmax=1.5
):
"""
Plot registration results: moving, fixed, deformed,
overlay and difference images after registration

Parameters
------------
moving: ANTsImage
Moving image
fixed: ANTsImage
Fixed image
warped: ANTsImage
Deformed image
figpath: PathLike
Path where the plot is going to be saved
title: str
Figure title
loc: int
Visualization direction
vmin, vmax: float
Set the color limits of the current image.
"""

if loc >= len(moving.shape):
raise ValueError(
f"loc {loc} is not allowed, should less than {len(moving.shape)}"
)

half_size_moving = np.array(moving.shape) // 2
half_size_fixed = np.array(fixed.shape) // 2
half_size_warped = np.array(warped.shape) // 2

if loc == 0:
moving = moving.view()[half_size_moving[0], :, :]
fixed = fixed.view()[half_size_fixed[0], :, :]
warped = warped.view()[half_size_warped[0], :, :]
y = 0.75
elif loc == 1:
# moving = np.rot90(moving.view()[:,half_size[1], :], 3)
moving = moving.view()[:, half_size_moving[1], :]
fixed = fixed.view()[:, half_size_fixed[1], :]
warped = warped.view()[:, half_size_warped[1], :]
y = 0.82
elif loc == 2:
moving = np.rot90(np.fliplr(moving.view()[:, :, half_size_moving[2]]))
fixed = np.rot90(np.fliplr(fixed.view()[:, :, half_size_fixed[2]]))
warped = np.rot90(np.fliplr(warped.view()[:, :, half_size_warped[2]]))
y = 0.82
else:
raise ValueError(
f"loc {loc} is not allowed. Allowed values are: 0, 1, 2"
)

# combine deformed and fixed images to an RGB image
overlay = np.stack((warped, fixed, warped), axis=2)
diff = fixed - warped

fontsize = 14

fig, ax = plt.subplots(1, 5, figsize=(16, 6))
ax[0].imshow(moving, cmap="gray", vmin=vmin, vmax=vmax)
ax[1].imshow(fixed, cmap="gray", vmin=vmin, vmax=vmax)
ax[2].imshow(warped, cmap="gray", vmin=vmin, vmax=vmax)
ax[3].imshow(overlay)
ax[4].imshow(diff, cmap="gray", vmin=-(vmax), vmax=vmax)

ax[0].set_title("Moving", fontsize=fontsize)
ax[1].set_title("Fixed", fontsize=fontsize)
ax[2].set_title("Deformed", fontsize=fontsize)
ax[3].set_title("Deformed Overlay Fixed", fontsize=fontsize)
ax[4].set_title("Fixed - Deformed", fontsize=fontsize)

fig.suptitle(title, size=18, y=y)

if figpath:
plt.savefig(figpath, bbox_inches="tight", pad_inches=0.01)
plt.close()
else:
fig.show()
Loading
Loading