Skip to content

Commit 30fad17

Browse files
committed
require einops for all tests
1 parent ce15886 commit 30fad17

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

tests/test_CABlock.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
import unittest
15+
from unittest import skipUnless
1516

1617
import torch
1718
from parameterized import parameterized
@@ -21,7 +22,7 @@
2122
from monai.utils import optional_import
2223
from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose
2324

24-
rearrange, _ = optional_import("einops", name="rearrange")
25+
einops, has_einops = optional_import("einops")
2526

2627

2728
TEST_CASES_CAB = []
@@ -70,42 +71,49 @@ def test_gating_mechanism(self):
7071
class TestCABlock(unittest.TestCase):
7172

7273
@parameterized.expand(TEST_CASES_CAB)
74+
@skipUnless(has_einops, "Requires einops")
7375
def test_shape(self, input_param, input_shape, expected_shape):
7476
net = CABlock(**input_param)
7577
with eval_mode(net):
7678
result = net(torch.randn(input_shape))
7779
self.assertEqual(result.shape, expected_shape)
7880

81+
@skipUnless(has_einops, "Requires einops")
7982
def test_invalid_spatial_dims(self):
8083
with self.assertRaises(ValueError):
8184
CABlock(spatial_dims=4, dim=64, num_heads=4, bias=True)
8285

8386
@SkipIfBeforePyTorchVersion((2, 0))
87+
@skipUnless(has_einops, "Requires einops")
8488
def test_flash_attention(self):
8589
device = "cuda" if torch.cuda.is_available() else "cpu"
8690
block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device)
8791
x = torch.randn(2, 64, 32, 32).to(device)
8892
output = block(x)
8993
self.assertEqual(output.shape, x.shape)
9094

95+
@skipUnless(has_einops, "Requires einops")
9196
def test_temperature_parameter(self):
9297
block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True)
9398
self.assertTrue(isinstance(block.temperature, torch.nn.Parameter))
9499
self.assertEqual(block.temperature.shape, (4, 1, 1))
95100

101+
@skipUnless(has_einops, "Requires einops")
96102
def test_qkv_transformation_2d(self):
97103
block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True)
98104
x = torch.randn(2, 64, 32, 32)
99105
qkv = block.qkv(x)
100106
self.assertEqual(qkv.shape, (2, 192, 32, 32))
101107

108+
@skipUnless(has_einops, "Requires einops")
102109
def test_qkv_transformation_3d(self):
103110
block = CABlock(spatial_dims=3, dim=64, num_heads=4, bias=True)
104111
x = torch.randn(2, 64, 16, 16, 16)
105112
qkv = block.qkv(x)
106113
self.assertEqual(qkv.shape, (2, 192, 16, 16, 16))
107114

108115
@SkipIfBeforePyTorchVersion((2, 0))
116+
@skipUnless(has_einops, "Requires einops")
109117
def test_flash_vs_normal_attention(self):
110118
device = "cuda" if torch.cuda.is_available() else "cpu"
111119
block_flash = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device)
@@ -120,6 +128,7 @@ def test_flash_vs_normal_attention(self):
120128

121129
assert_allclose(out_flash, out_normal, atol=1e-4)
122130

131+
@skipUnless(has_einops, "Requires einops")
123132
def test_deterministic_small_input(self):
124133
block = CABlock(spatial_dims=2, dim=2, num_heads=1, bias=False)
125134
with torch.no_grad():

0 commit comments

Comments
 (0)