From cabc6d3ff707a479c6f6c28ef2f3718d43f74d57 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Fri, 30 Jan 2026 21:01:25 +0000 Subject: [PATCH 1/3] refactor: updated split inference pipeline --- src/neuron_proofreader/config.py | 16 +- .../merge_feature_generation.py | 0 src/neuron_proofreader/proposal_graph.py | 6 +- .../split_proofreading/proposal_generation.py | 3 +- .../split_proofreading/split_datasets.py | 5 +- ...raction.py => split_feature_extraction.py} | 0 .../split_proofreading/split_inference.py | 269 ++++++++---------- src/neuron_proofreader/utils/graph_util.py | 14 +- src/neuron_proofreader/utils/swc_util.py | 20 +- 9 files changed, 139 insertions(+), 194 deletions(-) create mode 100644 src/neuron_proofreader/merge_proofreading/merge_feature_generation.py rename src/neuron_proofreader/split_proofreading/{feature_extraction.py => split_feature_extraction.py} (100%) diff --git a/src/neuron_proofreader/config.py b/src/neuron_proofreader/config.py index 4c18c02a..1f8baede 100644 --- a/src/neuron_proofreader/config.py +++ b/src/neuron_proofreader/config.py @@ -81,22 +81,14 @@ class MLConfig: batch_size : int The number of samples processed in one batch during training or inference. Default is 64. - multiscale : int - Level in the image pyramid that voxel coordinates must index into. threshold : float A general threshold value used for classification. Default is 0.6. - model_type : str - Type of machine learning model to use. Default is "GraphNeuralNet". """ batch_size: int = 64 - device: str = "cpu" - lr: float = 1e-4 - multiscale: int = 1 - n_epochs: int = 1000 - threshold: float = 0.6 - transform: bool = False - validation_split: float = 0.15 - weight_decay: float = 1e-3 + brightness_clip: int = 400 + device: str = "cuda" + patch_shape : tuple = (96, 96, 96) + threshold: float = 0.8 class Config: diff --git a/src/neuron_proofreader/merge_proofreading/merge_feature_generation.py b/src/neuron_proofreader/merge_proofreading/merge_feature_generation.py new file mode 100644 index 00000000..e69de29b diff --git a/src/neuron_proofreader/proposal_graph.py b/src/neuron_proofreader/proposal_graph.py index b8cfd248..8a683672 100644 --- a/src/neuron_proofreader/proposal_graph.py +++ b/src/neuron_proofreader/proposal_graph.py @@ -56,7 +56,6 @@ def __init__( prune_depth=20.0, remove_high_risk_merges=False, segmentation_path=None, - smooth_bool=True, soma_centroids=None, verbose=False, ): @@ -82,9 +81,6 @@ def __init__( branching points). Default is False. segmentation_path : str, optional Path to segmentation stored in GCS bucket. Default is None. - smooth_bool : bool, optional - Indication of whether to smooth xyz coordinates from SWC files. - Default is True. soma_centroids : List[Tuple[float]] or None, optional Phyiscal coordinates of soma centroids. Default is None. verbose : bool, optional @@ -106,6 +102,7 @@ def __init__( self.gt_accepts = set() self.merged_ids = set() self.n_merges_blocked = 0 + self.n_proposals_blocked = 0 self.node_proposals = defaultdict(set) self.proposals = set() @@ -123,7 +120,6 @@ def __init__( prune_depth=prune_depth, remove_high_risk_merges=remove_high_risk_merges, segmentation_path=segmentation_path, - smooth_bool=smooth_bool, soma_centroids=soma_centroids, verbose=verbose, ) diff --git a/src/neuron_proofreader/split_proofreading/proposal_generation.py b/src/neuron_proofreader/split_proofreading/proposal_generation.py index 9f22a330..0b592c41 100644 --- a/src/neuron_proofreader/split_proofreading/proposal_generation.py +++ b/src/neuron_proofreader/split_proofreading/proposal_generation.py @@ -59,7 +59,6 @@ def __init__( self.max_attempts = max_attempts self.max_proposals_per_leaf = max_proposals_per_leaf self.min_size_with_proposals = min_size_with_proposals - self.n_proposals_blocked = 0 self.search_scaling_factor = search_scaling_factor def __call__(self, initial_radius): @@ -246,7 +245,7 @@ def is_valid_proposal(self, leaf, i): """ if i is not None: is_soma = (self.graph.is_soma(i) and self.graph.is_soma(leaf)) - self.n_proposals_blocked += 1 if is_soma else 0 + self.graph.n_proposals_blocked += 1 if is_soma else 0 return not is_soma else: return False diff --git a/src/neuron_proofreader/split_proofreading/split_datasets.py b/src/neuron_proofreader/split_proofreading/split_datasets.py index 96d0eaba..adee3210 100644 --- a/src/neuron_proofreader/split_proofreading/split_datasets.py +++ b/src/neuron_proofreader/split_proofreading/split_datasets.py @@ -19,7 +19,7 @@ from neuron_proofreader.machine_learning.subgraph_sampler import ( SubgraphSampler ) -from neuron_proofreader.split_proofreading.feature_extraction import ( +from neuron_proofreader.split_proofreading.split_feature_extraction import ( FeaturePipeline, HeteroGraphData ) @@ -132,7 +132,8 @@ def load_graph(self, swc_pointer, is_gt=False, metadata_path=None): swc_pointer : str Path to SWC files to be loaded. metadata_path : str - ... + Patch to JSON file containing metadata on block that fragments + were extracted from. Returns ------- diff --git a/src/neuron_proofreader/split_proofreading/feature_extraction.py b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py similarity index 100% rename from src/neuron_proofreader/split_proofreading/feature_extraction.py rename to src/neuron_proofreader/split_proofreading/split_feature_extraction.py diff --git a/src/neuron_proofreader/split_proofreading/split_inference.py b/src/neuron_proofreader/split_proofreading/split_inference.py index 16f80100..959baae8 100644 --- a/src/neuron_proofreader/split_proofreading/split_inference.py +++ b/src/neuron_proofreader/split_proofreading/split_inference.py @@ -41,13 +41,14 @@ from neuron_proofreader.proposal_graph import ProposalGraph from neuron_proofreader.machine_learning.subgraph_sampler import ( - SubgraphSampler, SeededSubgraphSampler + SubgraphSampler, + SeededSubgraphSampler ) -from neuron_proofreader.split_proofreading import datasets -from neuron_proofreader.split_proofreading.feature_generation import ( - FeatureGenerator, +from neuron_proofreader.split_proofreading.split_feature_extraction import ( + FeaturePipeline, + HeteroGraphData ) -from neuron_proofreader.split_proofreading.models import init_model +from neuron_proofreader.machine_learning.gnn_models import VisionHGAT from neuron_proofreader.utils import geometry_util, ml_util, util @@ -121,14 +122,14 @@ def __init__( self.log_handle = open(log_path, 'a') # --- Core --- - def run(self, swc_pointer): + def run(self, swcs_path, search_radius): """ Executes the full inference pipeline. Parameters ---------- - swc_pointer : Any - Pointer to SWC files used to build an instance of FragmentGraph, + swcs_path : str + Path to SWC files used to build an instance of FragmentGraph, see "swc_util.Reader" for further documentation. """ # Initializations @@ -137,50 +138,48 @@ def run(self, swc_pointer): t0 = time() # Main - self.build_graph(swc_pointer) + self.build_graph(swcs_path) self.connect_soma_fragments() if self.soma_centroids else None - self.generate_proposals(self.graph_config.search_radius) - self.classify_proposals(self.ml_config.threshold) + self.generate_proposals(search_radius) + self.classify_proposals(self.ml_config.threshold, search_radius) # Finish t, unit = util.time_writer(time() - t0) - self.report_graph(prefix="\nFinal") - self.report(f"Total Runtime: {t:.2f} {unit}\n") + self.log_graph_specs(prefix="\nFinal") + self.log(f"Total Runtime: {t:.2f} {unit}\n") self.save_results() - def run_schedule( - self, swc_pointer, radius_schedule, threshold_schedule - ): + def run_schedule(self, swcs_path, radius_schedule, threshold_schedule): # Initializations self.log_experiment() self.write_metadata() t0 = time() # Main - self.build_graph(swc_pointer) - for i, radius in enumerate(radius_schedule): - self.report(f"\n--- Round {i + 1}: Radius = {radius} ---") - self.generate_proposals(radius_schedule[i]) - self.classify_proposals(threshold_schedule[i]) - self.report_graph(prefix="Current") + self.build_graph(swcs_path) + schedules = zip(radius_schedule, threshold_schedule) + for i, (radius, threshold) in enumerate(schedules): + self.log(f"\n--- Round {i + 1}: Radius = {radius} ---") + self.generate_proposals(radius) + self.classify_proposals(threshold) + self.log_graph_specs(prefix="Current") # Finish t, unit = util.time_writer(time() - t0) - self.report_graph(prefix="\nFinal") - self.report(f"Total Runtime: {t:.2f} {unit}\n") + self.log_graph_specs(prefix="\nFinal") + self.log(f"Total Runtime: {t:.2f} {unit}\n") self.save_results() - def build_graph(self, swc_pointer): + def build_graph(self, swcs_path): """ Builds a graph from the given fragments. Parameters ---------- - fragment_pointer : dict, list, str - Pointer to SWC files used to build an instance of FragmentGraph, - see "swc_util.Reader" for further documentation. + fragment_pointer : str + Path to SWC files to be loaded into graph. """ - self.report("Step 1: Build Fragments Graph") + self.log("Step 1: Build Graph") t0 = time() # Initialize graph @@ -192,11 +191,10 @@ def build_graph(self, swc_pointer): prune_depth=self.graph_config.prune_depth, remove_high_risk_merges=self.graph_config.remove_high_risk_merges, segmentation_path=self.segmentation_path, - smooth_bool=self.graph_config.smooth_bool, soma_centroids=self.soma_centroids, verbose=True, ) - self.graph.load(swc_pointer) + self.graph.load(swcs_path) # Filter fragments if self.graph_config.remove_doubles: @@ -209,17 +207,17 @@ def build_graph(self, swc_pointer): print("# Soma Fragments:", len(self.graph.soma_ids)) t, unit = util.time_writer(time() - t0) - self.report_graph(prefix="\nInitial") - self.report(f"Module Runtime: {t:.2f} {unit}\n") + self.log_graph_specs(prefix="\nInitial") + self.log(f"Module Runtime: {t:.2f} {unit}\n") def connect_soma_fragments(self): # Initializations - self.graph.init_kdtree() + self.graph.set_kdtree() # Parse locations merge_cnt, soma_cnt = 0, 0 for soma_xyz in self.soma_centroids: - node_ids = self.graph.find_fragments_near_xyz(soma_xyz, 20) + node_ids = self.graph.find_fragments_near_xyz(soma_xyz, 25) if len(node_ids) > 1: # Find closest node to soma location soma_cnt += 1 @@ -249,30 +247,24 @@ def connect_soma_fragments(self): print("# Soma Fragment Merges:", merge_cnt) del self.graph.kdtree - def generate_proposals(self, radius): + def generate_proposals(self, search_radius): """ Generates proposals for the fragments graph based on the specified configuration. """ # Main t0 = time() - self.report("Step 2: Generate Proposals") - self.graph.generate_proposals( - radius, - complex_bool=self.graph_config.complex_bool, - long_range_bool=self.graph_config.long_range_bool, - proposals_per_leaf=self.graph_config.proposals_per_leaf, - trim_endpoints_bool=self.graph_config.trim_endpoints_bool, - ) + self.log("\nStep 2: Generate Proposals") + self.graph.generate_proposals(search_radius) n_proposals = format(self.graph.n_proposals(), ",") # Report results t, unit = util.time_writer(time() - t0) - self.report(f"# Proposals: {n_proposals}") - self.report(f"# Proposals Blocked: {self.graph.n_proposals_blocked}") - self.report(f"Module Runtime: {t:.2f} {unit}\n") + self.log(f"# Proposals: {n_proposals}") + self.log(f"# Proposals Blocked: {self.graph.n_proposals_blocked}") + self.log(f"Module Runtime: {t:.2f} {unit}\n") - def classify_proposals(self, accept_threshold): + def classify_proposals(self, accept_threshold, search_radius): """ Classifies proposals by calling "self.inference_engine". This routine generates features and runs a GNN to make predictions. Proposals with @@ -280,29 +272,32 @@ def classify_proposals(self, accept_threshold): graph as an edge. """ # Initializations - self.report("Step 3: Run Inference") - proposals = self.graph.list_proposals() - - # Main + self.log("Step 3: Run Inference") t0 = time() - self.inference_engine = InferenceEngine( + + # Generate model predictions + n_proposals = self.graph.n_proposals() + inference_engine = InferenceEngine( self.graph, self.img_path, self.model_path, self.ml_config, - self.graph_config.search_radius, - accept_threshold=accept_threshold, + search_radius, segmentation_path=self.segmentation_path, ) - accepts = self.inference_engine.run() - self.accepted_proposals.extend(accepts) + preds_dict = inference_engine.run() + path = os.path.join(self.output_dir, "proposal_predictions.json") + util.write_json(path, reformat_preds(preds_dict)) + + # Update graph + stop # Report results t, unit = util.time_writer(time() - t0) - self.report(f"# Merges Blocked: {self.graph.n_merges_blocked}") - self.report(f"# Accepted: {format(len(accepts), ',')}") - self.report(f"% Accepted: {len(accepts) / len(proposals):.4f}") - self.report(f"Module Runtime: {t:.4f} {unit}\n") + self.log(f"# Merges Blocked: {self.graph.n_merges_blocked}") + self.log(f"# Accepted: {format(len(accepts), ',')}") + self.log(f"% Accepted: {len(accepts) / n_proposals:.4f}") + self.log(f"Module Runtime: {t:.4f} {unit}\n") def save_results(self): """ @@ -371,19 +366,19 @@ def write_metadata(self): util.write_json(path, metadata) # --- Summaries --- - def report(self, txt): + def log(self, txt): print(txt) self.log_handle.write(txt) self.log_handle.write("\n") def log_experiment(self): - self.report("\nExperiment Overview") - self.report("-------------------------------------------------------") - self.report(f"Brain_ID: {self.brain_id}") - self.report(f"Segmentation_ID: {self.segmentation_id}") - self.report("\n") + self.log("\nExperiment Overview") + self.log("-" * len(self.segmentation_id)) + self.log(f"Brain_ID: {self.brain_id}") + self.log(f"Segmentation_ID: {self.segmentation_id}") + self.log("\n") - def report_graph(self, prefix="\n"): + def log_graph_specs(self, prefix="\n"): """ Prints an overview of the graph's structure and memory usage. """ @@ -393,12 +388,12 @@ def report_graph(self, prefix="\n"): n_nodes = format(self.graph.number_of_nodes(), ",") n_edges = format(self.graph.number_of_edges(), ",") - # Report - self.report(f"{prefix} Graph") - self.report(f"# Connected Components: {n_components}") - self.report(f"# Nodes: {n_nodes}") - self.report(f"# Edges: {n_edges}") - self.report(f"Memory Consumption: {util.get_memory_usage():.2f} GBs") + # Report results + self.log(f"{prefix} Graph") + self.log(f"# Connected Components: {n_components}") + self.log(f"# Nodes: {n_nodes}") + self.log(f"# Edges: {n_edges}") + self.log(f"Memory Consumption: {util.get_memory_usage():.2f} GBs") class InferenceEngine: @@ -413,8 +408,7 @@ def __init__( img_path, model_path, ml_config, - radius, - accept_threshold=0.9, + search_radius, segmentation_path=None, ): """ @@ -430,7 +424,7 @@ def __init__( ml_config : MLConfig Configuration object containing parameters and settings required for the inference. - radius : float + search_radius : float Search radius used to generate proposals. segmentation_path : str, optional Path to segmentation stored in GCS bucket. Default is None. @@ -438,108 +432,64 @@ def __init__( # Instance attributes self.batch_size = ml_config.batch_size self.device = ml_config.device - self.graph = graph - self.ml_config = ml_config - self.radius = radius - self.threshold = accept_threshold + self.model = VisionHGAT(ml_config.patch_shape) + self.pbar = tqdm(total=graph.n_proposals(), desc="Inference") + self.subgraph_sampler = self.get_subgraph_sampler(graph) # Feature generator - self.feature_generator = FeatureGenerator( - self.graph, + self.feature_extractor = FeaturePipeline( + graph, img_path, - anisotropy=self.ml_config.anisotropy, - is_multimodal=self.ml_config.is_multimodal, - multiscale=self.ml_config.multiscale, - segmentation_path=segmentation_path, + search_radius, + brightness_clip=ml_config.brightness_clip, + patch_shape=ml_config.patch_shape, + segmentation_path=segmentation_path ) - # Model - device = "cuda" if ml_config.is_multimodal else "cpu" - self.model = init_model(ml_config.is_multimodal) - ml_util.load_model(self.model, model_path, device=device) + # Load weights + ml_util.load_model(self.model, model_path, device=ml_config.device) - def init_dataloader(self): - if len(self.graph.soma_ids) > 0: - return SeededSubgraphSampler(self.graph, self.batch_size) + def get_subgraph_sampler(self, graph): + if len(graph.soma_ids) > 0: + return SeededSubgraphSampler(graph, self.batch_size) else: - return SubgraphSampler(self.graph, self.batch_size) - - def run(self, return_preds=False): - """ - Runs inference by forming batches of proposals, then performing the - following steps for each batch: (1) generate features, (2) classify - proposals by running model, and (3) adding each accepted proposal as - an edge to "graph" if it does not create a cycle. + return SubgraphSampler(graph, self.batch_size) - Parameters - ---------- - graph : ProposalGraph - Graph that inference will be performed on. - proposals : list - Proposals to be classified as accept or reject. - - Returns - ------- - ProposalGraph - Updated graph with accepted proposals added as edges. - list - Accepted proposals. - """ - # Initializations - dataloader = self.init_dataloader() - pbar = tqdm(total=self.graph.n_proposals(), desc="Inference") - - # Main - accepts = list() - hat_y = dict() - for batch in dataloader: - # Feature generation - features = self.feature_generator.run(batch, self.radius) - heterograph_data = datasets.init(features, batch["graph"]) + def run(self): + preds = dict() + for subgraph in self.subgraph_sampler: + # Get model inputs + features = self.feature_extractor(subgraph) + data = HeteroGraphData(features) # Run model - hat_y_i = self.predict(heterograph_data) - if return_preds: - hat_y.update(hat_y_i) + preds.update(self.predict(data)) + self.pbar.update(subgraph.n_proposals()) + return preds - # Determine which proposals to accept - accepts.extend(self.update_graph(hat_y_i)) - pbar.update(len(batch["proposals"])) - - # Return results - if return_preds: - return accepts, hat_y - else: - return accepts - - def predict(self, heterograph_data): + def predict(self, data): """ - Runs the model on the given dataset to generate and filter - predictions. + ... Parameters ---------- - data : HeteroGeneousDataset - Dataset containing graph information, including feature matrices - and other relevant attributes needed for GNN input. + data : HeteroGraphData + ... + Returns ------- - dict - A dictionary that maps a proposal to the model's prediction (i.e. - probability). + Dict[Frozenset[int], float] + Dictionary that maps proposal IDs to model predictions. """ # Generate predictions with torch.no_grad(): - x, edge_index, edge_attr = ml_util.get_inputs( - heterograph_data.data, self.device - ) - hat_y = sigmoid(self.model(x, edge_index, edge_attr)) + x = data.get_inputs().to(self.device) + hat_y = sigmoid(self.model(x)) # Reformat predictions - n_proposals = len(heterograph_data.data["proposal"]["y"]) - hat_y = ml_util.to_cpu(hat_y[0:n_proposals, 0], to_numpy=True) - idxs = heterograph_data.idxs_proposals["idx_to_id"] - return {idxs[i]: p for i, p in enumerate(hat_y)} + idx_to_id = data.idxs_proposals.idx_to_id + hat_y = ml_util.tensor_to_list(hat_y) + return {idx_to_id[i]: y_i for i, y_i in enumerate(hat_y)} def update_graph(self, preds, high_threshold=0.9): """ @@ -628,3 +578,8 @@ def add_accepts(self, proposals): self.graph.merge_proposal(proposal) accepts.append(proposal) return accepts + + +# --- Helpers --- +def reformat_preds(preds_dict): + return {tuple(k): v for k, v in preds_dict.items()} diff --git a/src/neuron_proofreader/utils/graph_util.py b/src/neuron_proofreader/utils/graph_util.py index 6ed54ae2..f174a062 100644 --- a/src/neuron_proofreader/utils/graph_util.py +++ b/src/neuron_proofreader/utils/graph_util.py @@ -56,7 +56,6 @@ def __init__( prune_depth=24.0, remove_high_risk_merges=False, segmentation_path=None, - smooth_bool=True, soma_centroids=None, verbose=False, ): @@ -83,9 +82,6 @@ def __init__( branching points). Default is False. segmentation_path : str, optional Path to segmentation stored in GCS bucket. Default is None. - smooth_bool : bool, optional - Indication of whether to smooth xyz coordinates from SWC files. - Default is True. soma_centroids : List[Tuple[float]] or None, optional Physcial coordinates of soma centroids. Default is None. verbose : bool, optional @@ -98,7 +94,6 @@ def __init__( self.node_spacing = node_spacing self.prune_depth = prune_depth self.remove_high_risk_merges_bool = remove_high_risk_merges - self.smooth_bool = smooth_bool self.soma_centroids = soma_centroids self.verbose = verbose @@ -307,14 +302,13 @@ def dist(i, j): if graph.degree[j] != 2: path_length += edge_length irreducible_nodes.add(j) - attrs = to_numpy(attrs) - if self.smooth_bool: - n_pts = int(edge_length / self.node_spacing) - self.resample_curve_3d(graph, attrs, (root, j), n_pts) if graph.degree[j] == 1: leafs.add(j) - # Finish + attrs = to_numpy(attrs) + n_pts = int(edge_length / self.node_spacing) + self.resample_curve_3d(graph, attrs, (root, j), n_pts) + irreducible_edges[(root, j)] = attrs root = None diff --git a/src/neuron_proofreader/utils/swc_util.py b/src/neuron_proofreader/utils/swc_util.py index 316b5f3f..a53aeb8e 100644 --- a/src/neuron_proofreader/utils/swc_util.py +++ b/src/neuron_proofreader/utils/swc_util.py @@ -24,6 +24,7 @@ ThreadPoolExecutor, as_completed, ) +from google.auth.exceptions import RefreshError, TransportError from google.cloud import storage from io import BytesIO, StringIO from tqdm import tqdm @@ -235,8 +236,8 @@ def read_from_zip(self, zip_path): List of dictionaries whose keys and values are the attribute names and values from an SWC file. """ + swc_dicts = deque() with ZipFile(zip_path, "r") as zip_file: - swc_dicts = deque() swc_files = [f for f in zip_file.namelist() if f.endswith(".swc")] for f in swc_files: swc_dict = self.read_from_zipped_file(zip_file, f) @@ -402,7 +403,10 @@ def read_from_gcs_zips(self, bucket_name, zip_paths): # Store results for process in as_completed(processes): - swc_dicts.extend(process.result()) + try: + swc_dicts.extend(process.result()) + except RefreshError: + pass pbar.update(1) return swc_dicts @@ -426,10 +430,14 @@ def read_from_gcs_zip(self, bucket_name, zip_path, filenames=None): List of dictionaries whose keys and values are the attribute names and values from an SWC file. """ - # Download zip - client = storage.Client() - bucket = client.bucket(bucket_name) - zip_content = bucket.blob(zip_path).download_as_bytes() + try: + # Download zip + client = storage.Client() + bucket = client.bucket(bucket_name) + zip_content = bucket.blob(zip_path).download_as_bytes() + except TransportError: + print(f"Failed to read {zip_path}!") + return deque() # Process files swc_dicts = deque() From 78e00b35a51ec6b0fd22c588a7a46fc1b7652e98 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sat, 31 Jan 2026 02:04:34 +0000 Subject: [PATCH 2/3] messy and in progress --- src/neuron_proofreader/config.py | 35 +-- .../machine_learning/gnn_models.py | 4 +- src/neuron_proofreader/proposal_graph.py | 54 +++- src/neuron_proofreader/skeleton_graph.py | 15 + .../split_proofreading/split_datasets.py | 261 ++++++++++-------- .../split_feature_extraction.py | 14 +- .../split_proofreading/split_inference.py | 247 +++++------------ src/neuron_proofreader/utils/ml_util.py | 3 +- 8 files changed, 293 insertions(+), 340 deletions(-) diff --git a/src/neuron_proofreader/config.py b/src/neuron_proofreader/config.py index 1f8baede..ffab1de2 100644 --- a/src/neuron_proofreader/config.py +++ b/src/neuron_proofreader/config.py @@ -8,7 +8,7 @@ aspects of a system involving graphs, proposals, and machine learning (ML). """ -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Tuple @@ -24,13 +24,6 @@ class GraphConfig: Scaling factors applied to xyz coordinates to account for anisotropy of microscope. Note this instance of "anisotropy" is only used while reading SWC files. Default is (1.0, 1.0, 1.0). - complex_bool : bool - Indication of whether to generate complex proposals, meaning proposals - between leaf and non-leaf nodes. Default is False. - long_range_bool : bool, optional - Indication of whether to generate simple proposals within a scaled - distance of "search_radius" from leaves without any proposals. Default - is False. min_size : float, optional Minimum path length (in microns) of swc files which are stored as connected components in the FragmentsGraph. Default is 30. @@ -48,27 +41,23 @@ class GraphConfig: remove_high_risk_merges : bool, optional Indication of whether to remove high risk merge sites (i.e. close branching points). Default is False. - smooth_bool : bool, optional - Indication of whether to smooth branches in the graph. Default is - True. trim_endpoints_bool : bool, optional Indication of whether to endpoints of branches with exactly one proposal. Default is True. """ - anisotropy: Tuple[float] = field(default_factory=tuple) - complex_bool: bool = False - long_range_bool: bool = True - min_size: float = 30.0 - min_size_with_proposals: float = 0 - node_spacing: int = 5 + anisotropy: Tuple[float, float, float] = (1.0, 1.0, 1.0) + max_proposals_per_leaf: int = 3 + min_size: float = 40.0 + min_size_with_proposals: float = 40.0 + node_spacing: int = 1 proposals_per_leaf: int = 3 prune_depth: float = 24.0 - remove_doubles: bool = False + remove_doubles: bool = True remove_high_risk_merges: bool = False search_radius: float = 20.0 - smooth_bool: bool = True trim_endpoints_bool: bool = True + verbose: bool = True @dataclass @@ -87,7 +76,9 @@ class MLConfig: batch_size: int = 64 brightness_clip: int = 400 device: str = "cuda" - patch_shape : tuple = (96, 96, 96) + patch_shape: tuple = (96, 96, 96) + shuffle: bool = False + transform: bool = False threshold: float = 0.8 @@ -112,5 +103,5 @@ def __init__(self, graph_config: GraphConfig, ml_config: MLConfig): An instance of the "MLConfig" class that includes configuration parameters for machine learning models. """ - self.graph_config = graph_config - self.ml_config = ml_config + self.graph = graph_config + self.ml = ml_config diff --git a/src/neuron_proofreader/machine_learning/gnn_models.py b/src/neuron_proofreader/machine_learning/gnn_models.py index 1db956e6..2a3e5774 100644 --- a/src/neuron_proofreader/machine_learning/gnn_models.py +++ b/src/neuron_proofreader/machine_learning/gnn_models.py @@ -16,7 +16,7 @@ import torch.nn.init as init from neuron_proofreader.machine_learning.vision_models import CNN3D -from neuron_proofreader.split_proofreading import feature_extraction +from neuron_proofreader.split_proofreading import split_feature_extraction from neuron_proofreader.utils.ml_util import FeedForwardNet @@ -125,7 +125,7 @@ def init_node_embedding(output_dim): features. """ # Get feature dimensions - node_input_dims = feature_extraction.get_feature_dict() + node_input_dims = split_feature_extraction.get_feature_dict() dim_b = node_input_dims["branch"] dim_p = node_input_dims["proposal"] diff --git a/src/neuron_proofreader/proposal_graph.py b/src/neuron_proofreader/proposal_graph.py index 8a683672..537862e5 100644 --- a/src/neuron_proofreader/proposal_graph.py +++ b/src/neuron_proofreader/proposal_graph.py @@ -31,7 +31,9 @@ trim_endpoints_at_proposal ) from neuron_proofreader.utils import ( - geometry_util as geometry, graph_util as gutil, util, + geometry_util as geometry, + graph_util as gutil, + util, ) @@ -48,7 +50,7 @@ class ProposalGraph(SkeletonGraph): def __init__( self, anisotropy=(1.0, 1.0, 1.0), - filter_transitive_proposals=False, + gt_path=None, max_proposals_per_leaf=3, min_size=0, min_size_with_proposals=0, @@ -57,7 +59,7 @@ def __init__( remove_high_risk_merges=False, segmentation_path=None, soma_centroids=None, - verbose=False, + verbose=True, ): """ Instantiates a ProposalGraph object. @@ -93,6 +95,7 @@ def __init__( # Instance attributes - Graph self.anisotropy = anisotropy self.component_id_to_swc_id = dict() + self.gt_path = gt_path self.leaf_kdtree = None self.soma_ids = set() self.verbose = verbose @@ -192,6 +195,40 @@ def _add_edge(self, edge_id, attrs): self.add_edge(i, j, radius=attrs["radius"], xyz=attrs["xyz"]) self.xyz_to_edge.update({tuple(xyz): edge_id for xyz in attrs["xyz"]}) + def connect_soma_fragments(self, soma_centroids): + merge_cnt, soma_cnt = 0, 0 + for soma_xyz in soma_centroids: + node_ids = self.find_fragments_near_xyz(soma_xyz, 25) + if len(node_ids) > 1: + # Find closest node to soma location + soma_cnt += 1 + best_dist = np.inf + best_node = None + for i in node_ids: + dist = geometry.dist(soma_xyz, self.node_xyz[i]) + if dist < best_dist: + best_dist = dist + best_node = i + soma_component_id = self.node_component_id[best_node] + self.soma_ids.add(soma_component_id) + node_ids.remove(best_node) + + # Merge fragments to soma + soma_xyz = self.node_xyz[best_node] + for i in node_ids: + attrs = { + "radius": np.array([2, 2]), + "xyz": np.array([soma_xyz, self.node_xyz[i]]), + } + self._add_edge((best_node, i), attrs) + self.update_component_ids(soma_component_id, i) + merge_cnt += 1 + + # Summarize results + results = [f"# Somas Connected: {soma_cnt}"] + results.append(f"# Soma Fragments Merged: {merge_cnt}") + return results + def relabel_nodes(self): """ Reassigns contiguous node IDs and update all dependent structures. @@ -319,7 +356,7 @@ def query_kdtree(self, xyz, d, node_type=None): return geometry.query_ball(self.kdtree, xyz, d) # --- Proposal Generation --- - def generate_proposals(self, search_radius, gt_graph=None): + def generate_proposals(self, search_radius): """ Generates proposals from leaf nodes. @@ -330,16 +367,17 @@ def generate_proposals(self, search_radius, gt_graph=None): gt_graph : networkx.Graph, optional Ground truth graph. Default is None. """ - # Generate proposals + # Proposal pipeline proposals = self.proposal_generator(search_radius) + self.search_radius = search_radius self.store_proposals(proposals) self.trim_proposals() # Set groundtruth - if gt_graph: + if self.gt_path: + gt_graph = ProposalGraph(anisotropy=self.anisotropy) + gt_graph.load(self.gt_path) self.gt_accepts = groundtruth_generation.run(gt_graph, self) - else: - self.gt_accepts = set() def add_proposal(self, i, j): """ diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index fcf77179..949cab7a 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -592,6 +592,21 @@ def find_closest_node(self, xyz): _, node = self.kdtree.query(xyz) return node + def get_summary(self, prefix=""): + # Compute values + n_components = format(nx.number_connected_components(self), ",") + n_nodes = format(self.graph.number_of_nodes(), ",") + n_edges = format(self.graph.number_of_edges(), ",") + memory = util.get_memory_usage() + + # Compile results + summary = [f"{prefix} Graph"] + summary.append(f"# Connected Components: {n_components}") + summary.append(f"# Nodes: {n_nodes}") + summary.append(f"# Edges: {n_edges}") + summary.append(f"Memory Consumption: {memory:.2f} GBs") + return "\n".join(summary) + def path_length(self, max_depth=np.inf, root=None): """ Computes the path length of the connected component that contains the diff --git a/src/neuron_proofreader/split_proofreading/split_datasets.py b/src/neuron_proofreader/split_proofreading/split_datasets.py index adee3210..4c3c523d 100644 --- a/src/neuron_proofreader/split_proofreading/split_datasets.py +++ b/src/neuron_proofreader/split_proofreading/split_datasets.py @@ -10,13 +10,13 @@ from torch.utils.data import IterableDataset -import numpy as np import os import pandas as pd from neuron_proofreader.proposal_graph import ProposalGraph from neuron_proofreader.machine_learning.augmentation import ImageTransforms from neuron_proofreader.machine_learning.subgraph_sampler import ( + SeededSubgraphSampler, SubgraphSampler ) from neuron_proofreader.split_proofreading.split_feature_extraction import ( @@ -26,114 +26,70 @@ from neuron_proofreader.utils import geometry_util, img_util, util +# --- Single Brain Dataset --- class FragmentsDataset(IterableDataset): - """ - A dataset class for storing graphs used to train models to perform split - correction. Graphs are stored in the "self.graphs" attribute, which is a - dictionary containing the followin items: - - Key: (brain_id, segmentation_id, example_id) - - Value: graph that is an instance of ProposalGraph - - This dataset is populated using the "self.add_graph" method, which - requires the following inputs: - (1) key: Unique identifier of graph. - (2) gt_pointer: Path to ground truth SWC files. - (2) pred_pointer: Path to predicted SWC files. - (3) img_path: Path to whole-brain image stored in cloud bucket. - - Note: This dataset supports graphs from multiple whole-brain datasets. - """ def __init__( self, - config, - brightness_clip=400, - patch_shape=(128, 128, 128), - shuffle=True, - transform=False - ): - """ - Instantiates a FragmentsDataset object. - - Parameters - ---------- - config : GraphConfig - Config object that stores parameters used to build graphs. - patch_shape : Tuple[int], optional - Shape of image patch input to a vision model. Default is (128, 128, - 128). - transform : bool, optional - Indication of whether to apply augmentation to input images. - Default is False. - """ - # Instance attributes - self.brightness_clip = brightness_clip - self.config = config - self.feature_extractors = dict() - self.graphs = dict() - self.patch_shape = patch_shape - self.shuffle = shuffle - self.transform = ImageTransforms() if transform else False - - # --- Load Data --- - def add_graph( - self, - key, - gt_pointer, - pred_pointer, + fragments_path, img_path, + config, + gt_path=None, metadata_path=None, - segmentation_path=None + segmentation_path=None, + soma_centroids=None ): """ - Loads a fragments graph, generates proposals, and initializes feature - extraction. + Instantiates a FragmentsDataset object. Parameters ---------- - key : Tuple[str] - Unique identifier used to register the graph and its feature - pipeline. - gt_pointer : str - Path to ground-truth SWC files to be loaded. - pred_pointer : str + fragments_path : str Path to predicted SWC files to be loaded. img_path : str - Path to the raw image associated with the graph. + Path to the raw image associated with the fragments. + config : Config + ... + gt_pointer : str, optional + Path to ground-truth SWC files to be loaded. Default is None. metadata_path : str ... segmentation_path : str - Path to the segmentation mask associated with the graph. + Path to the segmentation that fragments were generated from. + Default is None. """ - # Add graph - gt_graph = self.load_graph(gt_pointer, is_gt=True) - self.graphs[key] = self.load_graph(pred_pointer, metadata_path) - self.graphs[key].generate_proposals( - self.config.search_radius, gt_graph=gt_graph + # Instance attributes + self.config = config + self.transform = ImageTransforms() if config.ml.transform else False + self.graph = self._load_graph( + fragments_path, metadata_path, segmentation_path, soma_centroids ) - # Generate features - self.feature_extractors[key] = FeaturePipeline( - self.graphs[key], + # Feature extractor + self.feature_extractor = FeaturePipeline( + self.graph, img_path, - self.config.search_radius, - brightness_clip=self.brightness_clip, - patch_shape=self.patch_shape, + brightness_clip=self.config.ml.brightness_clip, + patch_shape=self.config.ml.patch_shape, segmentation_path=segmentation_path ) - def load_graph(self, swc_pointer, is_gt=False, metadata_path=None): + def _load_graph( + self, fragments_path, metadata_path, segmentation_path, soma_centroids + ): """ - Loads a graph by reading and processing SWC files specified by - "swc_pointer". + Loads a graph by reading and processing SWC files specified by the + given path. Parameters ---------- - swc_pointer : str + fragments_path : str Path to SWC files to be loaded. metadata_path : str Patch to JSON file containing metadata on block that fragments were extracted from. + soma_centroids : List[Tuple[float]] + List of physical coordinates that represent soma centers. Returns ------- @@ -142,14 +98,22 @@ def load_graph(self, swc_pointer, is_gt=False, metadata_path=None): """ # Build graph graph = ProposalGraph( - anisotropy=self.config.anisotropy, - min_size=self.config.min_size + anisotropy=self.config.graph.anisotropy, + min_size=self.config.graph.min_size, + min_size_with_proposals=self.config.graph.min_size_with_proposals, + node_spacing=self.config.graph.node_spacing, + prune_depth=self.config.graph.prune_depth, + remove_high_risk_merges=self.config.graph.remove_high_risk_merges, + segmentation_path=segmentation_path, + soma_centroids=soma_centroids ) - graph.load(swc_pointer) + graph.load(fragments_path) # Post process fragments - if not is_gt: + if metadata_path: self.clip_fragments(graph, metadata_path) + + if self.config.graph.remove_doubles: geometry_util.remove_doubles(graph, 200) return graph @@ -165,27 +129,93 @@ def __iter__(self): targets : torch.Tensor Ground truth labels. """ - # Initialize subgraph samplers - samplers = dict() - for key, graph in self.graphs.items(): - samplers[key] = iter(SubgraphSampler(graph, max_proposals=32)) + for subgraph in self.get_sampler(): + yield self.get_inputs(subgraph) + + def get_inputs(self, subgraph): + features = self.feature_extractor(subgraph) + data = HeteroGraphData(features) + if self.gt_path: + return data.get_inputs(), data.get_targets() + else: + return data.get_inputs() + + # --- Helpers --- + @staticmethod + def clip_fragments(graph, metadata_path): + # Extract bounding box + bucket_name, path = util.parse_cloud_path(metadata_path) + metadata = util.read_json_from_gcs(bucket_name, path) + origin = metadata["chunk_origin"][::-1] + shape = metadata["chunk_shape"][::-1] + + # Clip graph + nodes = list() + for i in graph.nodes: + voxel = graph.get_voxel(i) + if not img_util.is_contained(voxel - origin, shape): + nodes.append(i) + graph.remove_nodes_from(nodes) + graph.relabel_nodes() - # Iterate over dataset + def get_sampler(self): + batch_size = self.config.ml.batch_size + if len(self.graph.soma_ids) > 0: + sampler = SeededSubgraphSampler( + self.graph, max_proposals=batch_size + ) + else: + sampler = SubgraphSampler(self.graph, max_proposals=batch_size) + return iter(sampler) + + +# --- Multi-Brain Dataset --- +class MultiBrainFragmentsDataset: + """ + A dataset class for storing graphs used to train models to perform split + correction. Graphs are stored in the "self.graphs" attribute, which is a + dictionary containing the followin items: + - Key: (brain_id, segmentation_id, example_id) + - Value: graph that is an instance of ProposalGraph + + This dataset is populated using the "self.add_graph" method, which + requires the following inputs: + (1) key: Unique identifier of graph. + (2) gt_pointer: Path to ground truth SWC files. + (2) pred_pointer: Path to predicted SWC files. + (3) img_path: Path to whole-brain image stored in cloud bucket. + + Note: This dataset supports graphs from multiple whole-brain datasets. + """ + def __init__(self, shuffle=True): + # Instance attributes + self.datasets = dict() + self.shuffle = shuffle + + def add_dataset(self, key, dataset): + self.datasets[key] = dataset + + def __iter__(self): + """ + Iterates over the dataset and yields model-ready inputs and targets. + + Yields + ------ + inputs : HeteroGraphData + Heterogeneous graph data. + targets : torch.Tensor + Ground truth labels. + """ + samplers = self.init_samplers() while len(samplers) > 0: key = self.get_next_key(samplers) try: - # Feature extraction subgraph = next(samplers[key]) - features = self.feature_extractors[key](subgraph) - - # Get model inputs - data = HeteroGraphData(features) - inputs = data.get_inputs() - targets = data.get_targets() - yield inputs, targets + yield self.datasets.get_inputs(subgraph) except StopIteration: del samplers[key] + # --- Helpers --- def get_next_key(self, samplers): """ Gets the next key to sample from a dictionary of samplers. @@ -206,23 +236,14 @@ def get_next_key(self, samplers): keys = sorted(samplers.keys()) return keys[0] - # --- Helpers --- - @staticmethod - def clip_fragments(graph, metadata_path): - # Extract bounding box - bucket_name, path = util.parse_cloud_path(metadata_path) - metadata = util.read_json_from_gcs(bucket_name, path) - origin = metadata["chunk_origin"][::-1] - shape = metadata["chunk_shape"][::-1] - - # Clip graph - nodes = list() - for i in graph.nodes: - voxel = graph.get_voxel(i) - if not img_util.is_contained(voxel - origin, shape): - nodes.append(i) - graph.remove_nodes_from(nodes) - graph.relabel_nodes() + def init_samplers(self): + samplers = dict() + for key, dataset in self.datasets.items(): + batch_size = dataset.config.ml.batch_size + samplers[key] = iter( + SubgraphSampler(dataset.graph, max_proposals=batch_size) + ) + return samplers def n_proposals(self): """ @@ -233,7 +254,10 @@ def n_proposals(self): int Number of proposals. """ - return np.sum([graph.n_proposals() for graph in self.graphs.values()]) + cnt = 0 + for dataset in self.datasets.values(): + cnt += dataset.graph.n_proposals() + return cnt def p_accepts(self): """ @@ -244,8 +268,10 @@ def p_accepts(self): float Percentage of accepted proposals in ground truth. """ - cnts = [len(graph.gt_accepts) for graph in self.graphs.values()] - return np.sum(cnts) / self.n_proposals() + accepts_cnt = 0 + for dataset in self.datasets.values(): + accepts_cnt += len(dataset.graph.gt_accepts) + return accepts_cnt / self.n_proposals() def save_examples_summary(self, path): """ @@ -257,8 +283,9 @@ def save_examples_summary(self, path): Output path for the CSV file. """ examples_summary = list() - for key in sorted(self.graphs.keys()): - examples_summary.extend([key] * self.graphs[key].n_proposals()) + for key in sorted(self.datasets.keys()): + n_proposals = self.datasets[key].graph.n_proposals() + examples_summary.extend([key] * n_proposals) pd.DataFrame(examples_summary).to_csv(path) diff --git a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py index a3940d4b..1b7f6d87 100644 --- a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py +++ b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py @@ -29,7 +29,6 @@ def __init__( self, graph, img_path, - search_radius, brightness_clip=400, padding=50, patch_shape=(96, 96, 96), @@ -44,8 +43,6 @@ def __init__( Graph to extract features from. img_path : str Path to image of whole-brain dataset. - search_radius : float - Search radius used to generate proposals. brightness_clip : int, optional ... padding : int, optional @@ -58,7 +55,7 @@ def __init__( Path to segmentation of whole-brain dataset. """ self.extractors = [ - SkeletonFeatureExtractor(graph, search_radius), + SkeletonFeatureExtractor(graph), ImageFeatureExtractor( graph, img_path, @@ -89,7 +86,7 @@ class SkeletonFeatureExtractor: A class for extracting skeleton-based features. """ - def __init__(self, graph, search_radius): + def __init__(self, graph): """ Instantiates a SkeletonFeatureExtractor object. @@ -97,12 +94,9 @@ def __init__(self, graph, search_radius): ---------- graph : ProposalGraph Graph to extract features from. - search_radius : float - Search radius used to generate edge proposals. """ # Instance attributes self.graph = graph - self.search_radius = search_radius # Build KD-tree from leaf nodes self.graph.set_kdtree(node_type="leaf") @@ -184,8 +178,8 @@ def extract_proposal_features(self, subgraph, features): for p in subgraph.proposals: proposal_features[p] = np.concatenate( ( - self.graph.proposal_length(p) / self.search_radius, - self.graph.n_nearby_leafs(p, self.search_radius), + self.graph.proposal_length(p) / self.graph.search_radius, + self.graph.n_nearby_leafs(p, self.graph.search_radius), self.graph.proposal_attr(p, "radius"), self.graph.proposal_directionals(p, 16), self.graph.proposal_directionals(p, 32), diff --git a/src/neuron_proofreader/split_proofreading/split_inference.py b/src/neuron_proofreader/split_proofreading/split_inference.py index 959baae8..2eb20cc4 100644 --- a/src/neuron_proofreader/split_proofreading/split_inference.py +++ b/src/neuron_proofreader/split_proofreading/split_inference.py @@ -39,17 +39,15 @@ import os import torch -from neuron_proofreader.proposal_graph import ProposalGraph -from neuron_proofreader.machine_learning.subgraph_sampler import ( - SubgraphSampler, - SeededSubgraphSampler +from neuron_proofreader.split_proofreading.split_datasets import ( + FragmentsDataset ) from neuron_proofreader.split_proofreading.split_feature_extraction import ( FeaturePipeline, HeteroGraphData ) from neuron_proofreader.machine_learning.gnn_models import VisionHGAT -from neuron_proofreader.utils import geometry_util, ml_util, util +from neuron_proofreader.utils import ml_util, util class InferencePipeline: @@ -63,15 +61,14 @@ class InferencePipeline: def __init__( self, - brain_id, - segmentation_id, + fragments_path, img_path, model_path, output_dir, config, + log_preamble="", segmentation_path=None, - soma_centroids=None, - s3_dict=None, + soma_centroids=list(), ): """ Initializes an object that executes the full split correction @@ -95,157 +92,80 @@ def __init__( segmentation_path : str, optional Path to segmentation stored in GCS bucket. The default is None. soma_centroids : List[Tuple[float]] or None, optional - Physcial coordinates of soma centroids. The default is None. - s3_dict : dict, optional - ... + Physcial coordinates of soma centroids. The default is an empty + list. """ # Instance attributes self.accepted_proposals = list() + self.config = config self.img_path = img_path self.model_path = model_path - self.brain_id = brain_id - self.segmentation_id = segmentation_id - self.segmentation_path = segmentation_path - self.soma_centroids = soma_centroids or list() - self.s3_dict = s3_dict - - # Extract config settings - self.graph_config = config.graph_config - self.ml_config = config.ml_config - - # Set output directory self.output_dir = output_dir - util.mkdir(self.output_dir) + self.soma_centroids = soma_centroids - # Initialize logger + # Logger + util.mkdir(self.output_dir) log_path = os.path.join(self.output_dir, "runtimes.txt") self.log_handle = open(log_path, 'a') + self.log(log_preamble) - # --- Core --- - def run(self, swcs_path, search_radius): + # Load data + self._load_data(fragments_path, img_path, segmentation_path) + + def _load_data(self, fragments_path, img_path, segmentation_path): """ - Executes the full inference pipeline. + Builds a graph from the given fragments. Parameters ---------- - swcs_path : str - Path to SWC files used to build an instance of FragmentGraph, - see "swc_util.Reader" for further documentation. + fragments_path : str + Path to SWC files to be loaded into graph. """ - # Initializations - self.log_experiment() - self.write_metadata() + # Load data t0 = time() - - # Main - self.build_graph(swcs_path) - self.connect_soma_fragments() if self.soma_centroids else None - self.generate_proposals(search_radius) - self.classify_proposals(self.ml_config.threshold, search_radius) - - # Finish - t, unit = util.time_writer(time() - t0) - self.log_graph_specs(prefix="\nFinal") - self.log(f"Total Runtime: {t:.2f} {unit}\n") - self.save_results() - - def run_schedule(self, swcs_path, radius_schedule, threshold_schedule): - # Initializations - self.log_experiment() - self.write_metadata() - t0 = time() - - # Main - self.build_graph(swcs_path) - schedules = zip(radius_schedule, threshold_schedule) - for i, (radius, threshold) in enumerate(schedules): - self.log(f"\n--- Round {i + 1}: Radius = {radius} ---") - self.generate_proposals(radius) - self.classify_proposals(threshold) - self.log_graph_specs(prefix="Current") - - # Finish - t, unit = util.time_writer(time() - t0) - self.log_graph_specs(prefix="\nFinal") - self.log(f"Total Runtime: {t:.2f} {unit}\n") - self.save_results() - - def build_graph(self, swcs_path): + self.log("Step 1: Build Graph") + self.dataset = FragmentsDataset( + fragments_path, + img_path, + self.config, + segmentation_path=segmentation_path, + soma_centroids=self.soma_centroids + ) + self.save_fragment_ids() + + # Connect fragments very close to soma + self.log(f"# Soma Fragments: {len(self.dataset.graph.soma_ids)}") + if len(self.soma_centroids) > 0: + somas = self.soma_centroids + results = self.dataset.graph.connect_soma_fragments(somas) + self.log(results) + + # Log results + elapsed, unit = util.time_writer(time() - t0) + self.log(self.dataset.graph.get_summary(prefix="\nInitial")) + self.log(f"Module Runtime: {elapsed:.2f} {unit}\n") + + # --- Core Routines --- + def run(self, search_radius): """ - Builds a graph from the given fragments. + Executes the full inference pipeline. Parameters ---------- - fragment_pointer : str - Path to SWC files to be loaded into graph. + swcs_path : str + Path to SWC files used to build an instance of FragmentGraph, + see "swc_util.Reader" for further documentation. """ - self.log("Step 1: Build Graph") + # Main t0 = time() - - # Initialize graph - self.graph = ProposalGraph( - anisotropy=self.graph_config.anisotropy, - min_size=self.graph_config.min_size, - min_size_with_proposals=self.graph_config.min_size_with_proposals, - node_spacing=self.graph_config.node_spacing, - prune_depth=self.graph_config.prune_depth, - remove_high_risk_merges=self.graph_config.remove_high_risk_merges, - segmentation_path=self.segmentation_path, - soma_centroids=self.soma_centroids, - verbose=True, - ) - self.graph.load(swcs_path) - - # Filter fragments - if self.graph_config.remove_doubles: - geometry_util.remove_doubles(self.graph, 160) + self.generate_proposals(search_radius) + self.classify_proposals(self.config.ml.threshold) # Report results - path = f"{self.output_dir}/segment_ids.txt" - swc_ids = list(self.graph.component_id_to_swc_id.values()) - util.write_list(path, swc_ids) - print("# Soma Fragments:", len(self.graph.soma_ids)) - t, unit = util.time_writer(time() - t0) - self.log_graph_specs(prefix="\nInitial") - self.log(f"Module Runtime: {t:.2f} {unit}\n") - - def connect_soma_fragments(self): - # Initializations - self.graph.set_kdtree() - - # Parse locations - merge_cnt, soma_cnt = 0, 0 - for soma_xyz in self.soma_centroids: - node_ids = self.graph.find_fragments_near_xyz(soma_xyz, 25) - if len(node_ids) > 1: - # Find closest node to soma location - soma_cnt += 1 - best_dist = np.inf - best_node = None - for i in node_ids: - dist = geometry_util.dist(soma_xyz, self.graph.node_xyz[i]) - if dist < best_dist: - best_dist = dist - best_node = i - soma_component_id = self.graph.node_component_id[best_node] - self.graph.soma_ids.add(soma_component_id) - node_ids.remove(best_node) - - # Merge fragments to soma - soma_xyz = self.graph.node_xyz[best_node] - for i in node_ids: - attrs = { - "radius": np.array([2, 2]), - "xyz": np.array([soma_xyz, self.graph.node_xyz[i]]), - } - self.graph._add_edge((best_node, i), attrs) - self.graph.update_component_ids(soma_component_id, i) - merge_cnt += 1 - - print("# Somas Connected:", soma_cnt) - print("# Soma Fragment Merges:", merge_cnt) - del self.graph.kdtree + self.log(self.dataset.graph.get_summary(prefix="\nFinal")) + self.log(f"Total Runtime: {t:.2f} {unit}\n") + self.save_results() def generate_proposals(self, search_radius): """ @@ -255,16 +175,17 @@ def generate_proposals(self, search_radius): # Main t0 = time() self.log("\nStep 2: Generate Proposals") - self.graph.generate_proposals(search_radius) + self.dataset.graph.generate_proposals(search_radius) n_proposals = format(self.graph.n_proposals(), ",") + n_proposals_blocked = self.dataset.graph.n_proposals_blocked # Report results t, unit = util.time_writer(time() - t0) self.log(f"# Proposals: {n_proposals}") - self.log(f"# Proposals Blocked: {self.graph.n_proposals_blocked}") + self.log(f"# Proposals Blocked: {n_proposals_blocked}") self.log(f"Module Runtime: {t:.2f} {unit}\n") - def classify_proposals(self, accept_threshold, search_radius): + def classify_proposals(self, accept_threshold): """ Classifies proposals by calling "self.inference_engine". This routine generates features and runs a GNN to make predictions. Proposals with @@ -282,7 +203,6 @@ def classify_proposals(self, accept_threshold, search_radius): self.img_path, self.model_path, self.ml_config, - search_radius, segmentation_path=self.segmentation_path, ) preds_dict = inference_engine.run() @@ -330,7 +250,7 @@ def save_results(self): self.s3_dict["prefix"] ) - # --- io --- + # --- Helpers --- def save_connections(self, round_id=None): """ Writes the accepted proposals from the graph to a text file. Each line @@ -342,6 +262,11 @@ def save_connections(self, round_id=None): for id_1, id_2 in self.graph.merged_ids: f.write(f"{id_1}, {id_2}" + "\n") + def save_fragment_ids(self): + path = f"{self.output_dir}/segment_ids.txt" + segment_ids = list(self.dataset.graph.component_id_to_swc_id.values()) + util.write_list(path, segment_ids) + def write_metadata(self): """ Writes metadata about the current pipeline run to a JSON file. @@ -371,30 +296,6 @@ def log(self, txt): self.log_handle.write(txt) self.log_handle.write("\n") - def log_experiment(self): - self.log("\nExperiment Overview") - self.log("-" * len(self.segmentation_id)) - self.log(f"Brain_ID: {self.brain_id}") - self.log(f"Segmentation_ID: {self.segmentation_id}") - self.log("\n") - - def log_graph_specs(self, prefix="\n"): - """ - Prints an overview of the graph's structure and memory usage. - """ - # Compute values - n_components = nx.number_connected_components(self.graph) - n_components = format(n_components, ",") - n_nodes = format(self.graph.number_of_nodes(), ",") - n_edges = format(self.graph.number_of_edges(), ",") - - # Report results - self.log(f"{prefix} Graph") - self.log(f"# Connected Components: {n_components}") - self.log(f"# Nodes: {n_nodes}") - self.log(f"# Edges: {n_edges}") - self.log(f"Memory Consumption: {util.get_memory_usage():.2f} GBs") - class InferenceEngine: """ @@ -404,12 +305,9 @@ class InferenceEngine: def __init__( self, - graph, - img_path, + dataset, model_path, ml_config, - search_radius, - segmentation_path=None, ): """ Initializes an inference engine by loading images and setting class @@ -424,8 +322,6 @@ def __init__( ml_config : MLConfig Configuration object containing parameters and settings required for the inference. - search_radius : float - Search radius used to generate proposals. segmentation_path : str, optional Path to segmentation stored in GCS bucket. Default is None. """ @@ -440,7 +336,6 @@ def __init__( self.feature_extractor = FeaturePipeline( graph, img_path, - search_radius, brightness_clip=ml_config.brightness_clip, patch_shape=ml_config.patch_shape, segmentation_path=segmentation_path @@ -449,15 +344,9 @@ def __init__( # Load weights ml_util.load_model(self.model, model_path, device=ml_config.device) - def get_subgraph_sampler(self, graph): - if len(graph.soma_ids) > 0: - return SeededSubgraphSampler(graph, self.batch_size) - else: - return SubgraphSampler(graph, self.batch_size) - def run(self): preds = dict() - for subgraph in self.subgraph_sampler: + for subgraph in self.dataset: # Get model inputs features = self.feature_extractor(subgraph) data = HeteroGraphData(features) @@ -582,4 +471,4 @@ def add_accepts(self, proposals): # --- Helpers --- def reformat_preds(preds_dict): - return {tuple(k): v for k, v in preds_dict.items()} + return {str(k): v for k, v in preds_dict.items()} diff --git a/src/neuron_proofreader/utils/ml_util.py b/src/neuron_proofreader/utils/ml_util.py index b21a6aad..14cbe778 100644 --- a/src/neuron_proofreader/utils/ml_util.py +++ b/src/neuron_proofreader/utils/ml_util.py @@ -106,8 +106,7 @@ def init_mlp(input_dim, hidden_dim, output_dim, dropout=0.1): """ mlp = nn.Sequential( nn.Linear(input_dim, hidden_dim), - #nn.LeakyReLU(), - nn.GELU(), + nn.LeakyReLU(), nn.Dropout(p=dropout), nn.Linear(hidden_dim, output_dim), ) From 3e6ce2886be5a310dd409a04db552a28ced5abd4 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sat, 31 Jan 2026 04:41:41 +0000 Subject: [PATCH 3/3] refactor: split inference almost working --- src/neuron_proofreader/proposal_graph.py | 33 +- src/neuron_proofreader/skeleton_graph.py | 4 +- .../split_proofreading/split_datasets.py | 2 +- .../split_proofreading/split_inference.py | 283 +++++------------- 4 files changed, 90 insertions(+), 232 deletions(-) diff --git a/src/neuron_proofreader/proposal_graph.py b/src/neuron_proofreader/proposal_graph.py index 537862e5..e7272de6 100644 --- a/src/neuron_proofreader/proposal_graph.py +++ b/src/neuron_proofreader/proposal_graph.py @@ -196,7 +196,9 @@ def _add_edge(self, edge_id, attrs): self.xyz_to_edge.update({tuple(xyz): edge_id for xyz in attrs["xyz"]}) def connect_soma_fragments(self, soma_centroids): - merge_cnt, soma_cnt = 0, 0 + merge_cnt = 0 + soma_cnt = 0 + self.set_kdtree() for soma_xyz in soma_centroids: node_ids = self.find_fragments_near_xyz(soma_xyz, 25) if len(node_ids) > 1: @@ -227,7 +229,7 @@ def connect_soma_fragments(self, soma_centroids): # Summarize results results = [f"# Somas Connected: {soma_cnt}"] results.append(f"# Soma Fragments Merged: {merge_cnt}") - return results + return "\n".join(results) def relabel_nodes(self): """ @@ -333,28 +335,6 @@ def get_kdtree(self, node_type=None): else: return KDTree(list(self.xyz_to_edge.keys())) - def query_kdtree(self, xyz, d, node_type=None): - """ - Parameters - ---------- - xyz : int - Node id. - d : float - Distance from "xyz" that is searched. - - Returns - ------- - generator[tuple] - Generator that generates the xyz coordinates cooresponding to all - nodes within a distance of "d" from "xyz". - """ - if node_type == "leaf": - return geometry.query_ball(self.leaf_kdtree, xyz, d) - elif node_type == "proposal": - return geometry.query_ball(self.proposal_kdtree, xyz, d) - else: - return geometry.query_ball(self.kdtree, xyz, d) - # --- Proposal Generation --- def generate_proposals(self, search_radius): """ @@ -638,7 +618,7 @@ def n_nearby_leafs(self, proposal, radius): a proposal. """ xyz = self.proposal_midpoint(proposal) - return len(self.query_kdtree(xyz, radius, "leaf")) - 1 + return len(geometry.query_ball(self.leaf_kdtree, xyz, radius)) - 1 # --- Helpers --- def node_attr(self, i, key): @@ -700,7 +680,8 @@ def edge_length(self, edge): def find_fragments_near_xyz(self, query_xyz, max_dist): hits = dict() - for xyz in self.query_kdtree(query_xyz, max_dist): + xyz_list = geometry.query_ball(self.kdtree, query_xyz, max_dist) + for xyz in xyz_list: i, j = self.xyz_to_edge[tuple(xyz)] dist_i = geometry.dist(self.node_xyz[i], query_xyz) dist_j = geometry.dist(self.node_xyz[j], query_xyz) diff --git a/src/neuron_proofreader/skeleton_graph.py b/src/neuron_proofreader/skeleton_graph.py index 949cab7a..67b37db0 100644 --- a/src/neuron_proofreader/skeleton_graph.py +++ b/src/neuron_proofreader/skeleton_graph.py @@ -595,8 +595,8 @@ def find_closest_node(self, xyz): def get_summary(self, prefix=""): # Compute values n_components = format(nx.number_connected_components(self), ",") - n_nodes = format(self.graph.number_of_nodes(), ",") - n_edges = format(self.graph.number_of_edges(), ",") + n_nodes = format(self.number_of_nodes(), ",") + n_edges = format(self.number_of_edges(), ",") memory = util.get_memory_usage() # Compile results diff --git a/src/neuron_proofreader/split_proofreading/split_datasets.py b/src/neuron_proofreader/split_proofreading/split_datasets.py index 4c3c523d..84391dc5 100644 --- a/src/neuron_proofreader/split_proofreading/split_datasets.py +++ b/src/neuron_proofreader/split_proofreading/split_datasets.py @@ -135,7 +135,7 @@ def __iter__(self): def get_inputs(self, subgraph): features = self.feature_extractor(subgraph) data = HeteroGraphData(features) - if self.gt_path: + if self.graph.gt_path: return data.get_inputs(), data.get_targets() else: return data.get_inputs() diff --git a/src/neuron_proofreader/split_proofreading/split_inference.py b/src/neuron_proofreader/split_proofreading/split_inference.py index 2eb20cc4..ff9404e1 100644 --- a/src/neuron_proofreader/split_proofreading/split_inference.py +++ b/src/neuron_proofreader/split_proofreading/split_inference.py @@ -35,17 +35,12 @@ from tqdm import tqdm import networkx as nx -import numpy as np import os import torch from neuron_proofreader.split_proofreading.split_datasets import ( FragmentsDataset ) -from neuron_proofreader.split_proofreading.split_feature_extraction import ( - FeaturePipeline, - HeteroGraphData -) from neuron_proofreader.machine_learning.gnn_models import VisionHGAT from neuron_proofreader.utils import ml_util, util @@ -99,7 +94,7 @@ def __init__( self.accepted_proposals = list() self.config = config self.img_path = img_path - self.model_path = model_path + self.model = VisionHGAT(config.ml.patch_shape) self.output_dir = output_dir self.soma_centroids = soma_centroids @@ -111,6 +106,7 @@ def __init__( # Load data self._load_data(fragments_path, img_path, segmentation_path) + ml_util.load_model(self.model, model_path, device=config.ml.device) def _load_data(self, fragments_path, img_path, segmentation_path): """ @@ -171,12 +167,17 @@ def generate_proposals(self, search_radius): """ Generates proposals for the fragments graph based on the specified configuration. + + Parameters + ---------- + search_radius : float + Search radius (in microns) used to generate proposals. """ # Main t0 = time() self.log("\nStep 2: Generate Proposals") self.dataset.graph.generate_proposals(search_radius) - n_proposals = format(self.graph.n_proposals(), ",") + n_proposals = format(self.dataset.graph.n_proposals(), ",") n_proposals_blocked = self.dataset.graph.n_proposals_blocked # Report results @@ -192,33 +193,58 @@ def classify_proposals(self, accept_threshold): a prediction above "self.threshold" are accepted and added to the graph as an edge. """ - # Initializations - self.log("Step 3: Run Inference") t0 = time() + self.log("Step 3: Run Inference") - # Generate model predictions - n_proposals = self.graph.n_proposals() - inference_engine = InferenceEngine( - self.graph, - self.img_path, - self.model_path, - self.ml_config, - segmentation_path=self.segmentation_path, - ) - preds_dict = inference_engine.run() - path = os.path.join(self.output_dir, "proposal_predictions.json") - util.write_json(path, reformat_preds(preds_dict)) - - # Update graph - stop + # Main + accepts = set() + new_threshold = 0.99 + preds = self.predict_proposals() + while True: + # Update graph + cur_threshold = new_threshold + accepts.update(self.merge_proposals(cur_threshold)) + + # Update threshold + new_threshold = cur_threshold - 0.025 + if cur_threshold == new_threshold: + break # Report results t, unit = util.time_writer(time() - t0) self.log(f"# Merges Blocked: {self.graph.n_merges_blocked}") self.log(f"# Accepted: {format(len(accepts), ',')}") - self.log(f"% Accepted: {len(accepts) / n_proposals:.4f}") + self.log(f"% Accepted: {len(accepts) / len(preds):.4f}") self.log(f"Module Runtime: {t:.4f} {unit}\n") + def predict_proposals(self): + # Main + preds = dict() + pbar = tqdm(total=self.dataset.graph.n_proposals(), desc="Inference") + for subgraph in self.dataset: + data = self.dataset.get_inputs(subgraph) + preds.update(self.predict(data)) + pbar.update(subgraph.n_proposals()) + + # Save results + path = os.path.join(self.output_dir, "proposal_predictions.json") + util.write_json(path, reformat_preds(preds)) + return preds + + def merge_proposals(self, preds, threshold): + accepts = list() + for proposal in self.dataset.graph.get_sorted_proposals(): + # Check if proposal satifies threshold + i, j = tuple(proposal) + if preds[proposal] < threshold: + continue + + # Check if proposal creates a loop + if not nx.has_path(self.graph, i, j): + self.graph.merge_proposal(proposal) + accepts.append(proposal) + return accepts + def save_results(self): """ Saves the processed results from running the inference pipeline, @@ -251,6 +277,35 @@ def save_results(self): ) # --- Helpers --- + def log(self, txt): + print(txt) + self.log_handle.write(txt) + self.log_handle.write("\n") + + def predict(self, data): + """ + ... + + Parameters + ---------- + data : HeteroGraphData + ... + + Returns + ------- + Dict[Frozenset[int], float] + Dictionary that maps proposal IDs to model predictions. + """ + # Generate predictions + with torch.no_grad(): + x = data.get_inputs().to(self.device) + hat_y = sigmoid(self.model(x)) + + # Reformat predictions + idx_to_id = data.idxs_proposals.idx_to_id + hat_y = ml_util.tensor_to_list(hat_y) + return {idx_to_id[i]: y_i for i, y_i in enumerate(hat_y)} + def save_connections(self, round_id=None): """ Writes the accepted proposals from the graph to a text file. Each line @@ -290,184 +345,6 @@ def write_metadata(self): path = os.path.join(self.output_dir, "metadata.json") util.write_json(path, metadata) - # --- Summaries --- - def log(self, txt): - print(txt) - self.log_handle.write(txt) - self.log_handle.write("\n") - - -class InferenceEngine: - """ - Class that runs inference with a machine learning model that has been - trained to classify edge proposals. - """ - - def __init__( - self, - dataset, - model_path, - ml_config, - ): - """ - Initializes an inference engine by loading images and setting class - attributes. - - Parameters - ---------- - img_path : str - Path to image. - model_path : str - Path to machine learning model weights. - ml_config : MLConfig - Configuration object containing parameters and settings required - for the inference. - segmentation_path : str, optional - Path to segmentation stored in GCS bucket. Default is None. - """ - # Instance attributes - self.batch_size = ml_config.batch_size - self.device = ml_config.device - self.model = VisionHGAT(ml_config.patch_shape) - self.pbar = tqdm(total=graph.n_proposals(), desc="Inference") - self.subgraph_sampler = self.get_subgraph_sampler(graph) - - # Feature generator - self.feature_extractor = FeaturePipeline( - graph, - img_path, - brightness_clip=ml_config.brightness_clip, - patch_shape=ml_config.patch_shape, - segmentation_path=segmentation_path - ) - - # Load weights - ml_util.load_model(self.model, model_path, device=ml_config.device) - - def run(self): - preds = dict() - for subgraph in self.dataset: - # Get model inputs - features = self.feature_extractor(subgraph) - data = HeteroGraphData(features) - - # Run model - preds.update(self.predict(data)) - self.pbar.update(subgraph.n_proposals()) - return preds - - def predict(self, data): - """ - ... - - Parameters - ---------- - data : HeteroGraphData - ... - - Returns - ------- - Dict[Frozenset[int], float] - Dictionary that maps proposal IDs to model predictions. - """ - # Generate predictions - with torch.no_grad(): - x = data.get_inputs().to(self.device) - hat_y = sigmoid(self.model(x)) - - # Reformat predictions - idx_to_id = data.idxs_proposals.idx_to_id - hat_y = ml_util.tensor_to_list(hat_y) - return {idx_to_id[i]: y_i for i, y_i in enumerate(hat_y)} - - def update_graph(self, preds, high_threshold=0.9): - """ - Determines which proposals to accept based on prediction scores and - the specified threshold. - - Parameters - ---------- - preds : dict - Dictionary that maps proposal ids to probability generated from - machine learning model. - high_threshold : float, optional - Threshold value for separating the best proposals from the rest. - Default is 0.9. - - Returns - ------- - list - Proposals to be added as edges to "graph". - """ - # Partition proposals into best and the rest - preds = {k: v for k, v in preds.items() if v > self.threshold} - best_proposals, proposals = self.separate_best(preds, high_threshold) - - # Determine which proposals to accept - accepts = list() - accepts.extend(self.add_accepts(best_proposals)) - accepts.extend(self.add_accepts(proposals)) - return accepts - - def separate_best(self, preds, high_threshold): - """ - Splits "preds" into two separate dictionaries such that one contains - the best proposals (i.e. simple proposals with high confidence) and - the other contains all other proposals. - - Parameters - ---------- - preds : dict - Dictionary that maps proposal ids to probability generated from - machine learning model. - high_threshold : float - Threshold on acceptance probability for proposals. - - Returns - ------- - list - Proposal IDs determined to be the best. - list - All other proposal IDs. - """ - best_probs, probs = list(), list() - best_proposals, proposals = list(), list() - simple_proposals = self.graph.simple_proposals() - for proposal, prob in preds.items(): - if proposal in simple_proposals and prob > high_threshold: - best_proposals.append(proposal) - best_probs.append(prob) - else: - proposals.append(proposal) - probs.append(prob) - best_idxs = np.argsort(best_probs) - idxs = np.argsort(probs) - return np.array(best_proposals)[best_idxs], np.array(proposals)[idxs] - - def add_accepts(self, proposals): - """ - ... - - Parameters - ---------- - proposals : list[frozenset] - Proposals with predicted probability above threshold to be added - to the graph. - - Returns - ------- - List[frozenset] - List of proposals that do not create a cycle when iteratively - added to "graph". - """ - accepts = list() - for proposal in proposals: - i, j = tuple(proposal) - if not nx.has_path(self.graph, i, j): - self.graph.merge_proposal(proposal) - accepts.append(proposal) - return accepts - # --- Helpers --- def reformat_preds(preds_dict):