Skip to content

Commit 3ab9213

Browse files
committed
Fix test input shape and class name for UMambaUNet
1 parent b554a10 commit 3ab9213

File tree

3 files changed

+5
-109
lines changed

3 files changed

+5
-109
lines changed

networks/nets/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,3 @@
144144
from .vnet import VNet
145145
from .voxelmorph import VoxelMorph, VoxelMorphUNet
146146
from .vqvae import VQVAE
147-
from .u_mamba import UMamba

networks/nets/u_mamba.py

Lines changed: 0 additions & 99 deletions
This file was deleted.

tests/test_networks_u_mamba.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
11
import unittest
22
import torch
3-
from monai.networks.nets import UMamba
3+
from monai.networks.nets import UMambaUNet
44

55
class TestUMamba(unittest.TestCase):
66
def test_forward_shape(self):
77
# Set up input dimensions and model
8-
input_tensor = torch.randn(2, 1, 64, 64) # (batch_size, channels, H, W)
9-
model = UMamba(in_channels=1, out_channels=2) # example args
10-
11-
# Forward pass
8+
input_tensor = torch.randn(2, 1, 16, 64, 64)
9+
model = UMambaUNet(in_channels=1, out_channels=2)
1210
output = model(input_tensor)
13-
14-
# Assert output shape matches expectation
15-
self.assertEqual(output.shape, (2, 2, 64, 64)) # adjust if necessary
11+
self.assertEqual(output.shape, (2, 2, 16, 64, 64))
1612

1713
def test_script(self):
1814
# Test JIT scripting if supported
19-
model = UMamba(in_channels=1, out_channels=2)
15+
model = UMambaUNet(in_channels=1, out_channels=2)
2016
scripted = torch.jit.script(model)
2117
x = torch.randn(1, 1, 64, 64)
2218
out = scripted(x)

0 commit comments

Comments
 (0)