Skip to content

Commit

Permalink
Merge pull request #6 from PennLINC/fix/xfm
Browse files Browse the repository at this point in the history
ants recompute xfm
  • Loading branch information
smeisler authored Sep 19, 2024
2 parents 56752c3 + 7cb3353 commit 693d328
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 71 deletions.
3 changes: 2 additions & 1 deletion ingress2qsirecon/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def _ingress2qsirecon(**kwargs):
input_pipeline = kwargs["input_pipeline"]
participant_label = kwargs["participant_label"]
work_dir = Path(kwargs["work_dir"])
skip_mni2009c_norm = kwargs["skip_mni2009c_norm"]
check_gradients = kwargs["check_gradients"]
dry_run = kwargs["dry_run"]
symlink = kwargs["symlink"]
Expand Down Expand Up @@ -55,7 +56,7 @@ def _ingress2qsirecon(**kwargs):
layouts = create_layout(input_dir, output_dir, input_pipeline, participant_label)

# Create and run overall workflow, which will be broken down to single subject workflows
ingress2qsirecon_wf = create_ingress2qsirecon_wf(layouts, base_dir=work_dir)
ingress2qsirecon_wf = create_ingress2qsirecon_wf(layouts, base_dir=work_dir, skip_mni2009c_norm=skip_mni2009c_norm)
ingress2qsirecon_wf.run()


Expand Down
7 changes: 7 additions & 0 deletions ingress2qsirecon/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def _drop_sub(value):
help="The working directory for the processing. "
"If not specified, the current working directory will be used.",
)
optional.add_argument(
"--skip-mni2009c-norm",
"--skip_mni2009c_norm",
action="store_true",
default=False,
help="Skip MNI normalization step. MNI normalization is not required for all pipelines in QSIRecon.",
)
optional.add_argument(
"--check_gradients",
"--check-gradients",
Expand Down
70 changes: 44 additions & 26 deletions ingress2qsirecon/utils/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import SimpleITK as sitk
from nilearn import image as nim
from nipype import logging
from nipype.interfaces import ants
from nipype.interfaces.base import (
BaseInterface,
BaseInterfaceInputSpec,
Expand All @@ -22,7 +23,16 @@
TraitedSpec,
traits,
)
from nipype.interfaces.mixins import reporting
from nipype.interfaces.workbench.base import WBCommand
from niworkflows.interfaces.norm import (
SpatialNormalization,
_SpatialNormalizationInputSpec,
)
from niworkflows.interfaces.reportlets.base import (
RegistrationRC,
_SVGReportCapableInputSpec,
)

from ingress2qsirecon.utils.functions import to_lps

Expand Down Expand Up @@ -610,36 +620,44 @@ def _convert_fsl_to_mrtrix(bval_file, bvec_file, output_fname):
np.savetxt(output_fname, gtab, fmt=["%.8f", "%.8f", "%.8f", "%d"])


class _ScansTSVWriterInputSpec(BaseInterfaceInputSpec):
filenames = traits.List(traits.Str, mandatory=True, desc="List of filenames")
source_files = traits.List(traits.Str, mandatory=True, desc="List of source files")
out_file = File("output.tsv", usedefault=True, desc="Output TSV file")

class RobustMNINormalizationInputSpecRPT(
_SVGReportCapableInputSpec,
_SpatialNormalizationInputSpec,
):
# Template orientation.
orientation = traits.Enum(
"LPS",
mandatory=True,
usedefault=True,
desc="modify template orientation (should match input image)",
)

class _ScansTSVWriterOutputSpec(TraitedSpec):
out_file = File(desc="Output TSV file")

class RobustMNINormalizationOutputSpecRPT(
reporting.ReportCapableOutputSpec,
ants.registration.RegistrationOutputSpec,
):
# Try to work around TraitError of "undefined 'reference_image' attribute"
reference_image = traits.File(desc="the output reference image")

class ScansTSVWriter(BaseInterface):
input_spec = _ScansTSVWriterInputSpec
output_spec = _ScansTSVWriterOutputSpec

def _run_interface(self, runtime):
filenames = self.inputs.filenames
source_files = self.inputs.source_files
class RobustMNINormalizationRPT(RegistrationRC, SpatialNormalization):
input_spec = RobustMNINormalizationInputSpecRPT
output_spec = RobustMNINormalizationOutputSpecRPT

# Check if lengths match
if len(filenames) != len(source_files):
raise ValueError("filenames and source_files must have the same length")
def _post_run_hook(self, runtime):
# We need to dig into the internal ants.Registration interface
self._fixed_image = self._get_ants_args()["fixed_image"]
if isinstance(self._fixed_image, (list, tuple)):
self._fixed_image = self._fixed_image[0] # get first item if list

# Create DataFrame
df = pd.DataFrame({"filename": filenames, "source_file": source_files})

# Write to TSV
df.to_csv(self.inputs.out_file, sep='\t', index=False)
return runtime
if self._get_ants_args().get("fixed_image_mask") is not None:
self._fixed_image_mask = self._get_ants_args().get("fixed_image_mask")
self._moving_image = self.aggregate_outputs(runtime=runtime).warped_image
LOGGER.info(
"Report - setting fixed (%s) and moving (%s) images",
self._fixed_image,
self._moving_image,
)

def _list_outputs(self):
outputs = self.output_spec().get().copy()
outputs['out_file'] = self.inputs.out_file
return outputs
return super(RobustMNINormalizationRPT, self)._post_run_hook(runtime)
117 changes: 74 additions & 43 deletions ingress2qsirecon/utils/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from templateflow import api as tflow

from ingress2qsirecon.utils.interfaces import (
ComposeTransforms,
Conform,
ConformDwi,
ConvertWarpfield,
ExtractB0s,
FSLBVecsToTORTOISEBmatrix,
MRTrixGradientTable,
NIFTItoH5,
RobustMNINormalizationRPT,
)


Expand All @@ -26,7 +26,7 @@ def parse_layout(subject_layout):
return tuple(subject_layout.values())


def create_single_subject_wf(subject_layout):
def create_single_subject_wf(subject_layout, skip_mni2009c_norm=False):
"""
Create a nipype workflow to ingest a single subject.
Expand Down Expand Up @@ -84,7 +84,7 @@ def create_single_subject_wf(subject_layout):

# Define input node for the single subject workflow
input_node = Node(
IdentityInterface(fields=['subject_layout', "MNI2009cAsym_to_MNINLin6", "MNINLin6_to_MNI2009cAsym"]),
IdentityInterface(fields=['subject_layout']),
name='input_node',
)
input_node.inputs.subject_layout = subject_layout
Expand Down Expand Up @@ -252,57 +252,88 @@ def create_single_subject_wf(subject_layout):

# Now get transform to MNI2009cAsym
MNI_template = subject_layout["MNI_template"]
if MNI_template == "MNI152NLin6Asym":
# Get the relevant transforms from templateflow
MNI2009cAsym_to_MNINLin6 = tflow.get('MNI152NLin6Asym', desc=None, suffix='xfm', extension='h5')
input_node.inputs.MNI2009cAsym_to_MNINLin6 = MNI2009cAsym_to_MNINLin6
MNINLin6_to_MNI2009cAsym = tflow.get('MNI152NLin2009cAsym', desc=None, suffix='xfm', extension='h5')
input_node.inputs.MNINLin6_to_MNI2009cAsym = MNINLin6_to_MNI2009cAsym

# Define a function to make a list of two warp files for input to ComposeTransforms
def combine_warp_files(file1, file2):
return [file1, file2]

# Create a Function node to make a list of warp files for MNI2subject
warp_files_list_MNI2subject = Node(
Function(input_names=['file1', 'file2'], output_names=['combined_files'], function=combine_warp_files),
name='list_warp_files_MNI2subject',
if MNI_template != "MNI152NLin2009cAsym" and skip_mni2009c_norm == False:
# Get MNI brain and mask
MNI2009cAsym_brain_path = str(
tflow.get('MNI152NLin2009cAsym', desc="brain", suffix="T1w", resolution=1, extension=".nii.gz")
)

# Create a Function node to make a list of warp files for subject2MNI
warp_files_list_subject2MNI = Node(
Function(input_names=['file1', 'file2'], output_names=['combined_files'], function=combine_warp_files),
name='list_warp_files_subject2MNI',
MNI2009cAsym_mask_path = str(
tflow.get('MNI152NLin2009cAsym', desc="brain", suffix="mask", resolution=1, extension=".nii.gz")
)

# Make the compute nodes for combining transforms
compose_transforms_node_MNI2subject = Node(ComposeTransforms(), name="compose_transforms_MNI2subject")
compose_transforms_node_MNI2subject.inputs.output_warp = str(subject_layout["bids_MNI2subject"]).replace(
MNI_template, "MNI152NLin2009cAsym"
# Create transform node
anat_norm_interface = RobustMNINormalizationRPT(float=True, generate_report=True, flavor="precise")
anat_nlin_normalization = Node(anat_norm_interface, name="anat_nlin_normalization")
# Set inputs
anat_nlin_normalization.inputs.template = MNI2009cAsym_brain_path
anat_nlin_normalization.inputs.reference_image = MNI2009cAsym_brain_path
anat_nlin_normalization.inputs.reference_mask = MNI2009cAsym_mask_path
anat_nlin_normalization.inputs.orientation = "LPS"

# Create output node to save out relevant files
def save_xfm_outputs(
to_template_nonlinear_transform_in,
from_template_nonlinear_transform_in,
to_template_nonlinear_transform_out,
from_template_nonlinear_transform_out,
):
import shutil

# Dictionary of inputs to save
files = {
to_template_nonlinear_transform_out: to_template_nonlinear_transform_in,
from_template_nonlinear_transform_out: from_template_nonlinear_transform_in,
}

for filename, file_content in files.items():
# Copy or move files to the output directory
shutil.copy(file_content, filename)

save_outputs_node = Node(
Function(
input_names=[
"to_template_nonlinear_transform_in",
"from_template_nonlinear_transform_in",
"to_template_nonlinear_transform_out",
"from_template_nonlinear_transform_out",
],
function=save_xfm_outputs,
),
name="save_outputs_node",
)
compose_transforms_node_subject2MNI = Node(ComposeTransforms(), name="compose_transforms_subject2MNI")
compose_transforms_node_subject2MNI.inputs.output_warp = str(subject_layout["bids_subject2MNI"]).replace(
MNI_template, "MNI152NLin2009cAsym"
save_outputs_node.inputs.to_template_nonlinear_transform_out = str(subject_layout["bids_subject2MNI"]).replace(
subject_layout["MNI_template"], "MNI152NLin2009cAsym"
)

# Connect the nodes
save_outputs_node.inputs.from_template_nonlinear_transform_out = str(
subject_layout["bids_MNI2subject"]
).replace(subject_layout["MNI_template"], "MNI152NLin2009cAsym")
# Link T1w brain and mask to node
wf.connect(
[
# For MNI2subject
(nii_to_h5_node_MNI2subject, warp_files_list_MNI2subject, [("xfm_h5_out", "file1")]),
(input_node, warp_files_list_MNI2subject, [("MNI2009cAsym_to_MNINLin6", "file2")]),
(warp_files_list_MNI2subject, compose_transforms_node_MNI2subject, [("combined_files", "warp_files")]),
# For subject2MNI
(input_node, warp_files_list_subject2MNI, [("MNINLin6_to_MNI2009cAsym", "file1")]),
(nii_to_h5_node_subject2MNI, warp_files_list_subject2MNI, [("xfm_h5_out", "file2")]),
(warp_files_list_subject2MNI, compose_transforms_node_subject2MNI, [("combined_files", "warp_files")]),
(
conform_t1w_node,
anat_nlin_normalization,
[("out_file", "moving_image")],
),
(
conform_mask_node,
anat_nlin_normalization,
[("out_file", "moving_mask")],
),
(
anat_nlin_normalization,
save_outputs_node,
[
('composite_transform', 'to_template_nonlinear_transform_in'),
('inverse_composite_transform', 'from_template_nonlinear_transform_in'),
],
),
]
)

return wf


def create_ingress2qsirecon_wf(layouts, name="ingress2qsirecon_wf", base_dir=os.getcwd()):
def create_ingress2qsirecon_wf(layouts, name="ingress2qsirecon_wf", base_dir=os.getcwd(), skip_mni2009c_norm=False):
"""
Creates the overall ingress2qsirecon workflow.
Expand Down Expand Up @@ -330,7 +361,7 @@ def create_ingress2qsirecon_wf(layouts, name="ingress2qsirecon_wf", base_dir=os.
print(f"Subject(s) to run: {subjects_to_run}")

for subject_layout in layouts:
single_subject_wf = create_single_subject_wf(subject_layout)
single_subject_wf = create_single_subject_wf(subject_layout, skip_mni2009c_norm=skip_mni2009c_norm)
wf.add_nodes([single_subject_wf])

return wf
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "Ingress2QSIRecon"
version = "0.1.4"
version = "0.2.0"
description = "Tool to ingress data from other pipelines for use in QSIRecon"
authors = ["Steven Meisler <[email protected]>"]
readme = "README.md"
Expand Down

0 comments on commit 693d328

Please sign in to comment.