From 787439b28e86563370d50f40f50be74701db6bdb Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 29 Jan 2026 21:41:17 +0000 Subject: [PATCH] feat: near soma sampling, aggressive aug --- .../machine_learning/augmentation.py | 42 ++++++---- .../machine_learning/exaspim_dataloader.py | 6 +- .../merge_proofreading/merge_datasets.py | 48 ++++++----- .../split_proofreading/split_datasets.py | 6 +- src/neuron_proofreader/utils/ml_util.py | 17 ++++ src/neuron_proofreader/visualization_skel.py | 83 +++++++++++++++++++ 6 files changed, 159 insertions(+), 43 deletions(-) create mode 100644 src/neuron_proofreader/visualization_skel.py diff --git a/src/neuron_proofreader/machine_learning/augmentation.py b/src/neuron_proofreader/machine_learning/augmentation.py index 5ab8c75..26be77f 100644 --- a/src/neuron_proofreader/machine_learning/augmentation.py +++ b/src/neuron_proofreader/machine_learning/augmentation.py @@ -30,7 +30,7 @@ def __init__(self): RandomFlip3D(), RandomRotation3D(), RandomNoise3D(), - #RandomContrast3D() + RandomContrast3D() ] def __call__(self, patches): @@ -45,7 +45,7 @@ def __call__(self, patches): input image and second is the segmentation. """ for transform in self.transforms: - transform(patches) + patches = transform(patches) return patches @@ -81,6 +81,7 @@ def __call__(self, patches): if random.random() > 0.5: patches[0, ...] = np.flip(patches[0, ...], axis=axis) patches[1, ...] = np.flip(patches[1, ...], axis=axis) + return patches class RandomRotation3D: @@ -117,6 +118,7 @@ def __call__(self, patches): angle = random.uniform(*self.angles) self.rotate3d(patches[0, ...], angle, axes, False) self.rotate3d(patches[1, ...], angle, axes, True) + return patches @staticmethod def rotate3d(img_patch, angle, axes, is_segmentation=False): @@ -206,17 +208,16 @@ class RandomContrast3D: Adjusts the contrast of a 3D image by scaling voxel intensities. """ - def __init__(self, factor_range=(0.8, 1.2)): + def __init__(self, p_low=(0, 90), p_high=(97.5, 100)): """ Initializes a RandomContrast3D transformer. Parameters ---------- - factor_range : Tuple[float], optional - Tuple of integers representing the range of contrast factors. - Default is (0.8, 1.2). + ... """ - self.factor_range = factor_range + self.p_low = p_low + self.p_high = p_high def __call__(self, patches): """ @@ -225,11 +226,18 @@ def __call__(self, patches): Parameters ---------- patches : numpy.ndarray - Image with the shape (2, H, W, D), where "patches[0, ...]" is from - the input image and "patches[1, ...]" is from the segmentation. + Image with the shape (2, H, W, D), where the zeroth channel is + from the raw image and first channel is from the segmentation. """ - factor = random.uniform(*self.factor_range) - patches[0, ...] = np.clip(patches[0, ...] * factor, 0, 1) + p_low = np.random.uniform(*self.p_low) + p_high = np.random.uniform(*self.p_high) + print("precentiles:", p_low, p_high) + lo = np.percentile(patches[0], p_low) + hi = np.percentile(patches[0], p_high) + print("intensities:", lo, hi) + patches[0] = (patches[0] - lo) / (hi - lo + 1e-5) + patches[0] = np.clip(patches[0], 0, 1) + return patches class RandomNoise3D: @@ -237,7 +245,7 @@ class RandomNoise3D: Adds random Gaussian noise to a 3D image. """ - def __init__(self, max_std=0.3): + def __init__(self, max_std=0.2): """ Initializes a RandomNoise3D transformer. @@ -247,9 +255,9 @@ def __init__(self, max_std=0.3): Maximum standard deviation of the Gaussian noise distribution. Default is 0.3. """ - self.max_std = 0.2 # max_std + self.max_std = max_std - def __call__(self, img_patch): + def __call__(self, img_patches): """ Adds Gaussian noise to the input 3D image. @@ -260,6 +268,6 @@ def __call__(self, img_patch): the input image and "patches[1, ...]" is from the segmentation. """ std = self.max_std * random.random() - noise = np.random.uniform(-std, std, img_patch[0, ...].shape) - img_patch[0, ...] += noise - img_patch[0, ...] = np.clip(img_patch[0, ...], 0, 1) + img_patches[0] += np.random.uniform(-std, std, img_patches[0].shape) + img_patches[0] = np.clip(img_patches[0], 0, 1) + return img_patches diff --git a/src/neuron_proofreader/machine_learning/exaspim_dataloader.py b/src/neuron_proofreader/machine_learning/exaspim_dataloader.py index 57bfe6e..be2a1ec 100644 --- a/src/neuron_proofreader/machine_learning/exaspim_dataloader.py +++ b/src/neuron_proofreader/machine_learning/exaspim_dataloader.py @@ -21,7 +21,7 @@ import torch from neuron_proofreader.utils.img_util import TensorStoreReader -from neuron_proofreader.utils import swc_util, img_util, util +from neuron_proofreader.utils import swc_util, util # --- Custom Datasets --- @@ -66,7 +66,7 @@ def __init__( Default is (1, 99.9). prefetch_foreground_sampling : int, optional Number of image patches that are preloaded during foreground - search in "self.sample_segmentation_voxel" and + search in "self.sample_segmentation_voxel" and "self.sample_bright_voxel". Default is 32. """ # Call parent class @@ -421,7 +421,7 @@ def read_segmentation(self, brain_id, voxel): numpy.ndarray Segmentation patch. """ - return self.segmentations[brain_id].read(voxel, self.patch_shape) + return self.segmentations[brain_id].read(voxel, self.patch_shape) # --- Custom Dataloader --- diff --git a/src/neuron_proofreader/merge_proofreading/merge_datasets.py b/src/neuron_proofreader/merge_proofreading/merge_datasets.py index cc539f3..605f0cb 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datasets.py @@ -77,7 +77,7 @@ def __init__( self, merge_sites_df, anisotropy=(1.0, 1.0, 1.0), - brightness_clip=400, + brightness_clip=500, subgraph_radius=100, node_spacing=5, patch_shape=(128, 128, 128), @@ -329,9 +329,6 @@ def __getitem__(self, idx): except ValueError: img_patch = img_util.pad_to_shape(img_patch, self.patch_shape) patches = np.stack([img_patch, segment_mask], axis=0) - - if subgraph.number_of_nodes() == 0: - print("Empty subgraph! -->", brain_id, voxel, label) return patches, subgraph, label def sample_brain_id(self): @@ -427,26 +424,27 @@ def get_random_negative_site(self): while True: # Sample node if outcome < 0.4: + # Any node node = util.sample_once(list(self.graphs[brain_id].nodes)) - elif outcome < 0.5: - node = util.sample_once(self.graphs[brain_id].get_leafs()) - elif outcome < 0.6: - branching_nodes = self.gt_graphs[brain_id].get_branchings() - node = util.sample_once(branching_nodes) - if self.check_nearby_branching(brain_id, node, use_gt=True): - continue - else: - subgraph = self.gt_graphs[brain_id].get_rooted_subgraph( - node, self.subgraph_radius - ) - return brain_id, subgraph, 0 - else: + #elif outcome < 0.5: + # # Node close to soma + # node = self.sample_node_nearby_soma(brain_id) + elif outcome < 0.8: + # Branching node branching_nodes = self.graphs[brain_id].get_branchings() if len(branching_nodes) > 0: node = util.sample_once(branching_nodes) else: outcome = 0 continue + else: + # Branching node from GT + branching_nodes = self.gt_graphs[brain_id].get_branchings() + node = util.sample_once(branching_nodes) + subgraph = self.gt_graphs[brain_id].get_rooted_subgraph( + node, self.subgraph_radius + ) + return brain_id, subgraph, 0 # Extract rooted subgraph subgraph = self.graphs[brain_id].get_rooted_subgraph( @@ -454,10 +452,10 @@ def get_random_negative_site(self): ) # Check branching - if self.graphs[brain_id].degree(node) == 3: + if self.graphs[brain_id].degree(node) > 2: is_high_degree = self.graphs[brain_id].degree(node) > 3 is_too_branchy = self.check_nearby_branching(brain_id, node) - if is_too_branchy or is_high_degree: + if is_high_degree or is_too_branchy: continue # Check if node is close to merge site @@ -512,10 +510,11 @@ def get_segment_mask(self, brain_id, center, subgraph): # Annotate fragment center = subgraph.get_voxel(0) + offset = img_util.get_offset(center, self.patch_shape) for node1, node2 in subgraph.edges: # Get local voxel coordinates - voxel1 = subgraph.get_local_voxel(node1, center, self.patch_shape) - voxel2 = subgraph.get_local_voxel(node2, center, self.patch_shape) + voxel1 = subgraph.get_local_voxel(node1, offset) + voxel2 = subgraph.get_local_voxel(node2, offset) # Populate mask voxels = geometry_util.make_digital_line(voxel1, voxel2) @@ -622,6 +621,13 @@ def is_nearby_merge_site(self, brain_id, node): dist, _ = self.merge_site_kdtrees[brain_id].query(xyz) return dist < 100 + def sample_node_nearby_soma(self, brain_id): + subgraph = self.gt_graphs[brain_id].get_rooted_subgraph(0, 600) + gt_node = util.sample_once(subgraph.nodes) + gt_xyz = self.gt_graphs[brain_id].node_xyz[gt_node] + d, node = self.graphs[brain_id].kdtree.query(gt_xyz) + return node + class MergeSiteTrainDataset(MergeSiteDataset): """ diff --git a/src/neuron_proofreader/split_proofreading/split_datasets.py b/src/neuron_proofreader/split_proofreading/split_datasets.py index 3d61ebc..96d0eab 100644 --- a/src/neuron_proofreader/split_proofreading/split_datasets.py +++ b/src/neuron_proofreader/split_proofreading/split_datasets.py @@ -16,12 +16,14 @@ from neuron_proofreader.proposal_graph import ProposalGraph from neuron_proofreader.machine_learning.augmentation import ImageTransforms -from neuron_proofreader.machine_learning.subgraph_sampler import SubgraphSampler +from neuron_proofreader.machine_learning.subgraph_sampler import ( + SubgraphSampler +) from neuron_proofreader.split_proofreading.feature_extraction import ( FeaturePipeline, HeteroGraphData ) -from neuron_proofreader.utils import geometry_util, util +from neuron_proofreader.utils import geometry_util, img_util, util class FragmentsDataset(IterableDataset): diff --git a/src/neuron_proofreader/utils/ml_util.py b/src/neuron_proofreader/utils/ml_util.py index 0ab2fd2..b21a6aa 100644 --- a/src/neuron_proofreader/utils/ml_util.py +++ b/src/neuron_proofreader/utils/ml_util.py @@ -203,3 +203,20 @@ def to_cpu(tensor, to_numpy=False): return np.array(tensor.detach().cpu()) else: return tensor.detach().cpu() + + +def to_tensor(arr): + """ + Converts a numpy array to a tensor. + + Parameters + ---------- + arr : numpy.ndarray + Array to be converted. + + Returns + ------- + torch.Tensor + Array converted to tensor. + """ + return torch.tensor(arr, dtype=torch.float32) diff --git a/src/neuron_proofreader/visualization_skel.py b/src/neuron_proofreader/visualization_skel.py new file mode 100644 index 0000000..708e88e --- /dev/null +++ b/src/neuron_proofreader/visualization_skel.py @@ -0,0 +1,83 @@ +""" +Created on Sat Sep 30 10:00:00 2025 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Code for visualizing SkeletonGraphs. + +""" + +import plotly.graph_objects as go + + +def visualize(graph): + """ + Visualizes the given graph using Plotly. + + Parameters + ---------- + graph : SkeletonGraph + Graph to be visualized. + """ + # Initializations + data = get_edge_trace(graph) + layout = get_layout() + + # Generate plot + fig = go.Figure(data=data, layout=layout) + fig.show() + + +def get_edge_trace(graph, color="blue", name=""): + """ + Generates a 3D edge trace for visualizing the edges of a graph. + + Parameters + ---------- + graph : SkeletonGraph + Graph to be visualized. + color : str, optional + Color to use for the edge lines in the plot. Default is "black". + name : str, optional + Name of the edge trace. Default is an empty string. + + Returns + ------- + plotly.graph_objects.Scatter3d + Scatter3d object that represents the 3D trace of the graph edges. + """ + # Build coordinate lists + x, y, z = list(), list(), list() + for u, v in graph.edges(): + x0, y0, z0 = graph.node_xyz[u] + x1, y1, z1 = graph.node_xyz[v] + x.extend([x0, x1, None]) + y.extend([y0, y1, None]) + z.extend([z0, z1, None]) + + # Set edge trace + edge_trace = go.Scatter3d( + x=x, y=y, z=z, mode="lines", line=dict(color=color, width=3), name=name + ) + return edge_trace + + +# --- Helpers --- +def get_layout(): + """ + Generates the layout for a 3D plot using Plotly. + + Returns + ------- + plotly.graph_objects.Layout + Layout object that defines the appearance and properties of the plot. + """ + layout = go.Layout( + scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=1)), + showlegend=True, + template="plotly_white", + height=700, + width=1200, + ) + return layout