Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/neuron_proofreader/machine_learning/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __call__(self, patches):
for axis in self.axes:
if random.random() > 0.5:
patches[0, ...] = np.flip(patches[0, ...], axis=axis)
patches[1, ...] = np.flip(patches[1, ...], axis=axis)
#patches[1, ...] = np.flip(patches[1, ...], axis=axis)


class RandomRotation3D:
Expand Down Expand Up @@ -124,7 +124,7 @@ def __call__(self, patches):
if random.random() < 0.5:
angle = random.uniform(*self.angles)
patches[0, ...] = rotate3d(patches[0, ...], angle, axes)
patches[1, ...] = rotate3d(patches[1, ...], angle, axes, True)
#patches[1, ...] = rotate3d(patches[1, ...], angle, axes, True)


class RandomScale3D:
Expand Down Expand Up @@ -174,7 +174,7 @@ def __call__(self, patches):

# Rescale images
patches[0, ...] = zoom(patches[0, ...], zoom_factors, order=3)
patches[1, ...] = zoom(patches[1, ...], zoom_factors, order=0)
#patches[1, ...] = zoom(patches[1, ...], zoom_factors, order=0)
return patches


Expand Down Expand Up @@ -207,7 +207,7 @@ def __call__(self, patches):
the input image and "patches[1, ...]" is from the segmentation.
"""
factor = random.uniform(*self.factor_range)
patches[0, ...] = np.clip(patches[0, ...] * factor, 0, 1)
#patches[0, ...] = np.clip(patches[0, ...] * factor, 0, 1)


class RandomNoise3D:
Expand Down Expand Up @@ -240,7 +240,7 @@ def __call__(self, img_patch):
std = self.max_std * random.random()
noise = np.random.uniform(-std, std, img_patch[0, ...].shape)
img_patch[0, ...] += noise
img_patch[0, ...] = np.clip(img_patch[0, ...], 0, 1)
#img_patch[0, ...] = np.clip(img_patch[0, ...], 0, 1)


# --- Helpers ---
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@

import torch

from neuron_proofreader.machine_learning.vision_models import (
CNN3D,
init_feedforward,
)
from neuron_proofreader.machine_learning.vision_models import CNN3D
from neuron_proofreader.utils import ml_util


# --- Multimodal GNN Architectures ---
Expand Down
10 changes: 4 additions & 6 deletions src/neuron_proofreader/machine_learning/point_cloud_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
import torch.nn as nn
import torch.nn.functional as F

from neuron_proofreader.machine_learning.vision_models import (
CNN3D,
init_feedforward,
)
from neuron_proofreader.machine_learning.vision_models import CNN3D
from neuron_proofreader.utils import ml_util


# --- Architectures ---
Expand Down Expand Up @@ -70,7 +68,7 @@ def __init__(self, patch_shape, output_dim=128):
output_dim=output_dim,
use_double_conv=True,
)
self.output = init_feedforward(2 * output_dim, 1, 3)
self.output = ml_util.init_feedforward(2 * output_dim, 1, 3)

def forward(self, x):
"""
Expand Down Expand Up @@ -231,7 +229,7 @@ def __init__(self, patch_shape, output_dim=128):
output_dim=output_dim,
use_double_conv=True,
)
self.output = init_feedforward(2 * output_dim, 1, 3)
self.output = ml_util.init_feedforward(2 * output_dim, 1, 3)

def forward(self, x):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/neuron_proofreader/machine_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def _save_mistake_mips(self, x, y, hat_y, idx_offset):
filename = f"{mistake_type}{i + idx_offset}.png"
output_path = os.path.join(self.mistakes_dir, filename)
img_util.plot_image_and_segmentation_mips(
x[i, 0], 2 * x[i, 1], output_path
x[i, 0] + np.min(x[i, 0]), x[i, 0] + np.min(x[i, 0]), output_path
)

def save_model(self, epoch):
Expand Down
99 changes: 43 additions & 56 deletions src/neuron_proofreader/machine_learning/vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
"""

from einops import rearrange
from neurobase.finetune import finetune_model

import torch
import torch.nn as nn

from neuron_proofreader.utils import ml_util


# --- CNNs ---
class CNN3D(nn.Module):
Expand Down Expand Up @@ -56,12 +59,12 @@ def __init__(

# Convolutional layers
self.conv_layers = init_cnn3d(
2, n_feat_channels, n_conv_layers, use_double_conv=use_double_conv
1, n_feat_channels, n_conv_layers, use_double_conv=use_double_conv
)

# Output layer
flat_size = self._get_flattened_size()
self.output = init_feedforward(flat_size, output_dim, 3)
self.output = ml_util.init_feedforward(flat_size, output_dim, 3)

# Initialize weights
self.apply(self.init_weights)
Expand All @@ -79,7 +82,7 @@ def _get_flattened_size(self):
pooling.
"""
with torch.no_grad():
x = torch.zeros(1, 2, *self.patch_shape)
x = torch.zeros(1, 1, *self.patch_shape)
x = self.conv_layers(x)
return x.view(1, -1).size(1)

Expand Down Expand Up @@ -128,6 +131,42 @@ def forward(self, x):


# --- Transformers ---
class MAE3D(nn.Module):

def __init__(self):
# Call parent closs
super().__init__()

# Load model
full_model = finetune_model(
checkpoint_path="/home/jupyter/models/best_model-v1_mae_S.ckpt",
model_config="mae_S",
task_head_config="binary_classifier",
freeze_encoder=True
)

# Instance attributes
self.encoder = full_model.encoder
self.output = ml_util.init_feedforward(384, 1, 2)

def forward(self, x):
latent = self.encoder(x)
x = latent["latents"][:, 0, :]
x = self.output(x)
return x

def forward_old(self, x):
latent0 = self.encoder(x[:, 0:1, ...])
latent1 = self.encoder(x[:, 1:2, ...])

x0 = latent0["latents"][:, 0, :]
x1 = latent1["latents"][:, 0, :]

x = torch.cat((x0, x1), dim=1)
x = self.output(x)
return x


class ViT3D(nn.Module):
"""
A class that implements a 3D Vision transformer.
Expand Down Expand Up @@ -185,7 +224,7 @@ def __init__(
self.norm = nn.LayerNorm(emb_dim)

# Output layer
self.output = init_feedforward(emb_dim, output_dim, 2)
self.output = ml_util.init_feedforward(emb_dim, output_dim, 2)

# Initialize weights
self._init_weights()
Expand Down Expand Up @@ -486,55 +525,3 @@ def init_conv_layer(in_channels, out_channels, kernel_size, use_double_conv):
# Pooling
layers.append(nn.MaxPool3d(kernel_size=2))
return nn.Sequential(*layers)


def init_feedforward(input_dim, output_dim, n_layers):
"""
Initializes a feed forward neural network.

Parameters
----------
input_dim : int
Dimension of the input.
output_dim : int
Dimension of the output of this network.
n_layers : int
Number of layers in the network.
"""
layers = list()
input_dim_i = input_dim
output_dim_i = input_dim // 2
for i in range(n_layers):
layers.append(init_mlp(input_dim_i, input_dim_i * 2, output_dim_i))
input_dim_i = input_dim_i // 2
output_dim_i = output_dim_i // 2 if i < n_layers - 2 else output_dim
return nn.Sequential(*layers)


def init_mlp(input_dim, hidden_dim, output_dim, dropout=0.1):
"""
Initializes a multi-layer perceptron (MLP).

Parameters
----------
input_dim : int
Dimension of input feature vector.
hidden_dim : int
Dimension of embedded feature vector.
output_dim : int
Dimension of output feature vector.
dropout : float, optional
Fraction of values to randomly drop during training. Default is 0.1.

Returns
-------
mlp : nn.Sequential
Multi-layer perception network.
"""
mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Dropout(p=dropout),
nn.Linear(hidden_dim, output_dim),
)
return mlp
10 changes: 5 additions & 5 deletions src/neuron_proofreader/merge_proofreading/merge_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
self,
merge_sites_df,
anisotropy=(1.0, 1.0, 1.0),
brightness_clip=400,
brightness_clip=600,
subgraph_radius=100,
node_spacing=5,
patch_shape=(128, 128, 128),
Expand Down Expand Up @@ -324,11 +324,11 @@ def __getitem__(self, idx):

# Stack image channels
try:
patches = np.stack([img_patch, segment_mask], axis=0)
patches = img_patch + 2 * segment_mask
except ValueError:
img_patch = img_util.pad_to_shape(img_patch, self.patch_shape)
patches = np.stack([img_patch, segment_mask], axis=0)
return patches, subgraph, label
patches = img_patch + segment_mask
return patches[np.newaxis], subgraph, label

def sample_brain_id(self):
"""
Expand Down Expand Up @@ -940,7 +940,7 @@ def __init__(
# Instance attributes
self.is_multimodal = is_multimodal
self.modality = modality
self.patches_shape = (2,) + self.dataset.patch_shape
self.patches_shape = (1,) + self.dataset.patch_shape
self.use_shuffle = use_shuffle

# --- Core Routines ---
Expand Down
7 changes: 4 additions & 3 deletions src/neuron_proofreader/utils/img_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,13 @@ def normalize(img):

Returns
-------
numpy.ndarray
img : numpy.ndarray
Normalized image.
"""
try:
mn, mx = np.percentile(img, [1, 99.9])
return np.clip((img - mn) / (mx - mn + 1e-5), 0, 1)
#mn, mx = np.percentile(img, [1, 99.9])
#return np.clip((img - mn) / (mx - mn + 1e-5), 0, 1)
return (img - img.mean()) / (img.std() + 1e-8)
except Exception as e:
print("Image Normalization Failed:", e)
return np.zeros(img.shape)
Expand Down
54 changes: 54 additions & 0 deletions src/neuron_proofreader/utils/ml_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,66 @@
import networkx as nx
import numpy as np
import torch
import torch.nn as nn

from neuron_proofreader.utils import util

GNN_DEPTH = 2


# --- Architectures ---
def init_feedforward(input_dim, output_dim, n_layers):
"""
Initializes a feed forward neural network.

Parameters
----------
input_dim : int
Dimension of the input.
output_dim : int
Dimension of the output of this network.
n_layers : int
Number of layers in the network.
"""
layers = list()
input_dim_i = input_dim
output_dim_i = input_dim // 2
for i in range(n_layers):
layers.append(init_mlp(input_dim_i, input_dim_i * 2, output_dim_i))
input_dim_i = input_dim_i // 2
output_dim_i = output_dim_i // 2 if i < n_layers - 2 else output_dim
return nn.Sequential(*layers)


def init_mlp(input_dim, hidden_dim, output_dim, dropout=0.1):
"""
Initializes a multi-layer perceptron (MLP).

Parameters
----------
input_dim : int
Dimension of input feature vector.
hidden_dim : int
Dimension of embedded feature vector.
output_dim : int
Dimension of output feature vector.
dropout : float, optional
Fraction of values to randomly drop during training. Default is 0.1.

Returns
-------
mlp : nn.Sequential
Multi-layer perception network.
"""
mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Dropout(p=dropout),
nn.Linear(hidden_dim, output_dim),
)
return mlp


# --- Batch Generation ---
def get_batch(graph, proposals, batch_size, flagged_proposals=set()):
"""
Expand Down