Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions cellfinder/core/classify/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,41 @@
from keras import Model

from cellfinder.core import logger
from cellfinder.core.classify.resnet import build_model, layer_type
from cellfinder.core.classify import resnet, vit


def build_model(
network_depth: str,
learning_rate: float,
**kwargs,
) -> Model:
"""
Automatically detects the type and configuration of the model to build
:param network_depth: The type of model to build
:param learning_rate: The learning rate to use

:return: A keras model
"""
if network_depth in vit.vit_configs:
return vit.build_model(

Check warning on line 26 in cellfinder/core/classify/tools.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/tools.py#L26

Added line #L26 was not covered by tests
network_depth=network_depth,
learning_rate=learning_rate,
**kwargs,
)
elif network_depth in resnet.resnet_unit_blocks:
return resnet.build_model(
network_depth=network_depth,
learning_rate=learning_rate,
**kwargs,
)
else:
raise ValueError(f"Unknown network depth: {network_depth}")

Check warning on line 38 in cellfinder/core/classify/tools.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/tools.py#L38

Added line #L38 was not covered by tests


def get_model(
existing_model: Optional[os.PathLike] = None,
model_weights: Optional[os.PathLike] = None,
network_depth: Optional[layer_type] = None,
network_depth: Optional[str] = None,
learning_rate: float = 0.0001,
inference: bool = False,
continue_training: bool = False,
Expand Down
284 changes: 284 additions & 0 deletions cellfinder/core/classify/vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Tuple

from keras import Model, layers, optimizers
from keras import ops as K


@dataclass
class VITConfig:
num_layers: int
hidden_dim: int
num_heads: int
expanding_factor: int
patch_size: Tuple[int, int, int]
layer_norm_eps: float = 1e-6


network_type = Literal[
"vit-4-layer",
"vit-8-layer",
"vit-12-layer",
"vit-24-layer",
"vit-32-layer",
]


vit_configs: Dict[network_type, VITConfig] = {
"vit-4-layer": VITConfig(
num_layers=4,
hidden_dim=64,
num_heads=8,
expanding_factor=4,
patch_size=(8, 8, 4),
),
"vit-8-layer": VITConfig(
num_layers=8,
hidden_dim=256,
num_heads=8,
expanding_factor=4,
patch_size=(8, 8, 4),
),
"vit-12-layer": VITConfig(
num_layers=12,
hidden_dim=768,
num_heads=8,
expanding_factor=4,
patch_size=(8, 8, 4),
),
"vit-24-layer": VITConfig(
num_layers=24,
hidden_dim=1024,
num_heads=8,
expanding_factor=4,
patch_size=(8, 8, 4),
),
"vit-32-layer": VITConfig(
num_layers=32,
hidden_dim=4096,
num_heads=8,
expanding_factor=4,
patch_size=(8, 8, 4),
),
}


class PositionalEmbeddings(layers.Layer):
"""
Add positional embeddings to the input tensor.
This seems to be not implemented in keras yet, so we have to do it

:param int embedding_dim: The dimension of the embeddings
"""

def __init__(
self,
embedding_dim: int,
**kwargs,
):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim

Check warning on line 80 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L79-L80

Added lines #L79 - L80 were not covered by tests

def build(
self,
input_shape: Tuple[int],
):
_, num_tokens, _ = input_shape
self.position_embedding = layers.Embedding(

Check warning on line 87 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L86-L87

Added lines #L86 - L87 were not covered by tests
input_dim=num_tokens, output_dim=self.embedding_dim
)
self.positions = K.arange(0, num_tokens, 1)

Check warning on line 90 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L90

Added line #L90 was not covered by tests

def call(
self,
inputs,
):
return K.broadcast_to(

Check warning on line 96 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L96

Added line #L96 was not covered by tests
self.position_embedding(self.positions), K.shape(inputs)
)


def attention_block(
inputs,
layer_norm_eps: float,
num_heads: int,
hidden_dim: int,
name="attention_block",
):
"""
Apply a multi-head attention block.

:param inputs: The input tensor
:param layer_norm_eps: The epsilon value for the layer normalization
:param num_heads: The number of heads in the multi-head attention
:param hidden_dim: The hidden dimension of the multi-head attention
:param name: The name of the block
:return: The residual-output of the multi-head attention block
"""
normalized_inputs = layers.LayerNormalization(

Check warning on line 118 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L118

Added line #L118 was not covered by tests
epsilon=layer_norm_eps,
name=f"{name}--layer_norm",
)(inputs)
return layers.MultiHeadAttention(

Check warning on line 122 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L122

Added line #L122 was not covered by tests
num_heads=num_heads,
key_dim=hidden_dim // num_heads,
name=f"{name}--mha",
)(normalized_inputs, normalized_inputs)


def mlp_block(
inputs,
hidden_dim: int,
layer_norm_eps: float,
expanding_factor: int,
name: str = "mlp_block",
):
"""
Apply a multi-layer perceptron block.

:param inputs: The input tensor
:param hidden_dim: The hidden dimension of the MLP
:param layer_norm_eps: The epsilon value for the layer normalization
:param expanding_factor: The factor by which the hidden dimension is
expanded in the MLP
:param name: The name of the block
:return: The residual-output of the MLP block
"""
normalized_inputs = layers.LayerNormalization(

Check warning on line 147 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L147

Added line #L147 was not covered by tests
epsilon=layer_norm_eps,
name=f"{name}--layer_norm",
)(inputs)
hidden_states = layers.Dense(

Check warning on line 151 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L151

Added line #L151 was not covered by tests
units=hidden_dim * expanding_factor,
activation=K.gelu,
name=f"{name}--up",
)(normalized_inputs)
return layers.Dense(

Check warning on line 156 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L156

Added line #L156 was not covered by tests
units=hidden_dim,
name=f"{name}--down",
)(hidden_states)


def transformer_block(
residual_stream,
layer_norm_eps: float = 1e-6,
num_heads: int = 8,
hidden_dim: int = 128,
expanding_factor: int = 4,
name: str = "transformer_block",
):
"""
Apply a transformer block a.k.a. transformer layer.

:param residual_stream: The input tensor
:param layer_norm_eps: The epsilon value for the layer normalization
:param num_heads: The number of heads in the multi-head attention
:param hidden_dim: The hidden dimension of the multi-head attention
:param expanding_factor: The factor by which the hidden dimension is
expanded in the MLP
:param name: The name of the block
:return: The residual-output of the transformer block
"""

attention_outputs = attention_block(

Check warning on line 183 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L183

Added line #L183 was not covered by tests
residual_stream,
layer_norm_eps=layer_norm_eps,
num_heads=num_heads,
hidden_dim=hidden_dim,
name=f"{name}--attention_block",
)

residual_stream = layers.Add()([residual_stream, attention_outputs])

Check warning on line 191 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L191

Added line #L191 was not covered by tests

mlp_outputs = mlp_block(

Check warning on line 193 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L193

Added line #L193 was not covered by tests
residual_stream,
hidden_dim=hidden_dim,
layer_norm_eps=layer_norm_eps,
expanding_factor=expanding_factor,
name=f"{name}--mlp_block",
)

# Skip connection
residual_stream = layers.Add()([residual_stream, mlp_outputs])

Check warning on line 202 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L202

Added line #L202 was not covered by tests

return residual_stream

Check warning on line 204 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L204

Added line #L204 was not covered by tests


def build_model(
input_shape: Tuple[int, int, int, int] = (50, 50, 20, 2),
network_depth: network_type = "24-layer",
optimizer: Optional[optimizers.Optimizer] = None,
learning_rate: float = 0.0005,
loss: str = "categorical_crossentropy",
metrics: List[str] = ["accuracy"],
num_classes: int = 2,
classification_activation: str = "softmax",
) -> Model:
"""
Build a Vision Transformer model.

Mostly follows the signature of the ResNet model, but with additional
parameters for the Vision Transformer.
"""
config = vit_configs[network_depth]
embedding_dim = config.hidden_dim

Check warning on line 224 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L223-L224

Added lines #L223 - L224 were not covered by tests

# Get the input layer
inputs = layers.Input(shape=input_shape)

Check warning on line 227 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L227

Added line #L227 was not covered by tests
# Create patches.
patches = layers.Conv3D(

Check warning on line 229 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L229

Added line #L229 was not covered by tests
name="patch_embedding",
filters=embedding_dim,
kernel_size=config.patch_size,
strides=config.patch_size,
padding="VALID",
)(inputs)
patches = layers.Reshape(target_shape=(-1, embedding_dim))(patches)

Check warning on line 236 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L236

Added line #L236 was not covered by tests

# Add positional embeddings
positional_embeddings = PositionalEmbeddings(embedding_dim=embedding_dim)(

Check warning on line 239 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L239

Added line #L239 was not covered by tests
patches
)

residual_stream = layers.Add()(

Check warning on line 243 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L243

Added line #L243 was not covered by tests
[
patches,
positional_embeddings,
]
)

# Create multiple layers of the Transformer block.
for layer_idx in range(config.num_layers):
residual_stream = transformer_block(

Check warning on line 252 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L251-L252

Added lines #L251 - L252 were not covered by tests
residual_stream,
num_heads=config.num_heads,
hidden_dim=config.hidden_dim,
expanding_factor=config.expanding_factor,
layer_norm_eps=config.layer_norm_eps,
name=f"layer_{layer_idx}",
)

normalized_stream = layers.LayerNormalization(

Check warning on line 261 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L261

Added line #L261 was not covered by tests
epsilon=config.layer_norm_eps, name="pre_logits_norm"
)(residual_stream)
flat_feature_vector = layers.GlobalAvgPool1D()(normalized_stream)

Check warning on line 264 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L264

Added line #L264 was not covered by tests

outputs = layers.Dense(

Check warning on line 266 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L266

Added line #L266 was not covered by tests
units=num_classes,
activation=classification_activation,
)(flat_feature_vector)

model = Model(

Check warning on line 271 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L271

Added line #L271 was not covered by tests
inputs=inputs,
outputs=outputs,
)

if optimizer is None:
optimizer = optimizers.Adam(learning_rate=learning_rate)

Check warning on line 277 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L276-L277

Added lines #L276 - L277 were not covered by tests

model.compile(

Check warning on line 279 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L279

Added line #L279 was not covered by tests
optimizer,
loss=loss,
metrics=metrics,
)
return model

Check warning on line 284 in cellfinder/core/classify/vit.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/classify/vit.py#L284

Added line #L284 was not covered by tests
5 changes: 5 additions & 0 deletions cellfinder/core/train/train_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
"50": "50-layer",
"101": "101-layer",
"152": "152-layer",
"vit-4": "vit-4-layer",
"vit-8": "vit-8-layer",
"vit-12": "vit-12-layer",
"vit-24": "vit-24-layer",
"vit-32": "vit-32-layer",
}


Expand Down
Loading