diff --git a/netam/models.py b/netam/models.py index 783fe6ef..0f5b2854 100644 --- a/netam/models.py +++ b/netam/models.py @@ -720,7 +720,9 @@ def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: if self.output_dim == 1: return self.single_value.expand(amino_acid_indices.shape) else: - return self.single_value.expand(amino_acid_indices.shape + (self.output_dim,)) + return self.single_value.expand( + amino_acid_indices.shape + (self.output_dim,) + ) class HitClassModel(nn.Module):