diff --git a/src/neuron_proofreader/config.py b/src/neuron_proofreader/config.py index 1f8baed..ffab1de 100644 --- a/src/neuron_proofreader/config.py +++ b/src/neuron_proofreader/config.py @@ -8,7 +8,7 @@ aspects of a system involving graphs, proposals, and machine learning (ML). """ -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Tuple @@ -24,13 +24,6 @@ class GraphConfig: Scaling factors applied to xyz coordinates to account for anisotropy of microscope. Note this instance of "anisotropy" is only used while reading SWC files. Default is (1.0, 1.0, 1.0). - complex_bool : bool - Indication of whether to generate complex proposals, meaning proposals - between leaf and non-leaf nodes. Default is False. - long_range_bool : bool, optional - Indication of whether to generate simple proposals within a scaled - distance of "search_radius" from leaves without any proposals. Default - is False. min_size : float, optional Minimum path length (in microns) of swc files which are stored as connected components in the FragmentsGraph. Default is 30. @@ -48,27 +41,23 @@ class GraphConfig: remove_high_risk_merges : bool, optional Indication of whether to remove high risk merge sites (i.e. close branching points). Default is False. - smooth_bool : bool, optional - Indication of whether to smooth branches in the graph. Default is - True. trim_endpoints_bool : bool, optional Indication of whether to endpoints of branches with exactly one proposal. Default is True. """ - anisotropy: Tuple[float] = field(default_factory=tuple) - complex_bool: bool = False - long_range_bool: bool = True - min_size: float = 30.0 - min_size_with_proposals: float = 0 - node_spacing: int = 5 + anisotropy: Tuple[float, float, float] = (1.0, 1.0, 1.0) + max_proposals_per_leaf: int = 3 + min_size: float = 40.0 + min_size_with_proposals: float = 40.0 + node_spacing: int = 1 proposals_per_leaf: int = 3 prune_depth: float = 24.0 - remove_doubles: bool = False + remove_doubles: bool = True remove_high_risk_merges: bool = False search_radius: float = 20.0 - smooth_bool: bool = True trim_endpoints_bool: bool = True + verbose: bool = True @dataclass @@ -87,7 +76,9 @@ class MLConfig: batch_size: int = 64 brightness_clip: int = 400 device: str = "cuda" - patch_shape : tuple = (96, 96, 96) + patch_shape: tuple = (96, 96, 96) + shuffle: bool = False + transform: bool = False threshold: float = 0.8 @@ -112,5 +103,5 @@ def __init__(self, graph_config: GraphConfig, ml_config: MLConfig): An instance of the "MLConfig" class that includes configuration parameters for machine learning models. """ - self.graph_config = graph_config - self.ml_config = ml_config + self.graph = graph_config + self.ml = ml_config diff --git a/src/neuron_proofreader/machine_learning/gnn_models.py b/src/neuron_proofreader/machine_learning/gnn_models.py index 1db956e..2a3e577 100644 --- a/src/neuron_proofreader/machine_learning/gnn_models.py +++ b/src/neuron_proofreader/machine_learning/gnn_models.py @@ -16,7 +16,7 @@ import torch.nn.init as init from neuron_proofreader.machine_learning.vision_models import CNN3D -from neuron_proofreader.split_proofreading import feature_extraction +from neuron_proofreader.split_proofreading import split_feature_extraction from neuron_proofreader.utils.ml_util import FeedForwardNet @@ -125,7 +125,7 @@ def init_node_embedding(output_dim): features. """ # Get feature dimensions - node_input_dims = feature_extraction.get_feature_dict() + node_input_dims = split_feature_extraction.get_feature_dict() dim_b = node_input_dims["branch"] dim_p = node_input_dims["proposal"] diff --git a/src/neuron_proofreader/proposal_graph.py b/src/neuron_proofreader/proposal_graph.py index 8a68367..e7272de 100644 --- a/src/neuron_proofreader/proposal_graph.py +++ b/src/neuron_proofreader/proposal_graph.py @@ -31,7 +31,9 @@ trim_endpoints_at_proposal ) from neuron_proofreader.utils import ( - geometry_util as geometry, graph_util as gutil, util, + geometry_util as geometry, + graph_util as gutil, + util, ) @@ -48,7 +50,7 @@ class ProposalGraph(SkeletonGraph): def __init__( self, anisotropy=(1.0, 1.0, 1.0), - filter_transitive_proposals=False, + gt_path=None, max_proposals_per_leaf=3, min_size=0, min_size_with_proposals=0, @@ -57,7 +59,7 @@ def __init__( remove_high_risk_merges=False, segmentation_path=None, soma_centroids=None, - verbose=False, + verbose=True, ): """ Instantiates a ProposalGraph object. @@ -93,6 +95,7 @@ def __init__( # Instance attributes - Graph self.anisotropy = anisotropy self.component_id_to_swc_id = dict() + self.gt_path = gt_path self.leaf_kdtree = None self.soma_ids = set() self.verbose = verbose @@ -192,6 +195,42 @@ def _add_edge(self, edge_id, attrs): self.add_edge(i, j, radius=attrs["radius"], xyz=attrs["xyz"]) self.xyz_to_edge.update({tuple(xyz): edge_id for xyz in attrs["xyz"]}) + def connect_soma_fragments(self, soma_centroids): + merge_cnt = 0 + soma_cnt = 0 + self.set_kdtree() + for soma_xyz in soma_centroids: + node_ids = self.find_fragments_near_xyz(soma_xyz, 25) + if len(node_ids) > 1: + # Find closest node to soma location + soma_cnt += 1 + best_dist = np.inf + best_node = None + for i in node_ids: + dist = geometry.dist(soma_xyz, self.node_xyz[i]) + if dist < best_dist: + best_dist = dist + best_node = i + soma_component_id = self.node_component_id[best_node] + self.soma_ids.add(soma_component_id) + node_ids.remove(best_node) + + # Merge fragments to soma + soma_xyz = self.node_xyz[best_node] + for i in node_ids: + attrs = { + "radius": np.array([2, 2]), + "xyz": np.array([soma_xyz, self.node_xyz[i]]), + } + self._add_edge((best_node, i), attrs) + self.update_component_ids(soma_component_id, i) + merge_cnt += 1 + + # Summarize results + results = [f"# Somas Connected: {soma_cnt}"] + results.append(f"# Soma Fragments Merged: {merge_cnt}") + return "\n".join(results) + def relabel_nodes(self): """ Reassigns contiguous node IDs and update all dependent structures. @@ -296,30 +335,8 @@ def get_kdtree(self, node_type=None): else: return KDTree(list(self.xyz_to_edge.keys())) - def query_kdtree(self, xyz, d, node_type=None): - """ - Parameters - ---------- - xyz : int - Node id. - d : float - Distance from "xyz" that is searched. - - Returns - ------- - generator[tuple] - Generator that generates the xyz coordinates cooresponding to all - nodes within a distance of "d" from "xyz". - """ - if node_type == "leaf": - return geometry.query_ball(self.leaf_kdtree, xyz, d) - elif node_type == "proposal": - return geometry.query_ball(self.proposal_kdtree, xyz, d) - else: - return geometry.query_ball(self.kdtree, xyz, d) - # --- Proposal Generation --- - def generate_proposals(self, search_radius, gt_graph=None): + def generate_proposals(self, search_radius): """ Generates proposals from leaf nodes. @@ -330,16 +347,17 @@ def generate_proposals(self, search_radius, gt_graph=None): gt_graph : networkx.Graph, optional Ground truth graph. Default is None. """ - # Generate proposals + # Proposal pipeline proposals = self.proposal_generator(search_radius) + self.search_radius = search_radius self.store_proposals(proposals) self.trim_proposals() # Set groundtruth - if gt_graph: + if self.gt_path: + gt_graph = ProposalGraph(anisotropy=self.anisotropy) + gt_graph.load(self.gt_path) self.gt_accepts = groundtruth_generation.run(gt_graph, self) - else: - self.gt_accepts = set() def add_proposal(self, i, j): """ @@ -600,7 +618,7 @@ def n_nearby_leafs(self, proposal, radius): a proposal. """ xyz = self.proposal_midpoint(proposal) - return len(self.query_kdtree(xyz, radius, "leaf")) - 1 + return len(geometry.query_ball(self.leaf_kdtree, xyz, radius)) - 1 # --- Helpers --- def node_attr(self, i, key): @@ -662,7 +680,8 @@ def edge_length(self, edge): def find_fragments_near_xyz(self, query_xyz, max_dist): hits = dict() - for xyz in self.query_kdtree(query_xyz, max_dist): + xyz_list = geometry.query_ball(self.kdtree, query_xyz, max_dist) + for xyz in xyz_list: i, j = self.xyz_to_edge[tuple(xyz)] dist_i = geometry.dist(self.node_xyz[i], query_xyz) dist_j = geometry.dist(self.node_xyz[j], query_xyz) diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index fcf7717..67b37db 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -592,6 +592,21 @@ def find_closest_node(self, xyz): _, node = self.kdtree.query(xyz) return node + def get_summary(self, prefix=""): + # Compute values + n_components = format(nx.number_connected_components(self), ",") + n_nodes = format(self.number_of_nodes(), ",") + n_edges = format(self.number_of_edges(), ",") + memory = util.get_memory_usage() + + # Compile results + summary = [f"{prefix} Graph"] + summary.append(f"# Connected Components: {n_components}") + summary.append(f"# Nodes: {n_nodes}") + summary.append(f"# Edges: {n_edges}") + summary.append(f"Memory Consumption: {memory:.2f} GBs") + return "\n".join(summary) + def path_length(self, max_depth=np.inf, root=None): """ Computes the path length of the connected component that contains the diff --git a/src/neuron_proofreader/split_proofreading/split_datasets.py b/src/neuron_proofreader/split_proofreading/split_datasets.py index adee321..84391dc 100644 --- a/src/neuron_proofreader/split_proofreading/split_datasets.py +++ b/src/neuron_proofreader/split_proofreading/split_datasets.py @@ -10,13 +10,13 @@ from torch.utils.data import IterableDataset -import numpy as np import os import pandas as pd from neuron_proofreader.proposal_graph import ProposalGraph from neuron_proofreader.machine_learning.augmentation import ImageTransforms from neuron_proofreader.machine_learning.subgraph_sampler import ( + SeededSubgraphSampler, SubgraphSampler ) from neuron_proofreader.split_proofreading.split_feature_extraction import ( @@ -26,114 +26,70 @@ from neuron_proofreader.utils import geometry_util, img_util, util +# --- Single Brain Dataset --- class FragmentsDataset(IterableDataset): - """ - A dataset class for storing graphs used to train models to perform split - correction. Graphs are stored in the "self.graphs" attribute, which is a - dictionary containing the followin items: - - Key: (brain_id, segmentation_id, example_id) - - Value: graph that is an instance of ProposalGraph - - This dataset is populated using the "self.add_graph" method, which - requires the following inputs: - (1) key: Unique identifier of graph. - (2) gt_pointer: Path to ground truth SWC files. - (2) pred_pointer: Path to predicted SWC files. - (3) img_path: Path to whole-brain image stored in cloud bucket. - - Note: This dataset supports graphs from multiple whole-brain datasets. - """ def __init__( self, - config, - brightness_clip=400, - patch_shape=(128, 128, 128), - shuffle=True, - transform=False - ): - """ - Instantiates a FragmentsDataset object. - - Parameters - ---------- - config : GraphConfig - Config object that stores parameters used to build graphs. - patch_shape : Tuple[int], optional - Shape of image patch input to a vision model. Default is (128, 128, - 128). - transform : bool, optional - Indication of whether to apply augmentation to input images. - Default is False. - """ - # Instance attributes - self.brightness_clip = brightness_clip - self.config = config - self.feature_extractors = dict() - self.graphs = dict() - self.patch_shape = patch_shape - self.shuffle = shuffle - self.transform = ImageTransforms() if transform else False - - # --- Load Data --- - def add_graph( - self, - key, - gt_pointer, - pred_pointer, + fragments_path, img_path, + config, + gt_path=None, metadata_path=None, - segmentation_path=None + segmentation_path=None, + soma_centroids=None ): """ - Loads a fragments graph, generates proposals, and initializes feature - extraction. + Instantiates a FragmentsDataset object. Parameters ---------- - key : Tuple[str] - Unique identifier used to register the graph and its feature - pipeline. - gt_pointer : str - Path to ground-truth SWC files to be loaded. - pred_pointer : str + fragments_path : str Path to predicted SWC files to be loaded. img_path : str - Path to the raw image associated with the graph. + Path to the raw image associated with the fragments. + config : Config + ... + gt_pointer : str, optional + Path to ground-truth SWC files to be loaded. Default is None. metadata_path : str ... segmentation_path : str - Path to the segmentation mask associated with the graph. + Path to the segmentation that fragments were generated from. + Default is None. """ - # Add graph - gt_graph = self.load_graph(gt_pointer, is_gt=True) - self.graphs[key] = self.load_graph(pred_pointer, metadata_path) - self.graphs[key].generate_proposals( - self.config.search_radius, gt_graph=gt_graph + # Instance attributes + self.config = config + self.transform = ImageTransforms() if config.ml.transform else False + self.graph = self._load_graph( + fragments_path, metadata_path, segmentation_path, soma_centroids ) - # Generate features - self.feature_extractors[key] = FeaturePipeline( - self.graphs[key], + # Feature extractor + self.feature_extractor = FeaturePipeline( + self.graph, img_path, - self.config.search_radius, - brightness_clip=self.brightness_clip, - patch_shape=self.patch_shape, + brightness_clip=self.config.ml.brightness_clip, + patch_shape=self.config.ml.patch_shape, segmentation_path=segmentation_path ) - def load_graph(self, swc_pointer, is_gt=False, metadata_path=None): + def _load_graph( + self, fragments_path, metadata_path, segmentation_path, soma_centroids + ): """ - Loads a graph by reading and processing SWC files specified by - "swc_pointer". + Loads a graph by reading and processing SWC files specified by the + given path. Parameters ---------- - swc_pointer : str + fragments_path : str Path to SWC files to be loaded. metadata_path : str Patch to JSON file containing metadata on block that fragments were extracted from. + soma_centroids : List[Tuple[float]] + List of physical coordinates that represent soma centers. Returns ------- @@ -142,14 +98,22 @@ def load_graph(self, swc_pointer, is_gt=False, metadata_path=None): """ # Build graph graph = ProposalGraph( - anisotropy=self.config.anisotropy, - min_size=self.config.min_size + anisotropy=self.config.graph.anisotropy, + min_size=self.config.graph.min_size, + min_size_with_proposals=self.config.graph.min_size_with_proposals, + node_spacing=self.config.graph.node_spacing, + prune_depth=self.config.graph.prune_depth, + remove_high_risk_merges=self.config.graph.remove_high_risk_merges, + segmentation_path=segmentation_path, + soma_centroids=soma_centroids ) - graph.load(swc_pointer) + graph.load(fragments_path) # Post process fragments - if not is_gt: + if metadata_path: self.clip_fragments(graph, metadata_path) + + if self.config.graph.remove_doubles: geometry_util.remove_doubles(graph, 200) return graph @@ -165,27 +129,93 @@ def __iter__(self): targets : torch.Tensor Ground truth labels. """ - # Initialize subgraph samplers - samplers = dict() - for key, graph in self.graphs.items(): - samplers[key] = iter(SubgraphSampler(graph, max_proposals=32)) + for subgraph in self.get_sampler(): + yield self.get_inputs(subgraph) + + def get_inputs(self, subgraph): + features = self.feature_extractor(subgraph) + data = HeteroGraphData(features) + if self.graph.gt_path: + return data.get_inputs(), data.get_targets() + else: + return data.get_inputs() + + # --- Helpers --- + @staticmethod + def clip_fragments(graph, metadata_path): + # Extract bounding box + bucket_name, path = util.parse_cloud_path(metadata_path) + metadata = util.read_json_from_gcs(bucket_name, path) + origin = metadata["chunk_origin"][::-1] + shape = metadata["chunk_shape"][::-1] + + # Clip graph + nodes = list() + for i in graph.nodes: + voxel = graph.get_voxel(i) + if not img_util.is_contained(voxel - origin, shape): + nodes.append(i) + graph.remove_nodes_from(nodes) + graph.relabel_nodes() - # Iterate over dataset + def get_sampler(self): + batch_size = self.config.ml.batch_size + if len(self.graph.soma_ids) > 0: + sampler = SeededSubgraphSampler( + self.graph, max_proposals=batch_size + ) + else: + sampler = SubgraphSampler(self.graph, max_proposals=batch_size) + return iter(sampler) + + +# --- Multi-Brain Dataset --- +class MultiBrainFragmentsDataset: + """ + A dataset class for storing graphs used to train models to perform split + correction. Graphs are stored in the "self.graphs" attribute, which is a + dictionary containing the followin items: + - Key: (brain_id, segmentation_id, example_id) + - Value: graph that is an instance of ProposalGraph + + This dataset is populated using the "self.add_graph" method, which + requires the following inputs: + (1) key: Unique identifier of graph. + (2) gt_pointer: Path to ground truth SWC files. + (2) pred_pointer: Path to predicted SWC files. + (3) img_path: Path to whole-brain image stored in cloud bucket. + + Note: This dataset supports graphs from multiple whole-brain datasets. + """ + def __init__(self, shuffle=True): + # Instance attributes + self.datasets = dict() + self.shuffle = shuffle + + def add_dataset(self, key, dataset): + self.datasets[key] = dataset + + def __iter__(self): + """ + Iterates over the dataset and yields model-ready inputs and targets. + + Yields + ------ + inputs : HeteroGraphData + Heterogeneous graph data. + targets : torch.Tensor + Ground truth labels. + """ + samplers = self.init_samplers() while len(samplers) > 0: key = self.get_next_key(samplers) try: - # Feature extraction subgraph = next(samplers[key]) - features = self.feature_extractors[key](subgraph) - - # Get model inputs - data = HeteroGraphData(features) - inputs = data.get_inputs() - targets = data.get_targets() - yield inputs, targets + yield self.datasets.get_inputs(subgraph) except StopIteration: del samplers[key] + # --- Helpers --- def get_next_key(self, samplers): """ Gets the next key to sample from a dictionary of samplers. @@ -206,23 +236,14 @@ def get_next_key(self, samplers): keys = sorted(samplers.keys()) return keys[0] - # --- Helpers --- - @staticmethod - def clip_fragments(graph, metadata_path): - # Extract bounding box - bucket_name, path = util.parse_cloud_path(metadata_path) - metadata = util.read_json_from_gcs(bucket_name, path) - origin = metadata["chunk_origin"][::-1] - shape = metadata["chunk_shape"][::-1] - - # Clip graph - nodes = list() - for i in graph.nodes: - voxel = graph.get_voxel(i) - if not img_util.is_contained(voxel - origin, shape): - nodes.append(i) - graph.remove_nodes_from(nodes) - graph.relabel_nodes() + def init_samplers(self): + samplers = dict() + for key, dataset in self.datasets.items(): + batch_size = dataset.config.ml.batch_size + samplers[key] = iter( + SubgraphSampler(dataset.graph, max_proposals=batch_size) + ) + return samplers def n_proposals(self): """ @@ -233,7 +254,10 @@ def n_proposals(self): int Number of proposals. """ - return np.sum([graph.n_proposals() for graph in self.graphs.values()]) + cnt = 0 + for dataset in self.datasets.values(): + cnt += dataset.graph.n_proposals() + return cnt def p_accepts(self): """ @@ -244,8 +268,10 @@ def p_accepts(self): float Percentage of accepted proposals in ground truth. """ - cnts = [len(graph.gt_accepts) for graph in self.graphs.values()] - return np.sum(cnts) / self.n_proposals() + accepts_cnt = 0 + for dataset in self.datasets.values(): + accepts_cnt += len(dataset.graph.gt_accepts) + return accepts_cnt / self.n_proposals() def save_examples_summary(self, path): """ @@ -257,8 +283,9 @@ def save_examples_summary(self, path): Output path for the CSV file. """ examples_summary = list() - for key in sorted(self.graphs.keys()): - examples_summary.extend([key] * self.graphs[key].n_proposals()) + for key in sorted(self.datasets.keys()): + n_proposals = self.datasets[key].graph.n_proposals() + examples_summary.extend([key] * n_proposals) pd.DataFrame(examples_summary).to_csv(path) diff --git a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py index a3940d4..1b7f6d8 100644 --- a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py +++ b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py @@ -29,7 +29,6 @@ def __init__( self, graph, img_path, - search_radius, brightness_clip=400, padding=50, patch_shape=(96, 96, 96), @@ -44,8 +43,6 @@ def __init__( Graph to extract features from. img_path : str Path to image of whole-brain dataset. - search_radius : float - Search radius used to generate proposals. brightness_clip : int, optional ... padding : int, optional @@ -58,7 +55,7 @@ def __init__( Path to segmentation of whole-brain dataset. """ self.extractors = [ - SkeletonFeatureExtractor(graph, search_radius), + SkeletonFeatureExtractor(graph), ImageFeatureExtractor( graph, img_path, @@ -89,7 +86,7 @@ class SkeletonFeatureExtractor: A class for extracting skeleton-based features. """ - def __init__(self, graph, search_radius): + def __init__(self, graph): """ Instantiates a SkeletonFeatureExtractor object. @@ -97,12 +94,9 @@ def __init__(self, graph, search_radius): ---------- graph : ProposalGraph Graph to extract features from. - search_radius : float - Search radius used to generate edge proposals. """ # Instance attributes self.graph = graph - self.search_radius = search_radius # Build KD-tree from leaf nodes self.graph.set_kdtree(node_type="leaf") @@ -184,8 +178,8 @@ def extract_proposal_features(self, subgraph, features): for p in subgraph.proposals: proposal_features[p] = np.concatenate( ( - self.graph.proposal_length(p) / self.search_radius, - self.graph.n_nearby_leafs(p, self.search_radius), + self.graph.proposal_length(p) / self.graph.search_radius, + self.graph.n_nearby_leafs(p, self.graph.search_radius), self.graph.proposal_attr(p, "radius"), self.graph.proposal_directionals(p, 16), self.graph.proposal_directionals(p, 32), diff --git a/src/neuron_proofreader/split_proofreading/split_inference.py b/src/neuron_proofreader/split_proofreading/split_inference.py index 959baae..212f4b0 100644 --- a/src/neuron_proofreader/split_proofreading/split_inference.py +++ b/src/neuron_proofreader/split_proofreading/split_inference.py @@ -35,21 +35,14 @@ from tqdm import tqdm import networkx as nx -import numpy as np import os import torch -from neuron_proofreader.proposal_graph import ProposalGraph -from neuron_proofreader.machine_learning.subgraph_sampler import ( - SubgraphSampler, - SeededSubgraphSampler -) -from neuron_proofreader.split_proofreading.split_feature_extraction import ( - FeaturePipeline, - HeteroGraphData +from neuron_proofreader.split_proofreading.split_datasets import ( + FragmentsDataset ) from neuron_proofreader.machine_learning.gnn_models import VisionHGAT -from neuron_proofreader.utils import geometry_util, ml_util, util +from neuron_proofreader.utils import ml_util, util class InferencePipeline: @@ -63,15 +56,14 @@ class InferencePipeline: def __init__( self, - brain_id, - segmentation_id, + fragments_path, img_path, model_path, output_dir, config, + log_preamble="", segmentation_path=None, - soma_centroids=None, - s3_dict=None, + soma_centroids=list(), ): """ Initializes an object that executes the full split correction @@ -95,173 +87,103 @@ def __init__( segmentation_path : str, optional Path to segmentation stored in GCS bucket. The default is None. soma_centroids : List[Tuple[float]] or None, optional - Physcial coordinates of soma centroids. The default is None. - s3_dict : dict, optional - ... + Physcial coordinates of soma centroids. The default is an empty + list. """ # Instance attributes self.accepted_proposals = list() + self.config = config self.img_path = img_path - self.model_path = model_path - self.brain_id = brain_id - self.segmentation_id = segmentation_id - self.segmentation_path = segmentation_path - self.soma_centroids = soma_centroids or list() - self.s3_dict = s3_dict - - # Extract config settings - self.graph_config = config.graph_config - self.ml_config = config.ml_config - - # Set output directory + self.model = VisionHGAT(config.ml.patch_shape) self.output_dir = output_dir - util.mkdir(self.output_dir) + self.soma_centroids = soma_centroids - # Initialize logger + # Logger + util.mkdir(self.output_dir) log_path = os.path.join(self.output_dir, "runtimes.txt") self.log_handle = open(log_path, 'a') + self.log(log_preamble) - # --- Core --- - def run(self, swcs_path, search_radius): + # Load data + self._load_data(fragments_path, img_path, segmentation_path) + ml_util.load_model(self.model, model_path, device=config.ml.device) + + def _load_data(self, fragments_path, img_path, segmentation_path): """ - Executes the full inference pipeline. + Builds a graph from the given fragments. Parameters ---------- - swcs_path : str - Path to SWC files used to build an instance of FragmentGraph, - see "swc_util.Reader" for further documentation. + fragments_path : str + Path to SWC files to be loaded into graph. """ - # Initializations - self.log_experiment() - self.write_metadata() - t0 = time() - - # Main - self.build_graph(swcs_path) - self.connect_soma_fragments() if self.soma_centroids else None - self.generate_proposals(search_radius) - self.classify_proposals(self.ml_config.threshold, search_radius) - - # Finish - t, unit = util.time_writer(time() - t0) - self.log_graph_specs(prefix="\nFinal") - self.log(f"Total Runtime: {t:.2f} {unit}\n") - self.save_results() - - def run_schedule(self, swcs_path, radius_schedule, threshold_schedule): - # Initializations - self.log_experiment() - self.write_metadata() + # Load data t0 = time() - - # Main - self.build_graph(swcs_path) - schedules = zip(radius_schedule, threshold_schedule) - for i, (radius, threshold) in enumerate(schedules): - self.log(f"\n--- Round {i + 1}: Radius = {radius} ---") - self.generate_proposals(radius) - self.classify_proposals(threshold) - self.log_graph_specs(prefix="Current") - - # Finish - t, unit = util.time_writer(time() - t0) - self.log_graph_specs(prefix="\nFinal") - self.log(f"Total Runtime: {t:.2f} {unit}\n") - self.save_results() - - def build_graph(self, swcs_path): + self.log("Step 1: Build Graph") + self.dataset = FragmentsDataset( + fragments_path, + img_path, + self.config, + segmentation_path=segmentation_path, + soma_centroids=self.soma_centroids + ) + self.save_fragment_ids() + + # Connect fragments very close to soma + self.log(f"# Soma Fragments: {len(self.dataset.graph.soma_ids)}") + if len(self.soma_centroids) > 0: + somas = self.soma_centroids + results = self.dataset.graph.connect_soma_fragments(somas) + self.log(results) + + # Log results + elapsed, unit = util.time_writer(time() - t0) + self.log(self.dataset.graph.get_summary(prefix="\nInitial")) + self.log(f"Module Runtime: {elapsed:.2f} {unit}\n") + + # --- Core Routines --- + def run(self, search_radius): """ - Builds a graph from the given fragments. + Executes the full inference pipeline. Parameters ---------- - fragment_pointer : str - Path to SWC files to be loaded into graph. + swcs_path : str + Path to SWC files used to build an instance of FragmentGraph, + see "swc_util.Reader" for further documentation. """ - self.log("Step 1: Build Graph") + # Main t0 = time() - - # Initialize graph - self.graph = ProposalGraph( - anisotropy=self.graph_config.anisotropy, - min_size=self.graph_config.min_size, - min_size_with_proposals=self.graph_config.min_size_with_proposals, - node_spacing=self.graph_config.node_spacing, - prune_depth=self.graph_config.prune_depth, - remove_high_risk_merges=self.graph_config.remove_high_risk_merges, - segmentation_path=self.segmentation_path, - soma_centroids=self.soma_centroids, - verbose=True, - ) - self.graph.load(swcs_path) - - # Filter fragments - if self.graph_config.remove_doubles: - geometry_util.remove_doubles(self.graph, 160) + self.generate_proposals(search_radius) + self.classify_proposals(self.config.ml.threshold) # Report results - path = f"{self.output_dir}/segment_ids.txt" - swc_ids = list(self.graph.component_id_to_swc_id.values()) - util.write_list(path, swc_ids) - print("# Soma Fragments:", len(self.graph.soma_ids)) - t, unit = util.time_writer(time() - t0) - self.log_graph_specs(prefix="\nInitial") - self.log(f"Module Runtime: {t:.2f} {unit}\n") - - def connect_soma_fragments(self): - # Initializations - self.graph.set_kdtree() - - # Parse locations - merge_cnt, soma_cnt = 0, 0 - for soma_xyz in self.soma_centroids: - node_ids = self.graph.find_fragments_near_xyz(soma_xyz, 25) - if len(node_ids) > 1: - # Find closest node to soma location - soma_cnt += 1 - best_dist = np.inf - best_node = None - for i in node_ids: - dist = geometry_util.dist(soma_xyz, self.graph.node_xyz[i]) - if dist < best_dist: - best_dist = dist - best_node = i - soma_component_id = self.graph.node_component_id[best_node] - self.graph.soma_ids.add(soma_component_id) - node_ids.remove(best_node) - - # Merge fragments to soma - soma_xyz = self.graph.node_xyz[best_node] - for i in node_ids: - attrs = { - "radius": np.array([2, 2]), - "xyz": np.array([soma_xyz, self.graph.node_xyz[i]]), - } - self.graph._add_edge((best_node, i), attrs) - self.graph.update_component_ids(soma_component_id, i) - merge_cnt += 1 - - print("# Somas Connected:", soma_cnt) - print("# Soma Fragment Merges:", merge_cnt) - del self.graph.kdtree + self.log(self.dataset.graph.get_summary(prefix="\nFinal")) + self.log(f"Total Runtime: {t:.2f} {unit}\n") + self.save_results() def generate_proposals(self, search_radius): """ Generates proposals for the fragments graph based on the specified configuration. + + Parameters + ---------- + search_radius : float + Search radius (in microns) used to generate proposals. """ # Main t0 = time() self.log("\nStep 2: Generate Proposals") - self.graph.generate_proposals(search_radius) - n_proposals = format(self.graph.n_proposals(), ",") + self.dataset.graph.generate_proposals(search_radius) + n_proposals = format(self.dataset.graph.n_proposals(), ",") + n_proposals_blocked = self.dataset.graph.n_proposals_blocked # Report results t, unit = util.time_writer(time() - t0) self.log(f"# Proposals: {n_proposals}") - self.log(f"# Proposals Blocked: {self.graph.n_proposals_blocked}") + self.log(f"# Proposals Blocked: {n_proposals_blocked}") self.log(f"Module Runtime: {t:.2f} {unit}\n") def classify_proposals(self, accept_threshold, search_radius): @@ -271,34 +193,58 @@ def classify_proposals(self, accept_threshold, search_radius): a prediction above "self.threshold" are accepted and added to the graph as an edge. """ - # Initializations - self.log("Step 3: Run Inference") t0 = time() + self.log("Step 3: Run Inference") - # Generate model predictions - n_proposals = self.graph.n_proposals() - inference_engine = InferenceEngine( - self.graph, - self.img_path, - self.model_path, - self.ml_config, - search_radius, - segmentation_path=self.segmentation_path, - ) - preds_dict = inference_engine.run() - path = os.path.join(self.output_dir, "proposal_predictions.json") - util.write_json(path, reformat_preds(preds_dict)) - - # Update graph - stop + # Main + accepts = set() + new_threshold = 0.99 + preds = self.predict_proposals() + while True: + # Update graph + cur_threshold = new_threshold + accepts.update(self.merge_proposals(cur_threshold)) + + # Update threshold + new_threshold = cur_threshold - 0.025 + if cur_threshold == new_threshold: + break # Report results t, unit = util.time_writer(time() - t0) self.log(f"# Merges Blocked: {self.graph.n_merges_blocked}") self.log(f"# Accepted: {format(len(accepts), ',')}") - self.log(f"% Accepted: {len(accepts) / n_proposals:.4f}") + self.log(f"% Accepted: {len(accepts) / len(preds):.4f}") self.log(f"Module Runtime: {t:.4f} {unit}\n") + def predict_proposals(self): + # Main + preds = dict() + pbar = tqdm(total=self.dataset.graph.n_proposals(), desc="Inference") + for subgraph in self.dataset: + data = self.dataset.get_inputs(subgraph) + preds.update(self.predict(data)) + pbar.update(subgraph.n_proposals()) + + # Save results + path = os.path.join(self.output_dir, "proposal_predictions.json") + util.write_json(path, reformat_preds(preds)) + return preds + + def merge_proposals(self, preds, threshold): + accepts = list() + for proposal in self.dataset.graph.get_sorted_proposals(): + # Check if proposal satifies threshold + i, j = tuple(proposal) + if preds[proposal] < threshold: + continue + + # Check if proposal creates a loop + if not nx.has_path(self.graph, i, j): + self.graph.merge_proposal(proposal) + accepts.append(proposal) + return accepts + def save_results(self): """ Saves the processed results from running the inference pipeline, @@ -330,143 +276,12 @@ def save_results(self): self.s3_dict["prefix"] ) - # --- io --- - def save_connections(self, round_id=None): - """ - Writes the accepted proposals from the graph to a text file. Each line - contains the two swc ids as comma separated values. - """ - suffix = f"-{round_id}" if round_id else "" - path = os.path.join(self.output_dir, f"connections{suffix}.txt") - with open(path, "w") as f: - for id_1, id_2 in self.graph.merged_ids: - f.write(f"{id_1}, {id_2}" + "\n") - - def write_metadata(self): - """ - Writes metadata about the current pipeline run to a JSON file. - """ - metadata = { - "date": datetime.today().strftime("%Y-%m-%d"), - "brain_id": self.brain_id, - "segmentation_id": self.segmentation_id, - "min_fragment_size": f"{self.graph_config.min_size}um", - "min_fragment_size_with_proposals": f"{self.graph_config.min_size_with_proposals}um", - "node_spacing": self.graph_config.node_spacing, - "remove_doubles": self.graph_config.remove_doubles, - "use_somas": len(self.soma_centroids) > 0, - "complex_proposals": self.graph_config.complex_bool, - "long_range_bool": self.graph_config.long_range_bool, - "proposals_per_leaf": self.graph_config.proposals_per_leaf, - "search_radius": f"{self.graph_config.search_radius}um", - "model_name": os.path.basename(self.model_path), - "accept_threshold": self.ml_config.threshold, - } - path = os.path.join(self.output_dir, "metadata.json") - util.write_json(path, metadata) - - # --- Summaries --- + # --- Helpers --- def log(self, txt): print(txt) self.log_handle.write(txt) self.log_handle.write("\n") - def log_experiment(self): - self.log("\nExperiment Overview") - self.log("-" * len(self.segmentation_id)) - self.log(f"Brain_ID: {self.brain_id}") - self.log(f"Segmentation_ID: {self.segmentation_id}") - self.log("\n") - - def log_graph_specs(self, prefix="\n"): - """ - Prints an overview of the graph's structure and memory usage. - """ - # Compute values - n_components = nx.number_connected_components(self.graph) - n_components = format(n_components, ",") - n_nodes = format(self.graph.number_of_nodes(), ",") - n_edges = format(self.graph.number_of_edges(), ",") - - # Report results - self.log(f"{prefix} Graph") - self.log(f"# Connected Components: {n_components}") - self.log(f"# Nodes: {n_nodes}") - self.log(f"# Edges: {n_edges}") - self.log(f"Memory Consumption: {util.get_memory_usage():.2f} GBs") - - -class InferenceEngine: - """ - Class that runs inference with a machine learning model that has been - trained to classify edge proposals. - """ - - def __init__( - self, - graph, - img_path, - model_path, - ml_config, - search_radius, - segmentation_path=None, - ): - """ - Initializes an inference engine by loading images and setting class - attributes. - - Parameters - ---------- - img_path : str - Path to image. - model_path : str - Path to machine learning model weights. - ml_config : MLConfig - Configuration object containing parameters and settings required - for the inference. - search_radius : float - Search radius used to generate proposals. - segmentation_path : str, optional - Path to segmentation stored in GCS bucket. Default is None. - """ - # Instance attributes - self.batch_size = ml_config.batch_size - self.device = ml_config.device - self.model = VisionHGAT(ml_config.patch_shape) - self.pbar = tqdm(total=graph.n_proposals(), desc="Inference") - self.subgraph_sampler = self.get_subgraph_sampler(graph) - - # Feature generator - self.feature_extractor = FeaturePipeline( - graph, - img_path, - search_radius, - brightness_clip=ml_config.brightness_clip, - patch_shape=ml_config.patch_shape, - segmentation_path=segmentation_path - ) - - # Load weights - ml_util.load_model(self.model, model_path, device=ml_config.device) - - def get_subgraph_sampler(self, graph): - if len(graph.soma_ids) > 0: - return SeededSubgraphSampler(graph, self.batch_size) - else: - return SubgraphSampler(graph, self.batch_size) - - def run(self): - preds = dict() - for subgraph in self.subgraph_sampler: - # Get model inputs - features = self.feature_extractor(subgraph) - data = HeteroGraphData(features) - - # Run model - preds.update(self.predict(data)) - self.pbar.update(subgraph.n_proposals()) - return preds - def predict(self, data): """ ... @@ -491,95 +306,46 @@ def predict(self, data): hat_y = ml_util.tensor_to_list(hat_y) return {idx_to_id[i]: y_i for i, y_i in enumerate(hat_y)} - def update_graph(self, preds, high_threshold=0.9): - """ - Determines which proposals to accept based on prediction scores and - the specified threshold. - - Parameters - ---------- - preds : dict - Dictionary that maps proposal ids to probability generated from - machine learning model. - high_threshold : float, optional - Threshold value for separating the best proposals from the rest. - Default is 0.9. - - Returns - ------- - list - Proposals to be added as edges to "graph". + def save_connections(self, round_id=None): """ - # Partition proposals into best and the rest - preds = {k: v for k, v in preds.items() if v > self.threshold} - best_proposals, proposals = self.separate_best(preds, high_threshold) - - # Determine which proposals to accept - accepts = list() - accepts.extend(self.add_accepts(best_proposals)) - accepts.extend(self.add_accepts(proposals)) - return accepts - - def separate_best(self, preds, high_threshold): + Writes the accepted proposals from the graph to a text file. Each line + contains the two swc ids as comma separated values. """ - Splits "preds" into two separate dictionaries such that one contains - the best proposals (i.e. simple proposals with high confidence) and - the other contains all other proposals. + suffix = f"-{round_id}" if round_id else "" + path = os.path.join(self.output_dir, f"connections{suffix}.txt") + with open(path, "w") as f: + for id_1, id_2 in self.graph.merged_ids: + f.write(f"{id_1}, {id_2}" + "\n") - Parameters - ---------- - preds : dict - Dictionary that maps proposal ids to probability generated from - machine learning model. - high_threshold : float - Threshold on acceptance probability for proposals. + def save_fragment_ids(self): + path = f"{self.output_dir}/segment_ids.txt" + segment_ids = list(self.dataset.graph.component_id_to_swc_id.values()) + util.write_list(path, segment_ids) - Returns - ------- - list - Proposal IDs determined to be the best. - list - All other proposal IDs. - """ - best_probs, probs = list(), list() - best_proposals, proposals = list(), list() - simple_proposals = self.graph.simple_proposals() - for proposal, prob in preds.items(): - if proposal in simple_proposals and prob > high_threshold: - best_proposals.append(proposal) - best_probs.append(prob) - else: - proposals.append(proposal) - probs.append(prob) - best_idxs = np.argsort(best_probs) - idxs = np.argsort(probs) - return np.array(best_proposals)[best_idxs], np.array(proposals)[idxs] - - def add_accepts(self, proposals): + def write_metadata(self): """ - ... - - Parameters - ---------- - proposals : list[frozenset] - Proposals with predicted probability above threshold to be added - to the graph. - - Returns - ------- - List[frozenset] - List of proposals that do not create a cycle when iteratively - added to "graph". + Writes metadata about the current pipeline run to a JSON file. """ - accepts = list() - for proposal in proposals: - i, j = tuple(proposal) - if not nx.has_path(self.graph, i, j): - self.graph.merge_proposal(proposal) - accepts.append(proposal) - return accepts + metadata = { + "date": datetime.today().strftime("%Y-%m-%d"), + "brain_id": self.brain_id, + "segmentation_id": self.segmentation_id, + "min_fragment_size": f"{self.graph_config.min_size}um", + "min_fragment_size_with_proposals": f"{self.graph_config.min_size_with_proposals}um", + "node_spacing": self.graph_config.node_spacing, + "remove_doubles": self.graph_config.remove_doubles, + "use_somas": len(self.soma_centroids) > 0, + "complex_proposals": self.graph_config.complex_bool, + "long_range_bool": self.graph_config.long_range_bool, + "proposals_per_leaf": self.graph_config.proposals_per_leaf, + "search_radius": f"{self.graph_config.search_radius}um", + "model_name": os.path.basename(self.model_path), + "accept_threshold": self.ml_config.threshold, + } + path = os.path.join(self.output_dir, "metadata.json") + util.write_json(path, metadata) # --- Helpers --- def reformat_preds(preds_dict): - return {tuple(k): v for k, v in preds_dict.items()} + return {str(k): v for k, v in preds_dict.items()}