Skip to content

Commit

Permalink
Update again.
Browse files Browse the repository at this point in the history
  • Loading branch information
tsalo committed Nov 14, 2023
1 parent 4b718b3 commit dfb09c5
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 91 deletions.
5 changes: 1 addition & 4 deletions aslprep/interfaces/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
)
from nipype.utils.filemanip import fname_presuffix
from niworkflows.utils.timeseries import _cifti_timeseries, _nifti_timeseries
from niworkflows.viz.plots import fMRIPlot

from aslprep.utils.plotting import CBFPlot, CBFtsPlot
from aslprep.utils.plotting import CBFPlot, CBFtsPlot, fMRIPlot


class _ASLSummaryInputSpec(BaseInterfaceInputSpec):
Expand Down Expand Up @@ -119,8 +118,6 @@ def _run_interface(self, runtime):
units=units,
nskip=self.inputs.drop_trs,
paired_carpet=has_cifti,
# The main change from fMRIPrep's usage is that detrend is False for ASL.
detrend=False,
).plot()
fig.savefig(self._results["out_file"], bbox_inches="tight")
return runtime
Expand Down
193 changes: 106 additions & 87 deletions aslprep/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from nilearn._utils.niimg import _safe_get_data, load_niimg
from nilearn._utils.niimg_conversions import check_niimg_4d
from niworkflows import NIWORKFLOWS_LOG
from niworkflows.interfaces.plotting import _get_tr
from niworkflows.viz.plots import confoundplot, spikesplot
from niworkflows.viz.utils import (
compose_view,
Expand All @@ -24,92 +23,6 @@
from svgutils.transform import SVGFigure


class ASLPlot:
"""Generates the ASL Summary Plot."""

__slots__ = ("func_file", "mask_data", "tr", "seg_data", "confounds", "spikes")

def __init__(
self,
func_file,
mask_file=None,
data=None,
confounds_file=None,
seg_file=None,
tr=None,
usecols=None,
units=None,
vlines=None,
spikes_files=None,
):
func_img = nb.load(func_file)
self.func_file = func_file
self.tr = tr or _get_tr(func_img)
self.mask_data = None
self.seg_data = None

if not isinstance(func_img, nb.Cifti2Image):
self.mask_data = nb.fileslice.strided_scalar(func_img.shape[:3], np.uint8(1))
if mask_file:
self.mask_data = np.asanyarray(nb.load(mask_file).dataobj).astype("uint8")
if seg_file:
self.seg_data = np.asanyarray(nb.load(seg_file).dataobj)

if units is None:
units = {}
if vlines is None:
vlines = {}
self.confounds = {}
if data is None and confounds_file:
data = pd.read_csv(confounds_file, sep=r"[\t\s]+", usecols=usecols, index_col=False)

if data is not None:
for name in data.columns.ravel():
self.confounds[name] = {
"values": data[[name]].values.ravel().tolist(),
"units": units.get(name),
"cutoff": vlines.get(name),
}

self.spikes = []
if spikes_files:
self.spikes.extend((np.loadtxt(sp_file), None, False) for sp_file in spikes_files)

def plot(self, figure=None):
"""Generate the plot."""
sns.set_style("whitegrid")
sns.set_context("paper", font_scale=0.8)

if figure is None:
figure = plt.gcf()

nconfounds = len(self.confounds)
nspikes = len(self.spikes)
nrows = 1 + nconfounds + nspikes

# Create grid
grid = mgs.GridSpec(
nrows, 1, wspace=0.0, hspace=0.05, height_ratios=[1] * (nrows - 1) + [5]
)

grid_id = 0
for tsz, name, iszs in self.spikes:
spikesplot(tsz, title=name, outer_gs=grid[grid_id], tr=self.tr, zscored=iszs)
grid_id += 1

if self.confounds:
palette = color_palette("husl", nconfounds)

for i, (name, kwargs) in enumerate(self.confounds.items()):
tseries = kwargs.pop("values")
confoundplot(tseries, grid[grid_id], tr=self.tr, color=palette[i], name=name, **kwargs)
grid_id += 1

plot_carpet(self.func_file, atlaslabels=self.seg_data, subplot=grid[-1], tr=self.tr)
# spikesplot_cb([0.7, 0.78, 0.2, 0.008])
return figure


class CBFtsPlot(object):
"""Generate the CBF time series Summary Plot."""

Expand Down Expand Up @@ -462,6 +375,112 @@ def confoundplotx(
return ax_ts, gs


class fMRIPlot:
"""Generates the fMRI Summary Plot."""

__slots__ = (
"timeseries",
"segments",
"tr",
"confounds",
"spikes",
"nskip",
"sort_carpet",
"paired_carpet",
)

def __init__(
self,
timeseries,
segments,
confounds=None,
conf_file=None,
tr=None,
usecols=None,
units=None,
vlines=None,
spikes_files=None,
nskip=0,
sort_carpet=True,
paired_carpet=False,
):
self.timeseries = timeseries
self.segments = segments
self.tr = tr
self.nskip = nskip
self.sort_carpet = sort_carpet
self.paired_carpet = paired_carpet

if units is None:
units = {}
if vlines is None:
vlines = {}
self.confounds = {}
if confounds is None and conf_file:
confounds = pd.read_csv(conf_file, sep=r"[\t\s]+", usecols=usecols, index_col=False)

if confounds is not None:
for name in confounds.columns:
self.confounds[name] = {
"values": confounds[[name]].values.squeeze().tolist(),
"units": units.get(name),
"cutoff": vlines.get(name),
}

self.spikes = []
if spikes_files:
for sp_file in spikes_files:
self.spikes.append((np.loadtxt(sp_file), None, False))

def plot(self, figure=None):
"""Main plotter"""
import seaborn as sns
from niworkflows.viz.plots import plot_carpet

sns.set_style("whitegrid")
sns.set_context("paper", font_scale=0.8)

if figure is None:
figure = plt.gcf()

nconfounds = len(self.confounds)
nspikes = len(self.spikes)
nrows = 1 + nconfounds + nspikes

# Create grid
grid = mgs.GridSpec(
nrows, 1, wspace=0.0, hspace=0.05, height_ratios=[1] * (nrows - 1) + [5]
)

grid_id = 0
for tsz, name, iszs in self.spikes:
spikesplot(tsz, title=name, outer_gs=grid[grid_id], tr=self.tr, zscored=iszs)
grid_id += 1

if self.confounds:
from seaborn import color_palette

palette = color_palette("husl", nconfounds)

for i, (name, kwargs) in enumerate(self.confounds.items()):
tseries = kwargs.pop("values")
confoundplot(tseries, grid[grid_id], tr=self.tr, color=palette[i], name=name, **kwargs)
grid_id += 1

plot_carpet(
self.timeseries,
segments=self.segments,
subplot=grid[-1],
tr=self.tr,
sort_rows=self.sort_carpet,
drop_trs=self.nskip,
cmap="paired" if self.paired_carpet else None,
# This is the only modification we need for ASLPrep
detrend=False,
)
return figure


def plot_carpet(
func,
atlaslabels=None,
Expand Down

0 comments on commit dfb09c5

Please sign in to comment.