Skip to content

Commit e7369b5

Browse files
committed
Merge remote-tracking branch 'origin/main'
2 parents 72d0fd4 + 5477742 commit e7369b5

1 file changed

Lines changed: 14 additions & 18 deletions

File tree

  • myoverse/models/definitions/raul_net/online

myoverse/models/definitions/raul_net/online/v17.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,20 @@ def configure_model(self) -> None:
106106
kernel_size=(
107107
self.nr_of_electrode_grids,
108108
(
109-
int(np.floor(self.nr_of_electrodes_per_grid / 2))
110-
+ (0 if self.nr_of_electrodes_per_grid % 2 == 0 else 1)
109+
int(np.floor(self.nr_of_electrodes_per_grid / 2))
110+
+ (0 if self.nr_of_electrodes_per_grid % 2 == 0 else 1)
111111
),
112112
18,
113113
),
114114
dilation=(1, 2, 1),
115115
padding=(
116116
(
117-
int(np.floor(self.nr_of_electrode_grids / 2))
118-
+ (0 if self.nr_of_electrode_grids % 2 == 0 else 1)
117+
int(np.floor(self.nr_of_electrode_grids / 2))
118+
+ (0 if self.nr_of_electrode_grids % 2 == 0 else 1)
119119
),
120120
(
121-
int(np.floor(self.nr_of_electrodes_per_grid / 4))
122-
+ (0 if self.nr_of_electrodes_per_grid % 4 == 0 else 1)
121+
int(np.floor(self.nr_of_electrodes_per_grid / 4))
122+
+ (0 if self.nr_of_electrodes_per_grid % 4 == 0 else 1)
123123
),
124124
0,
125125
),
@@ -133,8 +133,8 @@ def configure_model(self) -> None:
133133
kernel_size=(
134134
self.nr_of_electrode_grids,
135135
(
136-
int(np.floor(self.nr_of_electrodes_per_grid / 7))
137-
+ (0 if self.nr_of_electrodes_per_grid % 7 == 0 else 1)
136+
int(np.floor(self.nr_of_electrodes_per_grid / 7))
137+
+ (0 if self.nr_of_electrodes_per_grid % 7 == 0 else 1)
138138
),
139139
1,
140140
),
@@ -171,9 +171,7 @@ def configure_model(self) -> None:
171171

172172
self.mlp = mlp
173173

174-
model = nn.Sequential(
175-
self.cnn_encoder, self.mlp
176-
)
174+
model = nn.Sequential(self.cnn_encoder, self.mlp)
177175

178176
self.model = torch.jit.script(model)
179177

@@ -185,14 +183,12 @@ def forward(self, inputs) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tens
185183
def _reshape_and_normalize(self, inputs):
186184
x = torch.stack(inputs.split(self.nr_of_electrodes_per_grid, dim=2), dim=2)
187185

188-
if self.training_means.device != x.device:
189-
self.training_means = self.training_means.to(x.device)
190-
self.training_stds = self.training_stds.to(x.device)
191-
192186
if self.training_means is not None and self.training_stds is not None:
193-
return (x - self.training_means) / (
194-
self.training_stds + 1e-15
195-
)
187+
if self.training_means.device != x.device:
188+
self.training_means = self.training_means.to(x.device)
189+
self.training_stds = self.training_stds.to(x.device)
190+
191+
return (x - self.training_means) / (self.training_stds + 1e-15)
196192

197193
return (x - x.mean(dim=(3, 4), keepdim=True)) / (
198194
x.std(dim=(3, 4), keepdim=True, unbiased=True) + 1e-15

0 commit comments

Comments
 (0)