1212from __future__ import annotations
1313
1414import unittest
15+ from unittest import skipUnless
1516
1617import torch
1718from parameterized import parameterized
2122from monai .utils import optional_import
2223from tests .utils import SkipIfBeforePyTorchVersion , assert_allclose
2324
24- rearrange , _ = optional_import ("einops" , name = "rearrange " )
25+ einops , has_einops = optional_import ("einops" )
2526
2627
2728TEST_CASES_CAB = []
@@ -70,42 +71,49 @@ def test_gating_mechanism(self):
7071class TestCABlock (unittest .TestCase ):
7172
7273 @parameterized .expand (TEST_CASES_CAB )
74+ @skipUnless (has_einops , "Requires einops" )
7375 def test_shape (self , input_param , input_shape , expected_shape ):
7476 net = CABlock (** input_param )
7577 with eval_mode (net ):
7678 result = net (torch .randn (input_shape ))
7779 self .assertEqual (result .shape , expected_shape )
7880
81+ @skipUnless (has_einops , "Requires einops" )
7982 def test_invalid_spatial_dims (self ):
8083 with self .assertRaises (ValueError ):
8184 CABlock (spatial_dims = 4 , dim = 64 , num_heads = 4 , bias = True )
8285
8386 @SkipIfBeforePyTorchVersion ((2 , 0 ))
87+ @skipUnless (has_einops , "Requires einops" )
8488 def test_flash_attention (self ):
8589 device = "cuda" if torch .cuda .is_available () else "cpu"
8690 block = CABlock (spatial_dims = 2 , dim = 64 , num_heads = 4 , bias = True , flash_attention = True ).to (device )
8791 x = torch .randn (2 , 64 , 32 , 32 ).to (device )
8892 output = block (x )
8993 self .assertEqual (output .shape , x .shape )
9094
95+ @skipUnless (has_einops , "Requires einops" )
9196 def test_temperature_parameter (self ):
9297 block = CABlock (spatial_dims = 2 , dim = 64 , num_heads = 4 , bias = True )
9398 self .assertTrue (isinstance (block .temperature , torch .nn .Parameter ))
9499 self .assertEqual (block .temperature .shape , (4 , 1 , 1 ))
95100
101+ @skipUnless (has_einops , "Requires einops" )
96102 def test_qkv_transformation_2d (self ):
97103 block = CABlock (spatial_dims = 2 , dim = 64 , num_heads = 4 , bias = True )
98104 x = torch .randn (2 , 64 , 32 , 32 )
99105 qkv = block .qkv (x )
100106 self .assertEqual (qkv .shape , (2 , 192 , 32 , 32 ))
101107
108+ @skipUnless (has_einops , "Requires einops" )
102109 def test_qkv_transformation_3d (self ):
103110 block = CABlock (spatial_dims = 3 , dim = 64 , num_heads = 4 , bias = True )
104111 x = torch .randn (2 , 64 , 16 , 16 , 16 )
105112 qkv = block .qkv (x )
106113 self .assertEqual (qkv .shape , (2 , 192 , 16 , 16 , 16 ))
107114
108115 @SkipIfBeforePyTorchVersion ((2 , 0 ))
116+ @skipUnless (has_einops , "Requires einops" )
109117 def test_flash_vs_normal_attention (self ):
110118 device = "cuda" if torch .cuda .is_available () else "cpu"
111119 block_flash = CABlock (spatial_dims = 2 , dim = 64 , num_heads = 4 , bias = True , flash_attention = True ).to (device )
@@ -120,6 +128,7 @@ def test_flash_vs_normal_attention(self):
120128
121129 assert_allclose (out_flash , out_normal , atol = 1e-4 )
122130
131+ @skipUnless (has_einops , "Requires einops" )
123132 def test_deterministic_small_input (self ):
124133 block = CABlock (spatial_dims = 2 , dim = 2 , num_heads = 1 , bias = False )
125134 with torch .no_grad ():
0 commit comments