Skip to content

Commit

Permalink
small bug fixes and changes in the embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jul 25, 2023
1 parent 3debf9a commit 23e0a65
Show file tree
Hide file tree
Showing 4 changed files with 1,078 additions and 712 deletions.
1,690 changes: 993 additions & 697 deletions notebooks/Untitled.ipynb

Large diffs are not rendered by default.

92 changes: 81 additions & 11 deletions src/cryo_sbi/inference/models/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,13 @@ class EfficientNet_Encoder(nn.Module):
def __init__(self, output_dimension: int):
super(EfficientNet_Encoder, self).__init__()

self.efficient_net = models.efficientnet_b0().features
self.efficient_net = models.efficientnet_b3().features
self.efficient_net[0][0] = nn.Conv2d(
1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
1, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
)
self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
self.leakyrelu = nn.LeakyReLU()
self.linear = nn.Linear(1280, output_dimension)
self.linear = nn.Linear(1536, output_dimension)

def forward(self, x):
x = x.unsqueeze(1)
Expand All @@ -140,7 +140,7 @@ def forward(self, x):
return x


@add_embedding("SWINS_FFT_FILTER")
@add_embedding("SWINS")
class SwinTransformerS_Encoder(nn.Module):
def __init__(self, output_dimension: int):
super(SwinTransformerS_Encoder, self).__init__()
Expand All @@ -152,12 +152,8 @@ def __init__(self, output_dimension: int):
self.swin_transformer.head = nn.Linear(
in_features=768, out_features=output_dimension, bias=True
)
self._fft_filter = LowPassFilter(128, 25)

def forward(self, x):
# Low pass filter images
x = self._fft_filter(x)
# Proceed as normal
x = x.unsqueeze(1)
x = self.swin_transformer(x)
return x
Expand Down Expand Up @@ -199,17 +195,17 @@ def forward(self, x):
return x


@add_embedding("REGNET")
@add_embedding("REGNETY")
class RegNetY_Encoder(nn.Module):
def __init__(self, output_dimension: int):
super(RegNetY_Encoder, self).__init__()

self.regnety = models.regnet_y_800mf()
self.regnety = models.regnet_y_1_6gf()
self.regnety.stem[0] = nn.Conv2d(
1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
)
self.regnety.fc = nn.Linear(
in_features=784, out_features=output_dimension, bias=True
in_features=888, out_features=output_dimension, bias=True
)

def forward(self, x):
Expand Down Expand Up @@ -282,5 +278,79 @@ def forward(self, x):
x = self.resnet(x)
return x


@add_embedding("RESNET34")
class ResNet34_Encoder(nn.Module):
def __init__(self, output_dimension: int):
super(ResNet34_Encoder, self).__init__()
self.resnet = models.resnet34()
self.resnet.conv1 = nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.resnet.fc = nn.Linear(
in_features=512, out_features=output_dimension, bias=True
)

def forward(self, x):
# Proceed as normal
x = x.unsqueeze(1)
x = self.resnet(x)
return x


@add_embedding("RESNET34_256_LP")
class ResNet34_Encoder(nn.Module):
def __init__(self, output_dimension: int):
super(ResNet34_Encoder, self).__init__()
self.resnet = models.resnet34()
self.resnet.conv1 = nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.resnet.fc = nn.Linear(
in_features=512, out_features=output_dimension, bias=True
)
self._fft_filter = LowPassFilter(256, 50)

def forward(self, x):
# Low pass filter images
x = self._fft_filter(x)
# Proceed as normal
x = x.unsqueeze(1)
x = self.resnet(x)
return x


@add_embedding("VGG19")
class VGG19_Encoder(nn.Module):
def __init__(self, output_dimension: int):
super(VGG19_Encoder, self).__init__()

self.vgg19 = models.vgg19_bn().features
self.vgg19[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

self.avgpool = nn.AdaptiveAvgPool2d(output_size=(7, 7))

self.feedforward = nn.Sequential(
*[
nn.Linear(in_features=25088, out_features=4096),
nn.ReLU(inplace=True),
nn.Linear(in_features=4096, out_features=output_dimension, bias=True),
nn.ReLU(inplace=True)
]
)

#self._fft_filter = LowPassFilter(256, 50)

def forward(self, x):
# Low pass filter images
#x = self._fft_filter(x)
# Proceed as normal
x = x.unsqueeze(1)
x = self.vgg19(x)
x = self.avgpool(x).flatten(start_dim=1)
x = self.feedforward(x)
return x


if __name__ == "__main__":
pass
6 changes: 3 additions & 3 deletions src/cryo_sbi/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,9 @@ def get_image(self, idx: Union[int, list]):
self._index_map is not None
), "Index map not built. First call build_index_map()"
if isinstance(idx, int):
return mrc_to_tensor(self.paths[self._path_index[idx]])[
self._file_index[idx]
]
image = mrc_to_tensor(self.paths[self._path_index[idx]])
if image.ndim > 2:
return image[self._file_index[idx]]
if isinstance(idx, (list, np.ndarray, torch.Tensor)):
return [
mrc_to_tensor(self.paths[self._path_index[i]])[self._file_index[i]]
Expand Down
2 changes: 1 addition & 1 deletion src/cryo_sbi/wpa_simulator/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,6 @@ def project_density(
-0.5 * (((grid.unsqueeze(-1) - coords_rot[:, 1, :].unsqueeze(1)) / sigma) ** 2)
).transpose(1, 2)

image = torch.bmm(gauss_x, gauss_y) * norm.reshape(-1, 1, 1)
image = torch.bmm(gauss_x, gauss_y) * norm.reshape(-1, 1, 1)

return image

0 comments on commit 23e0a65

Please sign in to comment.