diff --git a/pipeline.py b/pipeline.py new file mode 100644 index 0000000..613f44d --- /dev/null +++ b/pipeline.py @@ -0,0 +1,172 @@ +import pathlib +import tifffile +import tqdm +import warnings +import textwrap + +from quant_test_v2 import( + validate_props, + validate_masks, + quantify_mask, + format_mask_table, + write_table, + validate_img, + load_marker_csv, + quantify_channel, + reorder_table_column +) + +class Pipeline(): + def __init__( + self, + mask_paths, + img_path=None, + marker_csv_path=None, + output_dir='.', + mask_props=None, + intensity_props=None, + img_preprocess_func=None, + img_preprocess_kwargs=None, + table_prefix='', + skip=None, + save_RAM=True + ): + self.mask_paths = mask_paths + self.img_path = img_path + self.marker_csv_path = marker_csv_path + self.output_dir = output_dir + self.mask_props = mask_props + self.intensity_props = intensity_props + self.img_preprocess_func = img_preprocess_func + self.img_preprocess_kwargs = img_preprocess_kwargs + self.table_prefix = table_prefix + self.skip = skip + self.save_RAM = save_RAM + + assert skip in [None, 'morphology', 'intensity'] + + validate_props(mask_props) + self.mask_shape = validate_masks(mask_paths) + + def run(self): + if self.skip != 'morphology': + self.run_mask() + if self.skip != 'intensity': + assert (self.img_path is not None) & (self.marker_csv_path is not None) + self.run_img() + + def run_mask(self): + if self.skip == 'intensity': + for path in self.mask_paths: + self._run_mask(path) + else: + self._run_mask(self.mask_paths[0]) + + def _run_mask(self, path, flat=True): + mask_name = pathlib.Path(path).name + mask = tifffile.imread(path, level=0, key=0) + print(f"Quantifying mask <{mask_name}>") + mask_table = quantify_mask(mask, mask_props=self.mask_props) + mask_table = format_mask_table(mask_table) + mask_table = reorder_table_column(mask_table) + write_table( + mask_table, self.output_dir, + mask_path=path, img_path=self.img_path, + prefix=self.table_prefix, suffix='_morphology', flat=flat + ) + print( + 'Completed.', + 'max id:', mask_table.index.max(), + 'number of ids:', len(mask_table.index), + '\n' + ) + + def run_img(self): + validate_props(self.intensity_props) + img_shape = validate_img(self.img_path) + assert self.mask_shape == img_shape[1:3], ( + f"Mask shape ({self.mask_shape}) does not match image shape ({img_shape})" + ) + channel_names = load_marker_csv(self.marker_csv_path) + num_channel_names = len(channel_names) + num_channels = img_shape[0] + message = f''' + Number of channel names ({num_channel_names}) does not match number of + channels ({num_channels}) in image: + {self.img_path} + ''' + assert num_channel_names <= num_channels, ( + textwrap.dedent(message) + ) + if num_channel_names != num_channels: + warnings.warn( + textwrap.dedent(message), + RuntimeWarning, stacklevel=2 + ) + self.channel_names = channel_names + if self.save_RAM: + for path in self.mask_paths: + self.masks = {pathlib.Path(path): tifffile.imread(path, level=0, key=0)} + self._run_img() + else: + self.masks = { + pathlib.Path(path): tifffile.imread(path, level=0, key=0) + for path in self.mask_paths + } + self._run_img() + + pass + + def _run_img(self, flat=True): + self.intensity_tables = {} + mask_names = [p.name for p in self.masks.keys()] + print(f"Quantifying channel with mask {mask_names}") + _intensity_props = self.intensity_props[:] if self.intensity_props else [] + for i in tqdm.tqdm(range(len(self.channel_names))): + is_first_channel = i == 0 + if is_first_channel: + intensity_props = _intensity_props + ['centroid'] + else: + intensity_props = _intensity_props + + for j, (mask_path, mask) in enumerate(self.masks.items()): + if j == 0: + channel_table, processed_img = quantify_channel( + mask, tifffile.imread(self.img_path, level=0, key=i), + intensity_props=intensity_props, + channel_name=self.channel_names[i], + preprocess_func=self.img_preprocess_func, + preprocess_func_kwargs=self.img_preprocess_kwargs, + return_img=True + ) + if len(self.masks) == 1: del processed_img + else: + channel_table = quantify_channel( + mask, processed_img, + intensity_props=intensity_props, + channel_name=self.channel_names[i], + preprocess_func=None, + preprocess_func_kwargs=None, + return_img=False + ) + channel_table = format_mask_table(channel_table) + if mask_path not in self.intensity_tables: + self.intensity_tables[mask_path] = channel_table + else: + self.intensity_tables[mask_path] = self.intensity_tables[mask_path].join(channel_table) + for mask_path, table in self.intensity_tables.items(): + write_table( + reorder_table_column(table), self.output_dir, + mask_path=mask_path, img_path=self.img_path, + prefix=self.table_prefix, suffix='_intensity', flat=flat + ) + print( + 'Completed.', + mask_path.name, + '-', + pathlib.Path(self.img_path).name, + '\n' + 'max id:', table.index.max(), + 'number of ids:', len(table.index), + '\n' + ) diff --git a/quant_test_v2.py b/quant_test_v2.py new file mode 100644 index 0000000..b7318b1 --- /dev/null +++ b/quant_test_v2.py @@ -0,0 +1,230 @@ +#Functions for reading in single cell imaging data +#Joshua Hess + +#Import necessary modules +import h5py +# TODO Create a reader for hdf5 images +import pandas as pd +import numpy as np +import skimage.measure as measure + +import tifffile +import pathlib + + +PROP_VALS = measure._regionprops.PROP_VALS + +NAME_MAP = { + 'label': 'CellID', + 'centroid-1': 'X_centroid', + 'centroid-0': 'Y_centroid', + 'area': 'Area', + 'major_axis_length': 'MajorAxisLength', + 'minor_axis_length': 'MinorAxisLength', + 'eccentricity': 'Eccentricity', + 'solidity': 'Solidity', + 'extent': 'Extent', + 'orientation': 'Orientation' +} + +def gini_index(mask, intensity): + x = intensity[mask] + sorted_x = np.sort(x) + n = len(x) + cumx = np.cumsum(sorted_x, dtype=float) + return (n + 1 - 2 * np.sum(cumx) / cumx[-1]) / n + +def median_intensity(mask, intensity): + return np.median(intensity[mask]) + +EXTRA_PROPS = { + 'gini_index': gini_index, + 'median_intensity': median_intensity +} + + +def quantify_channel( + mask, intensity_img, + intensity_props=None, channel_name=None, + preprocess_func=None, preprocess_func_kwargs=None, + return_img=False +): + if intensity_props is None: + intensity_props = [] + + intensity_props = ['label', 'mean_intensity'] + intensity_props[:] + # Look for regionprops in skimage + builtin_props = set(intensity_props).intersection(PROP_VALS) + # Otherwise look for them in this module + extra_props = [ + EXTRA_PROPS[p] for p in + set(intensity_props).difference(PROP_VALS) + ] + if preprocess_func is not None: + if preprocess_func_kwargs is None: + preprocess_func_kwargs = {} + intensity_img = preprocess_func(intensity_img, **preprocess_func_kwargs) + props_table = measure.regionprops_table( + mask, intensity_img, + properties=tuple(builtin_props), + extra_properties=extra_props + ) + ordered_props_table = order_dictionary_by_list(props_table, intensity_props) + del props_table + if channel_name is not None: + def format_name(prop_name): + if len(extra_props) == 0: return channel_name + else: return f"{channel_name}_{prop_name}" + renamed_table = { + k if k in NAME_MAP.keys() else format_name(k): v + for k, v in ordered_props_table.items() + } + else: + renamed_table = ordered_props_table + if return_img: + return renamed_table, intensity_img + else: + return renamed_table + + +def quantify_mask(mask, mask_props=None): + _all_mask_props = [ + "label", "centroid", "area", + "major_axis_length", "minor_axis_length", + "eccentricity", "solidity", "extent", "orientation" + ] + if mask_props is not None: + all_mask_props = set(_all_mask_props).union(mask_props) + _all_mask_props += mask_props + else: + all_mask_props = set(_all_mask_props) + table = measure.regionprops_table( + mask, + properties=all_mask_props + ) + ordered_table = order_dictionary_by_list(table, _all_mask_props) + return ordered_table + + +def order_dictionary_by_list(d, l): + return { + k: d[k] + for k in + sorted( + d.keys(), + key=lambda x: l.index(x.split('-')[0]) + ) + } + + +def load_marker_csv(csv_path): + csv_path = pathlib.Path(csv_path) + assert csv_path.suffix == '.csv', ( + f"{csv_path} is not a CSV file." + ) + marker_name_df = pd.read_csv(csv_path) + if 'marker_name' not in marker_name_df.columns: + print( + f"'marker_name' not in {csv_path.name} header\n" + f"Assuming legacy format, first column is used as marker names" + ) + marker_name_df = pd.read_csv( + csv_path, header=None, + usecols=[0], names=['marker_name'] + ) + has_duplicates = marker_name_df.duplicated(keep=False) + name_suffix = ( + marker_name_df.loc[has_duplicates] + .groupby('marker_name') + .cumcount() + .map(lambda x: f"_{x + 1}") + ) + marker_name_df.loc[has_duplicates, 'marker_name'] += name_suffix + return marker_name_df['marker_name'].to_list() + + +def validate_masks(mask_paths): + for p in mask_paths: + assert pathlib.Path(p).exists(), ( + f"{p} does not exist" + ) + mask_shapes = [] + for p in mask_paths: + with tifffile.TiffFile(pathlib.Path(p)) as tiff: + mask_shapes.append(tiff.series[0].shape) + assert len(set(mask_shapes)) == 1, ( + f"Masks must be the same shape\n" + '\n'.join([ f"{p} - {s}" + for p, s in zip(mask_paths, mask_shapes) + ]) + ) + shape = mask_shapes[0] + ndim = len(shape) + assert ndim == 2, ( + f"Only 2D masks are supported. Got a {ndim}D mask of shape {shape}" + ) + return shape + + +def validate_img(img_path): + suffixes = pathlib.Path(img_path).suffixes + assert '.ome' in suffixes + assert ('.tif' in suffixes) ^ ('.tiff' in suffixes) + with tifffile.TiffFile(img_path) as tiff: + img_shape = tiff.series[0].shape + ndim = len(img_shape) + assert (ndim == 2) ^ (ndim == 3), ( + f"Only 2D/3D images are supported. Got a {ndim}D image of shape {img_shape}" + ) + if len(img_shape) == 2: + img_shape = (1, *img_shape) + return img_shape + + +def format_mask_table(mask_table): + mask_table = pd.DataFrame(mask_table) + mask_table.rename(columns=NAME_MAP, inplace=True) + mask_table.set_index('CellID', inplace=True) + return mask_table + + +def reorder_table_column(df): + assert df.index.name == 'CellID' + col_ordering = tuple(NAME_MAP.values()) + sort_key = lambda x: col_ordering.index(x) if x in col_ordering else -1 + return df.reindex(columns=sorted(df.columns, key=sort_key)) + + +def write_table( + df, output_dir, + mask_path, img_path=None, + prefix='', suffix='', flat=False +): + if not img_path: + img_path = '' + mask_name, img_name = [ + pathlib.Path(p).stem.replace('.ome', '').replace('.', '_') + for p in [mask_path, img_path] + ] + + output_dir = pathlib.Path(output_dir) + if not flat: + output_dir = output_dir / img_name + + output_dir.mkdir(exist_ok=True, parents=True) + + output_filename = f"{prefix}{img_name}_{mask_name}{suffix}.csv" + output_path = output_dir / output_filename + df.to_csv(output_path) + return output_path + + +def validate_props(props): + if props is None: + return + for p in props: + assert (p in PROP_VALS) ^ (p in EXTRA_PROPS.keys()), ( + f"{p} is not a valid property. Available properties are " + f"[{', '.join(PROP_VALS)}] and [{', '.join(EXTRA_PROPS.keys())}]" + ) + return