diff --git a/src/model.py b/src/model.py index aabf559..355bc4c 100644 --- a/src/model.py +++ b/src/model.py @@ -90,6 +90,7 @@ def _get_conv_output(self, shape): # forward def forward(self, x): + x = nn.Dropout2d(0.1)(x) x = x.transpose(1, 2) x = self.conv1(x) x = self.conv2(x) @@ -102,3 +103,4 @@ def forward(self, x): x = self.fc2(x) x = self.fc3(x) return x +