Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions src/neuron_proofreader/machine_learning/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self):
RandomFlip3D(),
RandomRotation3D(),
RandomNoise3D(),
#RandomContrast3D()
RandomContrast3D()
]

def __call__(self, patches):
Expand All @@ -45,7 +45,7 @@ def __call__(self, patches):
input image and second is the segmentation.
"""
for transform in self.transforms:
transform(patches)
patches = transform(patches)
return patches


Expand Down Expand Up @@ -81,6 +81,7 @@ def __call__(self, patches):
if random.random() > 0.5:
patches[0, ...] = np.flip(patches[0, ...], axis=axis)
patches[1, ...] = np.flip(patches[1, ...], axis=axis)
return patches


class RandomRotation3D:
Expand Down Expand Up @@ -117,6 +118,7 @@ def __call__(self, patches):
angle = random.uniform(*self.angles)
self.rotate3d(patches[0, ...], angle, axes, False)
self.rotate3d(patches[1, ...], angle, axes, True)
return patches

@staticmethod
def rotate3d(img_patch, angle, axes, is_segmentation=False):
Expand Down Expand Up @@ -206,17 +208,16 @@ class RandomContrast3D:
Adjusts the contrast of a 3D image by scaling voxel intensities.
"""

def __init__(self, factor_range=(0.8, 1.2)):
def __init__(self, p_low=(0, 90), p_high=(97.5, 100)):
"""
Initializes a RandomContrast3D transformer.

Parameters
----------
factor_range : Tuple[float], optional
Tuple of integers representing the range of contrast factors.
Default is (0.8, 1.2).
...
"""
self.factor_range = factor_range
self.p_low = p_low
self.p_high = p_high

def __call__(self, patches):
"""
Expand All @@ -225,19 +226,26 @@ def __call__(self, patches):
Parameters
----------
patches : numpy.ndarray
Image with the shape (2, H, W, D), where "patches[0, ...]" is from
the input image and "patches[1, ...]" is from the segmentation.
Image with the shape (2, H, W, D), where the zeroth channel is
from the raw image and first channel is from the segmentation.
"""
factor = random.uniform(*self.factor_range)
patches[0, ...] = np.clip(patches[0, ...] * factor, 0, 1)
p_low = np.random.uniform(*self.p_low)
p_high = np.random.uniform(*self.p_high)
print("precentiles:", p_low, p_high)
lo = np.percentile(patches[0], p_low)
hi = np.percentile(patches[0], p_high)
print("intensities:", lo, hi)
patches[0] = (patches[0] - lo) / (hi - lo + 1e-5)
patches[0] = np.clip(patches[0], 0, 1)
return patches


class RandomNoise3D:
"""
Adds random Gaussian noise to a 3D image.
"""

def __init__(self, max_std=0.3):
def __init__(self, max_std=0.2):
"""
Initializes a RandomNoise3D transformer.

Expand All @@ -247,9 +255,9 @@ def __init__(self, max_std=0.3):
Maximum standard deviation of the Gaussian noise distribution.
Default is 0.3.
"""
self.max_std = 0.2 # max_std
self.max_std = max_std

def __call__(self, img_patch):
def __call__(self, img_patches):
"""
Adds Gaussian noise to the input 3D image.

Expand All @@ -260,6 +268,6 @@ def __call__(self, img_patch):
the input image and "patches[1, ...]" is from the segmentation.
"""
std = self.max_std * random.random()
noise = np.random.uniform(-std, std, img_patch[0, ...].shape)
img_patch[0, ...] += noise
img_patch[0, ...] = np.clip(img_patch[0, ...], 0, 1)
img_patches[0] += np.random.uniform(-std, std, img_patches[0].shape)
img_patches[0] = np.clip(img_patches[0], 0, 1)
return img_patches
6 changes: 3 additions & 3 deletions src/neuron_proofreader/machine_learning/exaspim_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch

from neuron_proofreader.utils.img_util import TensorStoreReader
from neuron_proofreader.utils import swc_util, img_util, util
from neuron_proofreader.utils import swc_util, util


# --- Custom Datasets ---
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
Default is (1, 99.9).
prefetch_foreground_sampling : int, optional
Number of image patches that are preloaded during foreground
search in "self.sample_segmentation_voxel" and
search in "self.sample_segmentation_voxel" and
"self.sample_bright_voxel". Default is 32.
"""
# Call parent class
Expand Down Expand Up @@ -421,7 +421,7 @@ def read_segmentation(self, brain_id, voxel):
numpy.ndarray
Segmentation patch.
"""
return self.segmentations[brain_id].read(voxel, self.patch_shape)
return self.segmentations[brain_id].read(voxel, self.patch_shape)


# --- Custom Dataloader ---
Expand Down
48 changes: 27 additions & 21 deletions src/neuron_proofreader/merge_proofreading/merge_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
self,
merge_sites_df,
anisotropy=(1.0, 1.0, 1.0),
brightness_clip=400,
brightness_clip=500,
subgraph_radius=100,
node_spacing=5,
patch_shape=(128, 128, 128),
Expand Down Expand Up @@ -329,9 +329,6 @@ def __getitem__(self, idx):
except ValueError:
img_patch = img_util.pad_to_shape(img_patch, self.patch_shape)
patches = np.stack([img_patch, segment_mask], axis=0)

if subgraph.number_of_nodes() == 0:
print("Empty subgraph! -->", brain_id, voxel, label)
return patches, subgraph, label

def sample_brain_id(self):
Expand Down Expand Up @@ -427,37 +424,38 @@ def get_random_negative_site(self):
while True:
# Sample node
if outcome < 0.4:
# Any node
node = util.sample_once(list(self.graphs[brain_id].nodes))
elif outcome < 0.5:
node = util.sample_once(self.graphs[brain_id].get_leafs())
elif outcome < 0.6:
branching_nodes = self.gt_graphs[brain_id].get_branchings()
node = util.sample_once(branching_nodes)
if self.check_nearby_branching(brain_id, node, use_gt=True):
continue
else:
subgraph = self.gt_graphs[brain_id].get_rooted_subgraph(
node, self.subgraph_radius
)
return brain_id, subgraph, 0
else:
#elif outcome < 0.5:
# # Node close to soma
# node = self.sample_node_nearby_soma(brain_id)
elif outcome < 0.8:
# Branching node
branching_nodes = self.graphs[brain_id].get_branchings()
if len(branching_nodes) > 0:
node = util.sample_once(branching_nodes)
else:
outcome = 0
continue
else:
# Branching node from GT
branching_nodes = self.gt_graphs[brain_id].get_branchings()
node = util.sample_once(branching_nodes)
subgraph = self.gt_graphs[brain_id].get_rooted_subgraph(
node, self.subgraph_radius
)
return brain_id, subgraph, 0

# Extract rooted subgraph
subgraph = self.graphs[brain_id].get_rooted_subgraph(
node, self.subgraph_radius
)

# Check branching
if self.graphs[brain_id].degree(node) == 3:
if self.graphs[brain_id].degree(node) > 2:
is_high_degree = self.graphs[brain_id].degree(node) > 3
is_too_branchy = self.check_nearby_branching(brain_id, node)
if is_too_branchy or is_high_degree:
if is_high_degree or is_too_branchy:
continue

# Check if node is close to merge site
Expand Down Expand Up @@ -512,10 +510,11 @@ def get_segment_mask(self, brain_id, center, subgraph):

# Annotate fragment
center = subgraph.get_voxel(0)
offset = img_util.get_offset(center, self.patch_shape)
for node1, node2 in subgraph.edges:
# Get local voxel coordinates
voxel1 = subgraph.get_local_voxel(node1, center, self.patch_shape)
voxel2 = subgraph.get_local_voxel(node2, center, self.patch_shape)
voxel1 = subgraph.get_local_voxel(node1, offset)
voxel2 = subgraph.get_local_voxel(node2, offset)

# Populate mask
voxels = geometry_util.make_digital_line(voxel1, voxel2)
Expand Down Expand Up @@ -622,6 +621,13 @@ def is_nearby_merge_site(self, brain_id, node):
dist, _ = self.merge_site_kdtrees[brain_id].query(xyz)
return dist < 100

def sample_node_nearby_soma(self, brain_id):
subgraph = self.gt_graphs[brain_id].get_rooted_subgraph(0, 600)
gt_node = util.sample_once(subgraph.nodes)
gt_xyz = self.gt_graphs[brain_id].node_xyz[gt_node]
d, node = self.graphs[brain_id].kdtree.query(gt_xyz)
return node


class MergeSiteTrainDataset(MergeSiteDataset):
"""
Expand Down
6 changes: 4 additions & 2 deletions src/neuron_proofreader/split_proofreading/split_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

from neuron_proofreader.proposal_graph import ProposalGraph
from neuron_proofreader.machine_learning.augmentation import ImageTransforms
from neuron_proofreader.machine_learning.subgraph_sampler import SubgraphSampler
from neuron_proofreader.machine_learning.subgraph_sampler import (
SubgraphSampler
)
from neuron_proofreader.split_proofreading.feature_extraction import (
FeaturePipeline,
HeteroGraphData
)
from neuron_proofreader.utils import geometry_util, util
from neuron_proofreader.utils import geometry_util, img_util, util


class FragmentsDataset(IterableDataset):
Expand Down
17 changes: 17 additions & 0 deletions src/neuron_proofreader/utils/ml_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,20 @@ def to_cpu(tensor, to_numpy=False):
return np.array(tensor.detach().cpu())
else:
return tensor.detach().cpu()


def to_tensor(arr):
"""
Converts a numpy array to a tensor.

Parameters
----------
arr : numpy.ndarray
Array to be converted.

Returns
-------
torch.Tensor
Array converted to tensor.
"""
return torch.tensor(arr, dtype=torch.float32)
83 changes: 83 additions & 0 deletions src/neuron_proofreader/visualization_skel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Created on Sat Sep 30 10:00:00 2025

@author: Anna Grim
@email: [email protected]

Code for visualizing SkeletonGraphs.

"""

import plotly.graph_objects as go


def visualize(graph):
"""
Visualizes the given graph using Plotly.

Parameters
----------
graph : SkeletonGraph
Graph to be visualized.
"""
# Initializations
data = get_edge_trace(graph)
layout = get_layout()

# Generate plot
fig = go.Figure(data=data, layout=layout)
fig.show()


def get_edge_trace(graph, color="blue", name=""):
"""
Generates a 3D edge trace for visualizing the edges of a graph.

Parameters
----------
graph : SkeletonGraph
Graph to be visualized.
color : str, optional
Color to use for the edge lines in the plot. Default is "black".
name : str, optional
Name of the edge trace. Default is an empty string.

Returns
-------
plotly.graph_objects.Scatter3d
Scatter3d object that represents the 3D trace of the graph edges.
"""
# Build coordinate lists
x, y, z = list(), list(), list()
for u, v in graph.edges():
x0, y0, z0 = graph.node_xyz[u]
x1, y1, z1 = graph.node_xyz[v]
x.extend([x0, x1, None])
y.extend([y0, y1, None])
z.extend([z0, z1, None])

# Set edge trace
edge_trace = go.Scatter3d(
x=x, y=y, z=z, mode="lines", line=dict(color=color, width=3), name=name
)
return edge_trace


# --- Helpers ---
def get_layout():
"""
Generates the layout for a 3D plot using Plotly.

Returns
-------
plotly.graph_objects.Layout
Layout object that defines the appearance and properties of the plot.
"""
layout = go.Layout(
scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=1)),
showlegend=True,
template="plotly_white",
height=700,
width=1200,
)
return layout
Loading