Skip to content

Commit b554a10

Browse files
committed
Add U-Mamba architecture to MONAI
1 parent b58e883 commit b554a10

File tree

3 files changed

+272
-0
lines changed

3 files changed

+272
-0
lines changed

networks/nets/__init__.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from .ahnet import AHnet, Ahnet, AHNet
15+
from .attentionunet import AttentionUnet
16+
from .autoencoder import AutoEncoder
17+
from .autoencoderkl import AutoencoderKL
18+
from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet
19+
from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus
20+
from .classifier import Classifier, Critic, Discriminator
21+
from .controlnet import ControlNet
22+
from .daf3d import DAF3D
23+
from .densenet import (
24+
DenseNet,
25+
Densenet,
26+
DenseNet121,
27+
Densenet121,
28+
DenseNet169,
29+
Densenet169,
30+
DenseNet201,
31+
Densenet201,
32+
DenseNet264,
33+
Densenet264,
34+
densenet121,
35+
densenet169,
36+
densenet201,
37+
densenet264,
38+
)
39+
from .diffusion_model_unet import DiffusionModelUNet
40+
from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch
41+
from .dynunet import DynUNet, DynUnet, Dynunet
42+
from .efficientnet import (
43+
BlockArgs,
44+
EfficientNet,
45+
EfficientNetBN,
46+
EfficientNetBNFeatures,
47+
EfficientNetEncoder,
48+
drop_connect,
49+
get_efficientnet_image_size,
50+
)
51+
from .flexible_unet import FLEXUNET_BACKBONE, FlexibleUNet, FlexUNet, FlexUNetEncoderRegister
52+
from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet
53+
from .generator import Generator
54+
from .highresnet import HighResBlock, HighResNet
55+
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
56+
from .masked_autoencoder_vit import MaskedAutoEncoderViT
57+
from .mednext import (
58+
MedNeXt,
59+
MedNext,
60+
MedNextB,
61+
MedNeXtB,
62+
MedNextBase,
63+
MedNextL,
64+
MedNeXtL,
65+
MedNeXtLarge,
66+
MedNextLarge,
67+
MedNextM,
68+
MedNeXtM,
69+
MedNeXtMedium,
70+
MedNextMedium,
71+
MedNextS,
72+
MedNeXtS,
73+
MedNeXtSmall,
74+
MedNextSmall,
75+
)
76+
from .milmodel import MILModel
77+
from .netadapter import NetAdapter
78+
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
79+
from .quicknat import Quicknat
80+
from .regressor import Regressor
81+
from .regunet import GlobalNet, LocalNet, RegUNet
82+
from .resnet import (
83+
ResNet,
84+
ResNetBlock,
85+
ResNetBottleneck,
86+
ResNetEncoder,
87+
ResNetFeatures,
88+
get_medicalnet_pretrained_resnet_args,
89+
get_pretrained_resnet_medicalnet,
90+
resnet10,
91+
resnet18,
92+
resnet34,
93+
resnet50,
94+
resnet101,
95+
resnet152,
96+
resnet200,
97+
)
98+
from .segresnet import SegResNet, SegResNetVAE
99+
from .segresnet_ds import SegResNetDS, SegResNetDS2
100+
from .senet import (
101+
SENet,
102+
SEnet,
103+
Senet,
104+
SENet154,
105+
SEnet154,
106+
Senet154,
107+
SEResNet50,
108+
SEresnet50,
109+
Seresnet50,
110+
SEResNet101,
111+
SEresnet101,
112+
Seresnet101,
113+
SEResNet152,
114+
SEresnet152,
115+
Seresnet152,
116+
SEResNext50,
117+
SEResNeXt50,
118+
SEresnext50,
119+
Seresnext50,
120+
SEResNext101,
121+
SEResNeXt101,
122+
SEresnext101,
123+
Seresnext101,
124+
senet154,
125+
seresnet50,
126+
seresnet101,
127+
seresnet152,
128+
seresnext50,
129+
seresnext101,
130+
)
131+
from .spade_autoencoderkl import SPADEAutoencoderKL
132+
from .spade_diffusion_model_unet import SPADEDiffusionModelUNet
133+
from .spade_network import SPADENet
134+
from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
135+
from .torchvision_fc import TorchVisionFCModel
136+
from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex
137+
from .transformer import DecoderOnlyTransformer
138+
from .unet import UNet, Unet
139+
from .unetr import UNETR
140+
from .varautoencoder import VarAutoEncoder
141+
from .vista3d import VISTA3D, vista3d132
142+
from .vit import ViT
143+
from .vitautoenc import ViTAutoEnc
144+
from .vnet import VNet
145+
from .voxelmorph import VoxelMorph, VoxelMorphUNet
146+
from .vqvae import VQVAE
147+
from .u_mamba import UMamba

networks/nets/u_mamba.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
# Simple placeholder for the SSM (Mamba-like block)
6+
class SSMBlock(nn.Module):
7+
def __init__(self, dim):
8+
super().__init__()
9+
self.linear1 = nn.Linear(dim, dim)
10+
self.linear2 = nn.Linear(dim, dim)
11+
12+
def forward(self, x):
13+
# x: (B, L, C)
14+
return self.linear2(torch.silu(self.linear1(x)))
15+
16+
class UMambaBlock(nn.Module):
17+
def __init__(self, in_channels, hidden_channels):
18+
super().__init__()
19+
self.conv_res1 = nn.Sequential(
20+
nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1),
21+
nn.InstanceNorm3d(in_channels),
22+
nn.LeakyReLU(),
23+
)
24+
self.conv_res2 = nn.Sequential(
25+
nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1),
26+
nn.InstanceNorm3d(in_channels),
27+
nn.LeakyReLU(),
28+
)
29+
30+
self.layernorm = nn.LayerNorm(hidden_channels)
31+
self.linear1 = nn.Linear(in_channels, hidden_channels)
32+
self.linear2 = nn.Linear(hidden_channels, in_channels)
33+
self.conv1d = nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1)
34+
self.ssm = SSMBlock(hidden_channels)
35+
36+
def forward(self, x):
37+
# x: (B, C, H, W, D)
38+
residual = x
39+
x = self.conv_res1(x)
40+
x = self.conv_res2(x) + residual
41+
42+
B, C, H, W, D = x.shape
43+
x_flat = x.view(B, C, -1).permute(0, 2, 1) # (B, L, C)
44+
x_norm = self.layernorm(x_flat)
45+
x_proj = self.linear1(x_norm)
46+
47+
x_silu = torch.silu(x_proj)
48+
x_ssm = self.ssm(x_silu)
49+
x_conv1d = self.conv1d(x_proj.permute(0, 2, 1)).permute(0, 2, 1)
50+
51+
x_combined = torch.silu(x_conv1d) * torch.silu(x_ssm)
52+
x_out = self.linear2(x_combined)
53+
x_out = x_out.permute(0, 2, 1).view(B, C, H, W, D)
54+
55+
return x + x_out # Residual connection
56+
57+
class ResidualBlock(nn.Module):
58+
def __init__(self, channels):
59+
super().__init__()
60+
self.block = nn.Sequential(
61+
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
62+
nn.BatchNorm3d(channels),
63+
nn.ReLU(),
64+
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
65+
nn.BatchNorm3d(channels),
66+
)
67+
68+
def forward(self, x):
69+
return F.relu(x + self.block(x))
70+
71+
class UMambaUNet(nn.Module):
72+
def __init__(self, in_channels=1, out_channels=1, base_channels=32):
73+
super().__init__()
74+
self.enc1 = UMambaBlock(in_channels, base_channels)
75+
self.down1 = nn.Conv3d(base_channels, base_channels*2, kernel_size=3, stride=2, padding=1)
76+
77+
self.enc2 = UMambaBlock(base_channels*2, base_channels*2)
78+
self.down2 = nn.Conv3d(base_channels*2, base_channels*4, kernel_size=3, stride=2, padding=1)
79+
80+
self.bottleneck = UMambaBlock(base_channels*4, base_channels*4)
81+
82+
self.up2 = nn.ConvTranspose3d(base_channels*4, base_channels*2, kernel_size=2, stride=2)
83+
self.dec2 = ResidualBlock(base_channels*4)
84+
85+
self.up1 = nn.ConvTranspose3d(base_channels*2, base_channels, kernel_size=2, stride=2)
86+
self.dec1 = ResidualBlock(base_channels*2)
87+
88+
self.final = nn.Conv3d(base_channels, out_channels, kernel_size=1)
89+
90+
def forward(self, x):
91+
x1 = self.enc1(x)
92+
x2 = self.enc2(self.down1(x1))
93+
x3 = self.bottleneck(self.down2(x2))
94+
95+
x = self.up2(x3)
96+
x = self.dec2(torch.cat([x, x2], dim=1))
97+
x = self.up1(x)
98+
x = self.dec1(torch.cat([x, x1], dim=1))
99+
return self.final(x)

tests/test_networks_u_mamba.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import unittest
2+
import torch
3+
from monai.networks.nets import UMamba
4+
5+
class TestUMamba(unittest.TestCase):
6+
def test_forward_shape(self):
7+
# Set up input dimensions and model
8+
input_tensor = torch.randn(2, 1, 64, 64) # (batch_size, channels, H, W)
9+
model = UMamba(in_channels=1, out_channels=2) # example args
10+
11+
# Forward pass
12+
output = model(input_tensor)
13+
14+
# Assert output shape matches expectation
15+
self.assertEqual(output.shape, (2, 2, 64, 64)) # adjust if necessary
16+
17+
def test_script(self):
18+
# Test JIT scripting if supported
19+
model = UMamba(in_channels=1, out_channels=2)
20+
scripted = torch.jit.script(model)
21+
x = torch.randn(1, 1, 64, 64)
22+
out = scripted(x)
23+
self.assertEqual(out.shape, (1, 2, 64, 64))
24+
25+
if __name__ == "__main__":
26+
unittest.main()

0 commit comments

Comments
 (0)