Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 13 additions & 22 deletions src/neuron_proofreader/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
aspects of a system involving graphs, proposals, and machine learning (ML).

"""
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Tuple


Expand All @@ -24,13 +24,6 @@ class GraphConfig:
Scaling factors applied to xyz coordinates to account for anisotropy
of microscope. Note this instance of "anisotropy" is only used while
reading SWC files. Default is (1.0, 1.0, 1.0).
complex_bool : bool
Indication of whether to generate complex proposals, meaning proposals
between leaf and non-leaf nodes. Default is False.
long_range_bool : bool, optional
Indication of whether to generate simple proposals within a scaled
distance of "search_radius" from leaves without any proposals. Default
is False.
min_size : float, optional
Minimum path length (in microns) of swc files which are stored as
connected components in the FragmentsGraph. Default is 30.
Expand All @@ -48,27 +41,23 @@ class GraphConfig:
remove_high_risk_merges : bool, optional
Indication of whether to remove high risk merge sites (i.e. close
branching points). Default is False.
smooth_bool : bool, optional
Indication of whether to smooth branches in the graph. Default is
True.
trim_endpoints_bool : bool, optional
Indication of whether to endpoints of branches with exactly one
proposal. Default is True.
"""

anisotropy: Tuple[float] = field(default_factory=tuple)
complex_bool: bool = False
long_range_bool: bool = True
min_size: float = 30.0
min_size_with_proposals: float = 0
node_spacing: int = 5
anisotropy: Tuple[float, float, float] = (1.0, 1.0, 1.0)
max_proposals_per_leaf: int = 3
min_size: float = 40.0
min_size_with_proposals: float = 40.0
node_spacing: int = 1
proposals_per_leaf: int = 3
prune_depth: float = 24.0
remove_doubles: bool = False
remove_doubles: bool = True
remove_high_risk_merges: bool = False
search_radius: float = 20.0
smooth_bool: bool = True
trim_endpoints_bool: bool = True
verbose: bool = True


@dataclass
Expand All @@ -87,7 +76,9 @@ class MLConfig:
batch_size: int = 64
brightness_clip: int = 400
device: str = "cuda"
patch_shape : tuple = (96, 96, 96)
patch_shape: tuple = (96, 96, 96)
shuffle: bool = False
transform: bool = False
threshold: float = 0.8


Expand All @@ -112,5 +103,5 @@ def __init__(self, graph_config: GraphConfig, ml_config: MLConfig):
An instance of the "MLConfig" class that includes configuration
parameters for machine learning models.
"""
self.graph_config = graph_config
self.ml_config = ml_config
self.graph = graph_config
self.ml = ml_config
4 changes: 2 additions & 2 deletions src/neuron_proofreader/machine_learning/gnn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.nn.init as init

from neuron_proofreader.machine_learning.vision_models import CNN3D
from neuron_proofreader.split_proofreading import feature_extraction
from neuron_proofreader.split_proofreading import split_feature_extraction
from neuron_proofreader.utils.ml_util import FeedForwardNet


Expand Down Expand Up @@ -125,7 +125,7 @@ def init_node_embedding(output_dim):
features.
"""
# Get feature dimensions
node_input_dims = feature_extraction.get_feature_dict()
node_input_dims = split_feature_extraction.get_feature_dict()
dim_b = node_input_dims["branch"]
dim_p = node_input_dims["proposal"]

Expand Down
83 changes: 51 additions & 32 deletions src/neuron_proofreader/proposal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
trim_endpoints_at_proposal
)
from neuron_proofreader.utils import (
geometry_util as geometry, graph_util as gutil, util,
geometry_util as geometry,
graph_util as gutil,
util,
)


Expand All @@ -48,7 +50,7 @@ class ProposalGraph(SkeletonGraph):
def __init__(
self,
anisotropy=(1.0, 1.0, 1.0),
filter_transitive_proposals=False,
gt_path=None,
max_proposals_per_leaf=3,
min_size=0,
min_size_with_proposals=0,
Expand All @@ -57,7 +59,7 @@ def __init__(
remove_high_risk_merges=False,
segmentation_path=None,
soma_centroids=None,
verbose=False,
verbose=True,
):
"""
Instantiates a ProposalGraph object.
Expand Down Expand Up @@ -93,6 +95,7 @@ def __init__(
# Instance attributes - Graph
self.anisotropy = anisotropy
self.component_id_to_swc_id = dict()
self.gt_path = gt_path
self.leaf_kdtree = None
self.soma_ids = set()
self.verbose = verbose
Expand Down Expand Up @@ -192,6 +195,42 @@ def _add_edge(self, edge_id, attrs):
self.add_edge(i, j, radius=attrs["radius"], xyz=attrs["xyz"])
self.xyz_to_edge.update({tuple(xyz): edge_id for xyz in attrs["xyz"]})

def connect_soma_fragments(self, soma_centroids):
merge_cnt = 0
soma_cnt = 0
self.set_kdtree()
for soma_xyz in soma_centroids:
node_ids = self.find_fragments_near_xyz(soma_xyz, 25)
if len(node_ids) > 1:
# Find closest node to soma location
soma_cnt += 1
best_dist = np.inf
best_node = None
for i in node_ids:
dist = geometry.dist(soma_xyz, self.node_xyz[i])
if dist < best_dist:
best_dist = dist
best_node = i
soma_component_id = self.node_component_id[best_node]
self.soma_ids.add(soma_component_id)
node_ids.remove(best_node)

# Merge fragments to soma
soma_xyz = self.node_xyz[best_node]
for i in node_ids:
attrs = {
"radius": np.array([2, 2]),
"xyz": np.array([soma_xyz, self.node_xyz[i]]),
}
self._add_edge((best_node, i), attrs)
self.update_component_ids(soma_component_id, i)
merge_cnt += 1

# Summarize results
results = [f"# Somas Connected: {soma_cnt}"]
results.append(f"# Soma Fragments Merged: {merge_cnt}")
return "\n".join(results)

def relabel_nodes(self):
"""
Reassigns contiguous node IDs and update all dependent structures.
Expand Down Expand Up @@ -296,30 +335,8 @@ def get_kdtree(self, node_type=None):
else:
return KDTree(list(self.xyz_to_edge.keys()))

def query_kdtree(self, xyz, d, node_type=None):
"""
Parameters
----------
xyz : int
Node id.
d : float
Distance from "xyz" that is searched.

Returns
-------
generator[tuple]
Generator that generates the xyz coordinates cooresponding to all
nodes within a distance of "d" from "xyz".
"""
if node_type == "leaf":
return geometry.query_ball(self.leaf_kdtree, xyz, d)
elif node_type == "proposal":
return geometry.query_ball(self.proposal_kdtree, xyz, d)
else:
return geometry.query_ball(self.kdtree, xyz, d)

# --- Proposal Generation ---
def generate_proposals(self, search_radius, gt_graph=None):
def generate_proposals(self, search_radius):
"""
Generates proposals from leaf nodes.

Expand All @@ -330,16 +347,17 @@ def generate_proposals(self, search_radius, gt_graph=None):
gt_graph : networkx.Graph, optional
Ground truth graph. Default is None.
"""
# Generate proposals
# Proposal pipeline
proposals = self.proposal_generator(search_radius)
self.search_radius = search_radius
self.store_proposals(proposals)
self.trim_proposals()

# Set groundtruth
if gt_graph:
if self.gt_path:
gt_graph = ProposalGraph(anisotropy=self.anisotropy)
gt_graph.load(self.gt_path)
self.gt_accepts = groundtruth_generation.run(gt_graph, self)
else:
self.gt_accepts = set()

def add_proposal(self, i, j):
"""
Expand Down Expand Up @@ -600,7 +618,7 @@ def n_nearby_leafs(self, proposal, radius):
a proposal.
"""
xyz = self.proposal_midpoint(proposal)
return len(self.query_kdtree(xyz, radius, "leaf")) - 1
return len(geometry.query_ball(self.leaf_kdtree, xyz, radius)) - 1

# --- Helpers ---
def node_attr(self, i, key):
Expand Down Expand Up @@ -662,7 +680,8 @@ def edge_length(self, edge):

def find_fragments_near_xyz(self, query_xyz, max_dist):
hits = dict()
for xyz in self.query_kdtree(query_xyz, max_dist):
xyz_list = geometry.query_ball(self.kdtree, query_xyz, max_dist)
for xyz in xyz_list:
i, j = self.xyz_to_edge[tuple(xyz)]
dist_i = geometry.dist(self.node_xyz[i], query_xyz)
dist_j = geometry.dist(self.node_xyz[j], query_xyz)
Expand Down
15 changes: 15 additions & 0 deletions src/neuron_proofreader/skeleton_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,21 @@ def find_closest_node(self, xyz):
_, node = self.kdtree.query(xyz)
return node

def get_summary(self, prefix=""):
# Compute values
n_components = format(nx.number_connected_components(self), ",")
n_nodes = format(self.number_of_nodes(), ",")
n_edges = format(self.number_of_edges(), ",")
memory = util.get_memory_usage()

# Compile results
summary = [f"{prefix} Graph"]
summary.append(f"# Connected Components: {n_components}")
summary.append(f"# Nodes: {n_nodes}")
summary.append(f"# Edges: {n_edges}")
summary.append(f"Memory Consumption: {memory:.2f} GBs")
return "\n".join(summary)

def path_length(self, max_depth=np.inf, root=None):
"""
Computes the path length of the connected component that contains the
Expand Down
Loading
Loading