@@ -106,20 +106,20 @@ def configure_model(self) -> None:
106106 kernel_size = (
107107 self .nr_of_electrode_grids ,
108108 (
109- int (np .floor (self .nr_of_electrodes_per_grid / 2 ))
110- + (0 if self .nr_of_electrodes_per_grid % 2 == 0 else 1 )
109+ int (np .floor (self .nr_of_electrodes_per_grid / 2 ))
110+ + (0 if self .nr_of_electrodes_per_grid % 2 == 0 else 1 )
111111 ),
112112 18 ,
113113 ),
114114 dilation = (1 , 2 , 1 ),
115115 padding = (
116116 (
117- int (np .floor (self .nr_of_electrode_grids / 2 ))
118- + (0 if self .nr_of_electrode_grids % 2 == 0 else 1 )
117+ int (np .floor (self .nr_of_electrode_grids / 2 ))
118+ + (0 if self .nr_of_electrode_grids % 2 == 0 else 1 )
119119 ),
120120 (
121- int (np .floor (self .nr_of_electrodes_per_grid / 4 ))
122- + (0 if self .nr_of_electrodes_per_grid % 4 == 0 else 1 )
121+ int (np .floor (self .nr_of_electrodes_per_grid / 4 ))
122+ + (0 if self .nr_of_electrodes_per_grid % 4 == 0 else 1 )
123123 ),
124124 0 ,
125125 ),
@@ -133,8 +133,8 @@ def configure_model(self) -> None:
133133 kernel_size = (
134134 self .nr_of_electrode_grids ,
135135 (
136- int (np .floor (self .nr_of_electrodes_per_grid / 7 ))
137- + (0 if self .nr_of_electrodes_per_grid % 7 == 0 else 1 )
136+ int (np .floor (self .nr_of_electrodes_per_grid / 7 ))
137+ + (0 if self .nr_of_electrodes_per_grid % 7 == 0 else 1 )
138138 ),
139139 1 ,
140140 ),
@@ -171,9 +171,7 @@ def configure_model(self) -> None:
171171
172172 self .mlp = mlp
173173
174- model = nn .Sequential (
175- self .cnn_encoder , self .mlp
176- )
174+ model = nn .Sequential (self .cnn_encoder , self .mlp )
177175
178176 self .model = torch .jit .script (model )
179177
@@ -185,14 +183,12 @@ def forward(self, inputs) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tens
185183 def _reshape_and_normalize (self , inputs ):
186184 x = torch .stack (inputs .split (self .nr_of_electrodes_per_grid , dim = 2 ), dim = 2 )
187185
188- if self .training_means .device != x .device :
189- self .training_means = self .training_means .to (x .device )
190- self .training_stds = self .training_stds .to (x .device )
191-
192186 if self .training_means is not None and self .training_stds is not None :
193- return (x - self .training_means ) / (
194- self .training_stds + 1e-15
195- )
187+ if self .training_means .device != x .device :
188+ self .training_means = self .training_means .to (x .device )
189+ self .training_stds = self .training_stds .to (x .device )
190+
191+ return (x - self .training_means ) / (self .training_stds + 1e-15 )
196192
197193 return (x - x .mean (dim = (3 , 4 ), keepdim = True )) / (
198194 x .std (dim = (3 , 4 ), keepdim = True , unbiased = True ) + 1e-15
0 commit comments