diff --git a/pyproject.toml b/pyproject.toml index 3bd874ff0..f021b0834 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,8 +42,10 @@ dependencies = [ "scipy <= 1.14.0,>= 1.14.0", # nipype needs networkx, which needs scipy > 1.8.0 "seaborn", # for plots "sentry-sdk ~= 2.10.0", # for usage reports + "surfplot ~= 0.2.0", # for surface plots "templateflow ~= 24.2.0", "toml", + "vtk ~= 9.2.6", ] dynamic = ["version"] diff --git a/xcp_d/interfaces/plotting.py b/xcp_d/interfaces/plotting.py index 7abf8b8ce..86f288d14 100644 --- a/xcp_d/interfaces/plotting.py +++ b/xcp_d/interfaces/plotting.py @@ -3,16 +3,15 @@ """Plotting interfaces.""" import json import os +import re import matplotlib.pyplot as plt import nibabel as nb import numpy as np import pandas as pd import seaborn as sns -from matplotlib.cm import ScalarMappable -from matplotlib.colors import Normalize -from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec -from nilearn.plotting import plot_anat, plot_stat_map, plot_surf_stat_map +import svgutils.transform as sg +from nilearn.plotting import plot_anat, plot_stat_map from nipype import logging from nipype.interfaces.base import ( BaseInterfaceInputSpec, @@ -27,6 +26,7 @@ traits, ) from nipype.interfaces.fsl.base import FSLCommand, FSLCommandInputSpec +from surfplot import Plot from templateflow.api import get as get_template from xcp_d.utils.confounds import load_motion @@ -886,6 +886,14 @@ class _PlotCiftiParcellationInputSpec(BaseInterfaceInputSpec): mandatory=True, desc="Labels for the CIFTI files.", ) + atlas_files = traits.List( + traits.Str, + mandatory=True, + desc=( + "The atlas files. Same length as 'labels' and to be reduced to match " + "'cortical_atlases'." + ), + ) out_file = File( exists=False, mandatory=False, @@ -963,159 +971,150 @@ def _run_interface(self, runtime): rh = self.inputs.rh_underlay lh = self.inputs.lh_underlay - # Create Figure and GridSpec. - # One subplot for each file. Each file will then have four subplots, arranged in a square. - cortical_files = [ + data_files = [ self.inputs.in_files[i] for i, atlas in enumerate(self.inputs.labels) if atlas in self.inputs.cortical_atlases ] - cortical_atlases = [ - atlas for atlas in self.inputs.labels if atlas in self.inputs.cortical_atlases + keep_idx = [ + i + for i, atlas in enumerate(self.inputs.labels) + if atlas in self.inputs.cortical_atlases ] - n_files = len(cortical_files) - fig = plt.figure(constrained_layout=False) - - if n_files == 1: - fig.set_size_inches(6.5, 6) - # Add an additional column for the colorbar - gs = GridSpec(1, 2, figure=fig, width_ratios=[1, 0.05]) - gs_list = [gs[0, 0]] - subplots = [fig.add_subplot(gs) for gs in gs_list] - cbar_gs_list = [gs[0, 1]] - else: - nrows = np.ceil(n_files / 2).astype(int) - fig.set_size_inches(12.5, 6 * nrows) - # Add an additional column for the colorbar - gs = GridSpec(nrows, 3, figure=fig, width_ratios=[1, 1, 0.05]) - gs_list = [gs[i, j] for i in range(nrows) for j in range(2)] - subplots = [fig.add_subplot(gs) for gs in gs_list] - cbar_gs_list = [gs[i, 2] for i in range(nrows)] - - for subplot in subplots: - subplot.set_axis_off() + atlas_names = [self.inputs.labels[i] for i in keep_idx] + atlas_files = [self.inputs.atlas_files[i] for i in keep_idx] vmin, vmax = self.inputs.vmin, self.inputs.vmax - threshold = 0.01 if vmin == vmax: - threshold = None - # Define vmin and vmax based on all of the files vmin, vmax = np.inf, -np.inf - for cortical_file in cortical_files: - img_data = nb.load(cortical_file).get_fdata() + for data_file in data_files: + img_data = nb.load(data_file).get_fdata() vmin = np.min([np.nanmin(img_data), vmin]) vmax = np.max([np.nanmax(img_data), vmax]) vmin = 0 - for i_file in range(n_files): - subplot = subplots[i_file] - subplot.set_title(cortical_atlases[i_file]) - subplot_gridspec = gs_list[i_file] - - # Create 4 Axes (2 rows, 2 columns) from the subplot - gs_inner = GridSpecFromSubplotSpec(2, 2, subplot_spec=subplot_gridspec) - inner_subplots = [ - fig.add_subplot(gs_inner[i, j], projection="3d") - for i in range(2) - for j in range(2) - ] + figure_files = [] + for i_file, atlas_name in enumerate(atlas_names): + data_file = data_files[i_file] + atlas_file = atlas_files[i_file] + temp_file = fname_presuffix( + f"{atlas_name}.svg", + newpath=runtime.cwd, + ) + lh_img = nb.load(lh) + rh_img = nb.load(rh) + LOGGER.info(f"Underlay files: {lh} {rh}") + LOGGER.info( + "Underlay sizes: " + f"{lh_img.agg_data()[0].shape[0]} {rh_img.agg_data()[0].shape[0]}" + ) + + plot_obj = Plot(lh, rh) - img = nb.load(cortical_files[i_file]) - img_data = img.get_fdata() - img_axes = [img.header.get_axis(i) for i in range(img.ndim)] - lh_surf_data = surf_data_from_cifti( + LOGGER.info(f"Adding {atlas_name} to the plot.") + data_img = nb.load(data_file) + img_data = data_img.get_fdata() + img_axes = [data_img.header.get_axis(i) for i in range(data_img.ndim)] + lh_data = surf_data_from_cifti( img_data, img_axes[1], "CIFTI_STRUCTURE_CORTEX_LEFT", ) - rh_surf_data = surf_data_from_cifti( + rh_data = surf_data_from_cifti( img_data, img_axes[1], "CIFTI_STRUCTURE_CORTEX_RIGHT", ) - - plot_surf_stat_map( - lh, - lh_surf_data, - threshold=threshold, - vmin=vmin, - vmax=vmax, - hemi="left", - view="lateral", - engine="matplotlib", - cmap="cool", - colorbar=False, - axes=inner_subplots[0], - figure=fig, - ) - plot_surf_stat_map( - rh, - rh_surf_data, - threshold=threshold, - vmin=vmin, - vmax=vmax, - hemi="right", - view="lateral", - engine="matplotlib", + LOGGER.info(f"Data sizes: {lh_data.shape} {rh_data.shape}") + plot_obj.add_layer( + {"left": np.squeeze(lh_data), "right": np.squeeze(rh_data)}, cmap="cool", - colorbar=False, - axes=inner_subplots[1], - figure=fig, + color_range=(vmin, vmax), + cbar=True, ) - plot_surf_stat_map( - lh, - lh_surf_data, - threshold=threshold, - vmin=vmin, - vmax=vmax, - hemi="left", - view="medial", - engine="matplotlib", - cmap="cool", - colorbar=False, - axes=inner_subplots[2], - figure=fig, + + # Add parcel boundaries + atlas_img = nb.load(atlas_file) + atlas_data = atlas_img.get_fdata() + atlas_axes = [atlas_img.header.get_axis(i) for i in range(atlas_img.ndim)] + lh_atlas = surf_data_from_cifti( + atlas_data, + atlas_axes[1], + "CIFTI_STRUCTURE_CORTEX_LEFT", ) - plot_surf_stat_map( - rh, - rh_surf_data, - threshold=threshold, - vmin=vmin, - vmax=vmax, - hemi="right", - view="medial", - engine="matplotlib", - cmap="cool", - colorbar=False, - axes=inner_subplots[3], - figure=fig, + rh_atlas = surf_data_from_cifti( + atlas_data, + atlas_axes[1], + "CIFTI_STRUCTURE_CORTEX_RIGHT", ) - for ax in inner_subplots: - ax.set_rasterized(True) - - # Create a ScalarMappable with the "cool" colormap and the specified vmin and vmax - sm = ScalarMappable(cmap="cool", norm=Normalize(vmin=vmin, vmax=vmax)) - - for colorbar_gridspec in cbar_gs_list: - colorbar_ax = fig.add_subplot(colorbar_gridspec) - # Add a colorbar to colorbar_ax using the ScalarMappable - fig.colorbar(sm, cax=colorbar_ax) + LOGGER.info(f"Atlas sizes: {lh_atlas.shape} {rh_atlas.shape}") + plot_obj.add_layer( + {"left": np.squeeze(lh_atlas), "right": np.squeeze(rh_atlas)}, + cmap="gray", + as_outline=True, + cbar=False, + ) + fig = plot_obj.build() + fig.suptitle(atlas_name, fontsize=16) + fig.tight_layout() + fig.savefig(temp_file) + figure_files.append(temp_file) + plt.close(fig) + + # Now build the combined figure + # Load SVG files and get their sizes + widths, heights = [], [] + for figure_file in figure_files: + svg_obj = sg.fromfile(figure_file) + fig = svg_obj.getroot() + + # Original size is represented as string (example: '600px'); convert to float + width = float(re.sub("[^0-9]", "", svg_obj.width)) + height = float(re.sub("[^0-9]", "", svg_obj.height)) + widths.append(width) + heights.append(height) + + cell_width, cell_height = max(widths), max(heights) + max_columns = 2 + if len(figure_files) == 1: + total_width = cell_width + total_height = cell_height + loc_idx = [(0, 0)] + else: + n_rows = int(np.ceil(len(figure_files) / max_columns)) + n_columns = max_columns + total_width = cell_width * n_columns + total_height = cell_height * n_rows + + loc_idx = [] + for i_row in range(n_rows): + row_idx = [j for j in range(len(figure_files)) if (j // max_columns) == i_row] + for j_col in range(len(row_idx)): + loc_idx.append((i_row, j_col)) + + # Create new SVG figure + new_svg = sg.SVGFigure(width=f"{total_width}px", height=f"{total_height}px") + # for some reason, the width and height params aren't retained, so set them again + new_svg.set_size((f"{total_width}px", f"{total_height}px")) + + # Add each SVG to the new figure + for i_fig, figure_file in enumerate(figure_files): + svg_obj = sg.fromfile(figure_file) + fig = svg_obj.getroot() + offset0 = loc_idx[i_fig][1] * cell_width + offset1 = loc_idx[i_fig][0] * cell_height + fig.moveto(offset0, offset1) + new_svg.append(fig) self._results["out_file"] = fname_presuffix( - cortical_files[0], + data_files[0], suffix="_file.svg", newpath=runtime.cwd, use_ext=False, ) - fig.savefig( - self._results["out_file"], - bbox_inches="tight", - pad_inches=None, - format="svg", - ) - plt.close(fig) + new_svg.save(self._results["out_file"]) return runtime @@ -1181,106 +1180,31 @@ def _run_interface(self, runtime): rh = self.inputs.rh_underlay lh = self.inputs.lh_underlay - cifti = nb.load(self.inputs.in_file) - cifti_data = cifti.get_fdata() - cifti_axes = [cifti.header.get_axis(i) for i in range(cifti.ndim)] - - # Create Figure and GridSpec. - fig = plt.figure(constrained_layout=False) - fig.set_size_inches(6.5, 6) - # Add an additional column for the colorbar - gs = GridSpec(1, 2, figure=fig, width_ratios=[1, 0.05]) - subplot_gridspec = gs[0, 0] - subplot = fig.add_subplot(subplot_gridspec) - colorbar_gridspec = gs[0, 1] - - subplot.set_axis_off() - - # Create 4 Axes (2 rows, 2 columns) from the subplot - gs_inner = GridSpecFromSubplotSpec(2, 2, subplot_spec=subplot_gridspec) - inner_subplots = [ - fig.add_subplot(gs_inner[i, j], projection="3d") for i in range(2) for j in range(2) - ] + data_img = nb.load(self.inputs.in_file) + img_data = data_img.get_fdata() + data_axes = [data_img.header.get_axis(i) for i in range(data_img.ndim)] - lh_surf_data = surf_data_from_cifti( - cifti_data, - cifti_axes[1], + plot_obj = Plot(lh, rh) + lh_data = surf_data_from_cifti( + img_data, + data_axes[1], "CIFTI_STRUCTURE_CORTEX_LEFT", ) - rh_surf_data = surf_data_from_cifti( - cifti_data, - cifti_axes[1], + rh_data = surf_data_from_cifti( + img_data, + data_axes[1], "CIFTI_STRUCTURE_CORTEX_RIGHT", ) + vmax = np.nanmax([np.nanmax(lh_data), np.nanmax(rh_data)]) + vmin = np.nanmin([np.nanmin(lh_data), np.nanmin(rh_data)]) - vmax = np.nanmax([np.nanmax(lh_surf_data), np.nanmax(rh_surf_data)]) - vmin = np.nanmin([np.nanmin(lh_surf_data), np.nanmin(rh_surf_data)]) - - plot_surf_stat_map( - lh, - lh_surf_data, - vmin=vmin, - vmax=vmax, - hemi="left", - view="lateral", - engine="matplotlib", - cmap="cool", - colorbar=False, - axes=inner_subplots[0], - figure=fig, - ) - plot_surf_stat_map( - rh, - rh_surf_data, - vmin=vmin, - vmax=vmax, - hemi="right", - view="lateral", - engine="matplotlib", - cmap="cool", - colorbar=False, - axes=inner_subplots[1], - figure=fig, - ) - plot_surf_stat_map( - lh, - lh_surf_data, - vmin=vmin, - vmax=vmax, - hemi="left", - view="medial", - engine="matplotlib", - cmap="cool", - colorbar=False, - axes=inner_subplots[2], - figure=fig, - ) - plot_surf_stat_map( - rh, - rh_surf_data, - vmin=vmin, - vmax=vmax, - hemi="right", - view="medial", - engine="matplotlib", + plot_obj.add_layer( + {"left": np.squeeze(lh_data), "right": np.squeeze(rh_data)}, cmap="cool", - colorbar=False, - axes=inner_subplots[3], - figure=fig, + color_range=(vmin, vmax), + cbar=True, ) - - inner_subplots[0].set_title("Left Hemisphere", fontsize=10) - inner_subplots[1].set_title("Right Hemisphere", fontsize=10) - - for ax in inner_subplots: - ax.set_rasterized(True) - - # Create a ScalarMappable with the "cool" colormap and the specified vmin and vmax - sm = ScalarMappable(cmap="cool", norm=Normalize(vmin=vmin, vmax=vmax)) - - colorbar_ax = fig.add_subplot(colorbar_gridspec) - # Add a colorbar to colorbar_ax using the ScalarMappable - fig.colorbar(sm, cax=colorbar_ax) + fig = plot_obj.build() self._results["out_file"] = fname_presuffix( self.inputs.in_file, diff --git a/xcp_d/workflows/connectivity.py b/xcp_d/workflows/connectivity.py index 1cf912f31..d5f02f012 100644 --- a/xcp_d/workflows/connectivity.py +++ b/xcp_d/workflows/connectivity.py @@ -674,6 +674,7 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit workflow.connect([ (inputnode, plot_coverage, [ ("atlases", "labels"), + ("atlas_files", "atlas_files"), ("lh_midthickness", "lh_underlay"), ("rh_midthickness", "rh_underlay"), ]), @@ -857,6 +858,7 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit workflow.connect([ (inputnode, plot_parcellated_reho, [ ("atlases", "labels"), + ("atlas_files", "atlas_files"), ("lh_midthickness", "lh_underlay"), ("rh_midthickness", "rh_underlay"), ]), @@ -912,6 +914,7 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit workflow.connect([ (inputnode, plot_parcellated_alff, [ ("atlases", "labels"), + ("atlas_files", "atlas_files"), ("lh_midthickness", "lh_underlay"), ("rh_midthickness", "rh_underlay"), ]),