diff --git a/src/neuron_proofreader/config.py b/src/neuron_proofreader/config.py index ffab1de..69221c6 100644 --- a/src/neuron_proofreader/config.py +++ b/src/neuron_proofreader/config.py @@ -1,64 +1,94 @@ """ -Created on Sat Sept 16 16:00:00 2024 +Created on Frid Sept 15 16:00:00 2024 @author: Anna Grim @email: anna.grim@alleninstitute.org -This module defines a set of configuration classes used for setting up various -aspects of a system involving graphs, proposals, and machine learning (ML). +This module defines a set of configuration classes used for setting storing +parameters used in neuron proofreading pipelines. """ + from dataclasses import dataclass from typing import Tuple +import os + +from neuron_proofreader.utils import util + @dataclass -class GraphConfig: +class ProposalGraphConfig: """ Represents configuration settings related to graph properties and proposals generated. Attributes ---------- - anisotropy : Tuple[float], optional - Scaling factors applied to xyz coordinates to account for anisotropy - of microscope. Note this instance of "anisotropy" is only used while - reading SWC files. Default is (1.0, 1.0, 1.0). - min_size : float, optional - Minimum path length (in microns) of swc files which are stored as - connected components in the FragmentsGraph. Default is 30. - min_size_with_proposals : float, optional - Minimum fragment path length required for proposals. Default is 0. - node_spacing : int, optional - Spacing (in microns) between nodes. Default is 5. - proposals_per_leaf : int - Maximum number of proposals generated for each leaf. Default is 3. - prune_depth : int, optional + anisotropy : Tuple[float] + Scaling factors used to transform physical to image coordinates + Default is (1.0, 1.0, 1.0). + max_proposals_per_leaf : int + Maximum number of proposals generated at leaf nodes. Default is 3. + min_size : float + Minimum path length (in microns) of SWC files loaded into a graph + object. Default is 40. + min_size_with_proposals : float + Minimum path length (in microns) required for a fragment to have + proposals generated from its leaf nodes. Default is 40. + node_spacing : float + Physcial spacing (in microns) between nodes. Default is 1. + prune_depth : int Branches in graph less than "prune_depth" microns are pruned. Default - is 16. - remove_doubles : bool, optional - ... - remove_high_risk_merges : bool, optional + is 24. + remove_doubles : bool + Indication of whether to remove fragments that are likely a double of + another fragment. Default is True. + remove_high_risk_merges : bool Indication of whether to remove high risk merge sites (i.e. close branching points). Default is False. - trim_endpoints_bool : bool, optional - Indication of whether to endpoints of branches with exactly one + trim_endpoints_bool : bool + Indication of whether trim endpoints of branches with exactly one proposal. Default is True. + verbose : bool + Indication of whether to display a progress bar. Default is True. """ anisotropy: Tuple[float, float, float] = (1.0, 1.0, 1.0) max_proposals_per_leaf: int = 3 min_size: float = 40.0 min_size_with_proposals: float = 40.0 - node_spacing: int = 1 + node_spacing: float = 1.0 proposals_per_leaf: int = 3 prune_depth: float = 24.0 remove_doubles: bool = True remove_high_risk_merges: bool = False - search_radius: float = 20.0 trim_endpoints_bool: bool = True verbose: bool = True + def to_dict(self): + """ + Converts configuration attributes to a dictionary. + + Returns + ------- + dict + Dictionary containing configuration attributes. + """ + attributes = dict() + for k, v in vars(self).items(): + if isinstance(v, tuple): + attributes[k] = list(v) + else: + attributes[k] = v + return attributes + + def save(self, path): + """ + Saves configuration attributes to a JSON file. + """ + util.write_json(path, self.to_dict()) + @dataclass class MLConfig: @@ -70,25 +100,65 @@ class MLConfig: batch_size : int The number of samples processed in one batch during training or inference. Default is 64. + brightness_clip : int + Maximum brightness value that image intensities are clipped to. + Default is 400. + device : str + Device to load model onto. Default is "cuda". + model_name : str + Name of model used to perform inference. Default is None. + patch_shape : Tuple[int] + Shape of image patch expected by vision model. Default is (96, 96, 96). + shuffle : bool + Indication of whether to shuffle batches. Default is False threshold : float - A general threshold value used for classification. Default is 0.6. + A general threshold value used in classification. Default is 0.8. + transform : bool + Indication of whether to apply data augmentation to image patches. + Default is False. """ + batch_size: int = 64 brightness_clip: int = 400 device: str = "cuda" + model_name: str = None patch_shape: tuple = (96, 96, 96) shuffle: bool = False - transform: bool = False threshold: float = 0.8 + transform: bool = False + def to_dict(self): + """ + Converts configuration attributes to a dictionary. + + Returns + ------- + dict + Dictionary containing configuration attributes. + """ + attributes = dict() + for k, v in vars(self).items(): + if isinstance(v, tuple): + attributes[k] = list(v) + else: + attributes[k] = v + return attributes + + def save(self, path): + """ + Saves configuration attributes to a JSON file. + """ + util.write_json(path, self.to_dict()) + +@dataclass class Config: """ A configuration class for managing and storing settings related to graph and machine learning models. """ - def __init__(self, graph_config: GraphConfig, ml_config: MLConfig): + def __init__(self, graph_config, ml_config): """ Initializes a Config object which is used to manage settings used to run the proofreading pipeline. @@ -97,11 +167,21 @@ def __init__(self, graph_config: GraphConfig, ml_config: MLConfig): ---------- graph_config : GraphConfig Instance of the "GraphConfig" class that contains configuration - parameters for graph and proposal operations, such as anisotropy, - node spacing, and other graph-specific settings. + parameters for graph and proposal operations. ml_config : MLConfig An instance of the "MLConfig" class that includes configuration parameters for machine learning models. """ self.graph = graph_config self.ml = ml_config + + def save(self, dir_path): + """ + Saves configuration attributes to a JSON file. + + dir_path : str + Path to directory to save JSON file. + """ + + self.graph.save(os.path.join(dir_path, "metadata_graph.json")) + self.ml.save(os.path.join(dir_path, "metadata_ml.json")) diff --git a/src/neuron_proofreader/proposal_graph.py b/src/neuron_proofreader/proposal_graph.py index 55afb1c..43e8e63 100644 --- a/src/neuron_proofreader/proposal_graph.py +++ b/src/neuron_proofreader/proposal_graph.py @@ -102,6 +102,7 @@ def __init__( self.xyz_to_edge = dict() # Instance attributes - Proposals + self.accepts = set() self.gt_accepts = set() self.merged_ids = set() self.n_merges_blocked = 0 @@ -335,7 +336,23 @@ def get_kdtree(self, node_type=None): else: return KDTree(list(self.xyz_to_edge.keys())) - # --- Proposal Generation --- + # --- Proposal Operations --- + def add_proposal(self, i, j): + """ + Adds proposal between nodes "i" and "j". + + Parameters + ---------- + i : int + Node ID. + j : int + Node ID + """ + proposal = frozenset({i, j}) + self.node_proposals[i].add(j) + self.node_proposals[j].add(i) + self.proposals.add(proposal) + def generate_proposals(self, search_radius): """ Generates proposals from leaf nodes. @@ -359,48 +376,41 @@ def generate_proposals(self, search_radius): gt_graph.load(self.gt_path) self.gt_accepts = groundtruth_generation.run(gt_graph, self) - def add_proposal(self, i, j): + def get_sorted_proposals(self): """ - Adds proposal between nodes "i" and "j". + Return proposals sorted by physical length. - Parameters - ---------- - i : int - Node ID. - j : int - Node ID + Returns + ------- + List[Frozenset[int]] + List of proposals sorted by phyiscal length. """ - proposal = frozenset({i, j}) - self.node_proposals[i].add(j) - self.node_proposals[j].add(i) - self.proposals.add(proposal) + proposals = self.list_proposals() + lengths = [self.proposal_length(p) for p in proposals] + return [proposals[i] for i in np.argsort(lengths)] - def store_proposals(self, proposals): - self.node_proposals = defaultdict(set) - for proposal in proposals: - i, j = tuple(proposal) - self.add_proposal(i, j) + def is_mergeable(self, i, j): + one_leaf = self.degree[i] == 1 or self.degree[j] == 1 + branching = self.degree[i] > 2 or self.degree[j] > 2 + somas_check = not (self.is_soma(i) and self.is_soma(j)) + return somas_check and (one_leaf and not branching) - def remove_proposal(self, proposal): + def is_simple(self, proposal): """ - Removes an existing proposal between two nodes. + Checks if both nodes in a proposal are leafs. Parameters ---------- proposal : Frozenset[int] - Pair of node IDs corresponding to a proposal. + Pair of nodes that form a proposal. + + Returns + ------- + bool + Indication of whether both nodes in a proposal are leafs. """ i, j = tuple(proposal) - self.node_proposals[i].remove(j) - self.node_proposals[j].remove(i) - self.proposals.remove(proposal) - - def trim_proposals(self): - for proposal in self.list_proposals(): - is_simple = self.is_simple(proposal) - is_single = self.is_single_proposal(proposal) - if is_simple and is_single: - trim_endpoints_at_proposal(self, proposal) + return True if self.degree[i] == 1 and self.degree[j] == 1 else False def is_single_proposal(self, proposal): """ @@ -434,19 +444,30 @@ def list_proposals(self): """ return list(self.proposals) - # --- Proposal Helpers --- - def get_sorted_proposals(self): - """ - Return proposals sorted by physical length. + def merge_proposal(self, proposal): + i, j = tuple(proposal) + if self.is_mergeable(i, j): + # Update attributes + attrs = { + "radius": self.node_radius[np.array([i, j], dtype=int)], + "xyz": self.node_xyz[np.array([i, j], dtype=int)] + } - Returns - ------- - List[Frozenset[int]] - List of proposals sorted by phyiscal length. - """ - proposals = self.list_proposals() - lengths = [self.proposal_length(p) for p in proposals] - return [proposals[i] for i in np.argsort(lengths)] + # Update component_ids + self.merged_ids.add((self.get_swc_id(i), self.get_swc_id(j))) + if self.is_soma(i): + component_id = self.node_component_id[i] + self.update_component_ids(component_id, j) + else: + component_id = self.node_component_id[j] + self.update_component_ids(component_id, i) + + # Update graph + self._add_edge((i, j), attrs) + self.accepts.add(proposal) + self.proposals.remove(proposal) + else: + self.n_merges_blocked += 1 def n_proposals(self): """ @@ -459,36 +480,34 @@ def n_proposals(self): """ return len(self.proposals) - def is_simple(self, proposal): + def remove_proposal(self, proposal): """ - Checks if both nodes in a proposal are leafs. + Removes an existing proposal between two nodes. Parameters ---------- proposal : Frozenset[int] - Pair of nodes that form a proposal. - - Returns - ------- - bool - Indication of whether both nodes in a proposal are leafs. + Pair of node IDs corresponding to a proposal. """ i, j = tuple(proposal) - return True if self.degree[i] == 1 and self.degree[j] == 1 else False - - def simple_proposals(self): - return set([p for p in self.proposals if self.is_simple(p)]) - - def complex_proposals(self): - return set([p for p in self.proposals if not self.is_simple(p)]) + self.node_proposals[i].remove(j) + self.node_proposals[j].remove(i) + self.proposals.remove(proposal) - def proposal_length(self, proposal): - return self.dist(*tuple(proposal)) + def store_proposals(self, proposals): + self.node_proposals = defaultdict(set) + for proposal in proposals: + i, j = tuple(proposal) + self.add_proposal(i, j) - def proposal_midpoint(self, proposal): - i, j = tuple(proposal) - return geometry.midpoint(self.node_xyz[i], self.node_xyz[j]) + def trim_proposals(self): + for proposal in self.list_proposals(): + is_simple = self.is_simple(proposal) + is_single = self.is_single_proposal(proposal) + if is_simple and is_single: + trim_endpoints_at_proposal(self, proposal) + # --- Proposal Feature Generation --- def proposal_attr(self, proposal, key): """ Gets the attributes of nodes in "proposal". @@ -543,61 +562,16 @@ def proposal_directionals(self, proposal, depth): dot_ij = max(dot_ij, -dot_ij) return np.array([dot_i, dot_j, dot_ij]) - def truncated_edge_attr_xyz(self, i, depth): - xyz_path_list = self.edge_attr(i, "xyz") - return [geometry.truncate_path(path, depth) for path in xyz_path_list] + def proposal_length(self, proposal): + return self.dist(*tuple(proposal)) - def merge_proposal(self, proposal): + def proposal_midpoint(self, proposal): i, j = tuple(proposal) - if self.is_mergeable(i, j): - # Update attributes - attrs = { - "radius": self.node_radius[np.array([i, j], dtype=int)], - "xyz": self.node_xyz[np.array([i, j], dtype=int)] - } - self.node_radius[i] = 5.3141592 - self.node_radius[j] = 5.3141592 - - # Update component_ids - self.merged_ids.add((self.get_swc_id(i), self.get_swc_id(j))) - if self.is_soma(i): - component_id = self.node_component_id[i] - self.update_component_ids(component_id, j) - else: - component_id = self.node_component_id[j] - self.update_component_ids(component_id, i) - - # Update graph - self._add_edge((i, j), attrs) - self.proposals.remove(proposal) - else: - self.n_merges_blocked += 1 - - def is_mergeable(self, i, j): - one_leaf = self.degree[i] == 1 or self.degree[j] == 1 - branching = self.degree[i] > 2 or self.degree[j] > 2 - somas_check = not (self.is_soma(i) and self.is_soma(j)) - return somas_check and (one_leaf and not branching) - - def update_component_ids(self, component_id, root): - """ - Updates the component_id of all nodes connected to "root". + return geometry.midpoint(self.node_xyz[i], self.node_xyz[j]) - Parameters - ---------- - component_id : str - Connected component id. - root : int - Node ID - """ - queue = [root] - visited = set(queue) - while len(queue) > 0: - i = queue.pop() - self.node_component_id[i] = component_id - visited.add(i) - for j in [j for j in self.neighbors(i) if j not in visited]: - queue.append(j) + def truncated_edge_attr_xyz(self, i, depth): + xyz_path_list = self.edge_attr(i, "xyz") + return [geometry.truncate_path(path, depth) for path in xyz_path_list] def n_nearby_leafs(self, proposal, radius): """ @@ -712,6 +686,26 @@ def orient_edge_attr(self, edge, i, key="xyz"): else: return np.flip(self.edges[edge][key], axis=0) + def update_component_ids(self, component_id, root): + """ + Updates the component_id of all nodes connected to "root". + + Parameters + ---------- + component_id : str + Connected component id. + root : int + Node ID + """ + queue = [root] + visited = set(queue) + while len(queue) > 0: + i = queue.pop() + self.node_component_id[i] = component_id + visited.add(i) + for j in [j for j in self.neighbors(i) if j not in visited]: + queue.append(j) + def xyz_to_component_id(self, xyz, return_node=False): if tuple(xyz) in self.xyz_to_edge.keys(): edge = self.xyz_to_edge[tuple(xyz)] @@ -787,10 +781,7 @@ def nodes_to_zipped_swc( if len(node_to_idx) == 0: # Get attributes x, y, z = tuple(self.node_xyz[i]) - if preserve_radius: - r = self.node_radius[i] - else: - r = 6 if self.node_radius[i] == 5.3141592 else 2 + r = self.node_radius[i] if preserve_radius else 2 # Write entry text_buffer.write(f"\n1 2 {x} {y} {z} {r} -1") @@ -832,10 +823,9 @@ def branch_to_zip( node_id = n_entries + 1 parent = n_entries if k > 1 else parent x, y, z = tuple(branch_xyz[k]) - if preserve_radius: - r = branch_radius[k] - else: - r = 6 if branch_radius[k] == 5.3141592 else 2 + r = branch_radius[k] if preserve_radius else 2 + if frozenset({i, j}) in self.accepts: + r = 6 # Write entry text_buffer.write(f"\n{node_id} 2 {x} {y} {z} {r} {parent}") diff --git a/src/neuron_proofreader/split_proofreading/split_inference.py b/src/neuron_proofreader/split_proofreading/split_inference.py index aca5db4..b607217 100644 --- a/src/neuron_proofreader/split_proofreading/split_inference.py +++ b/src/neuron_proofreader/split_proofreading/split_inference.py @@ -28,7 +28,6 @@ """ -from datetime import datetime from time import time from torch.nn.functional import sigmoid from tqdm import tqdm @@ -70,23 +69,24 @@ def __init__( Parameters ---------- - brain_id : str - Identifier for the whole-brain dataset. - segmentation_id : str - Identifier for the segmentation model that generated fragments. + fragments_path : str + Path to SWC files to be loaded into graph. img_path : str - Path to the whole-brain image stored in a GCS or S3 bucket. + Path to whole-brain image corresponding to the given fragments. model_path : str - Path to machine learning model parameters. + Path to checkpoint file containing model weights. output_dir : str Directory where the results of the inference will be saved. config : Config Configuration object containing parameters and settings required for the inference pipeline. + log_preamble : str, optional + String to be added to the beginning of log. Default is an empty + string. segmentation_path : str, optional - Path to segmentation stored in GCS bucket. The default is None. + Path to segmentation stored in GCS bucket. Default is None. soma_centroids : List[Tuple[float]] or None, optional - Physcial coordinates of soma centroids. The default is an empty + Physcial coordinates of soma centroids. Default is an empty list. """ # Instance attributes @@ -147,9 +147,8 @@ def __call__(self, search_radius): Parameters ---------- - swcs_path : str - Path to SWC files used to build an instance of FragmentGraph, - see "swc_util.Reader" for further documentation. + search_radius : float + Search radius (in microns) used to generate proposals. """ # Main t0 = time() @@ -175,6 +174,7 @@ def generate_proposals(self, search_radius): # Main t0 = time() self.log("\nStep 2: Generate Proposals") + self.log(f"Search Radius: {search_radius}") self.dataset.graph.generate_proposals(search_radius) n_proposals = format(self.dataset.graph.n_proposals(), ",") n_proposals_blocked = self.dataset.graph.n_proposals_blocked @@ -190,24 +190,24 @@ def classify_proposals(self, accept_threshold, dt=0.05): self.log("Step 3: Run Inference") # Main - accepts = set() new_threshold = 0.99 preds = self.predict_proposals() while True: # Update graph cur_threshold = new_threshold - accepts.update(self.merge_proposals(preds, cur_threshold)) + self.merge_proposals(preds, cur_threshold) # Update threshold new_threshold = max(cur_threshold - dt, accept_threshold) if cur_threshold == new_threshold: break + n_accepts = len(self.dataset.graph.accepts) # Report results t, unit = util.time_writer(time() - t0) self.log(f"# Merges Blocked: {self.dataset.graph.n_merges_blocked}") - self.log(f"# Accepted: {format(len(accepts), ',')}") - self.log(f"% Accepted: {len(accepts) / len(preds):.4f}") + self.log(f"# Accepted: {format(n_accepts, ',')}") + self.log(f"% Accepted: {n_accepts / len(preds):.4f}") self.log(f"Module Runtime: {t:.4f} {unit}\n") def predict_proposals(self): @@ -220,11 +220,10 @@ def predict_proposals(self): # Save results path = os.path.join(self.output_dir, "proposal_predictions.json") - util.write_json(path, reformat_preds(preds)) + util.write_json(path, self.reformat_preds(preds)) return preds def merge_proposals(self, preds, threshold): - accepts = list() for proposal in self.dataset.graph.get_sorted_proposals(): # Check if proposal has been visited if proposal not in preds: @@ -238,9 +237,7 @@ def merge_proposals(self, preds, threshold): # Check if proposal creates a loop if not nx.has_path(self.dataset.graph, i, j): self.dataset.graph.merge_proposal(proposal) - accepts.append(proposal) del preds[proposal] - return accepts def save_results(self): """ @@ -262,19 +259,19 @@ def save_results(self): # Save additional info self.save_connections() - self.write_metadata() + self.config.save(self.output_dir) self.log_handle.close() - # Save result on s3 (if applicable) - if self.s3_dict is not None: - util.upload_dir_to_s3( - self.output_dir, - self.s3_dict["bucket_name"], - self.s3_dict["prefix"] - ) - # --- Helpers --- def log(self, txt): + """ + Logs and prints the given text. + + Parameters + ---------- + txt : str + Text to be logged and printed. + """ print(txt) self.log_handle.write(txt) self.log_handle.write("\n") @@ -304,6 +301,15 @@ def predict(self, data): hat_y = ml_util.tensor_to_list(hat_y) return {idx_to_id[i]: y_i for i, y_i in enumerate(hat_y)} + def reformat_preds(self, preds_dict): + id_to_pred = dict() + for proposal, pred in preds_dict.items(): + node1, node2 = tuple(proposal) + id1 = self.dataset.graph.get_swc_id(node1) + id2 = self.dataset.graph.get_swc_id(node2) + id_to_pred[str((id1, id2))] = pred + return id_to_pred + def save_connections(self, round_id=None): """ Writes the accepted proposals from the graph to a text file. Each line @@ -312,34 +318,10 @@ def save_connections(self, round_id=None): suffix = f"-{round_id}" if round_id else "" path = os.path.join(self.output_dir, f"connections{suffix}.txt") with open(path, "w") as f: - for id_1, id_2 in self.graph.merged_ids: + for id_1, id_2 in self.dataset.graph.merged_ids: f.write(f"{id_1}, {id_2}" + "\n") def save_fragment_ids(self): path = f"{self.output_dir}/segment_ids.txt" segment_ids = list(self.dataset.graph.component_id_to_swc_id.values()) util.write_list(path, segment_ids) - - def write_metadata(self): - """ - Writes metadata about the current pipeline run to a JSON file. - """ - metadata = { - "date": datetime.today().strftime("%Y-%m-%d"), - "min_fragment_size": f"{self.config.graph.min_size}um", - "min_fragment_size_with_proposals": f"{self.config.graph.min_size_with_proposals}um", - "node_spacing": self.config.graph.node_spacing, - "remove_doubles": self.config.graph.remove_doubles, - "use_somas": len(self.soma_centroids) > 0, - "proposals_per_leaf": self.config.graph.proposals_per_leaf, - "search_radius": f"{self.config.graph.search_radius}um", - "model_name": os.path.basename(self.model_path), - "accept_threshold": self.config.ml.threshold, - } - path = os.path.join(self.output_dir, "metadata.json") - util.write_json(path, metadata) - - -# --- Helpers --- -def reformat_preds(preds_dict): - return {str(k): v for k, v in preds_dict.items()}