diff --git a/src/neuron_proofreader/merge_proofreading/merge_datasets.py b/src/neuron_proofreader/merge_proofreading/merge_datasets.py index 605f0cb..3fe3848 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datasets.py @@ -77,7 +77,7 @@ def __init__( self, merge_sites_df, anisotropy=(1.0, 1.0, 1.0), - brightness_clip=500, + brightness_clip=400, subgraph_radius=100, node_spacing=5, patch_shape=(128, 128, 128), diff --git a/src/neuron_proofreader/merge_proofreading/merge_inference.py b/src/neuron_proofreader/merge_proofreading/merge_inference.py index 6c3aa9c..7543dc2 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_inference.py +++ b/src/neuron_proofreader/merge_proofreading/merge_inference.py @@ -350,9 +350,31 @@ def read_superchunk(self, nodes): superchunk = img_util.normalize(superchunk) return superchunk, start.astype(int) + def is_near_leaf(self, node, threshold=50): + # Check if node is branching + if self.graph.degree[node] > 2: + return False + + # Search neighborhood + queue = [(node, 0)] + visited = {node} + while len(queue) > 0: + # Visit node + i, dist_i = queue.pop() + if self.graph.degree[i] == 1: + return True + + # Update queue + for j in self.graph.neighbors(i): + dist_j = dist_i + self.graph.dist(i, j) + if j not in visited and dist_j < threshold: + queue.append((j, dist_j)) + visited.add(j) + return False + def is_node_valid(self, node): is_contained = self.is_contained(node) - is_nonleaf = self.graph.degree[node] > 1 + is_nonleaf = self.is_near_leaf(node) return is_contained and is_nonleaf diff --git a/src/neuron_proofreader/utils/ml_util.py b/src/neuron_proofreader/utils/ml_util.py index b21a6aa..14cbe77 100644 --- a/src/neuron_proofreader/utils/ml_util.py +++ b/src/neuron_proofreader/utils/ml_util.py @@ -106,8 +106,7 @@ def init_mlp(input_dim, hidden_dim, output_dim, dropout=0.1): """ mlp = nn.Sequential( nn.Linear(input_dim, hidden_dim), - #nn.LeakyReLU(), - nn.GELU(), + nn.LeakyReLU(), nn.Dropout(p=dropout), nn.Linear(hidden_dim, output_dim), )