Skip to content

Commit

Permalink
added cryoDRGN encoder into embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jun 25, 2024
1 parent d9624f9 commit 67ba7e6
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions src/cryo_sbi/inference/models/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,27 @@ def forward(self, x):


@add_embedding("RESNET18_FFT")
class ResNet18_Encoder(nn.Module):
class ResNet18_FFT(nn.Module):
def __init__(self, output_dimension: int):
super(ResNet18_Encoder, self).__init__()
print("Using FFT ResNet18")
self.resnet = models.resnet18()
self.resnet.conv1 = nn.Conv2d(
2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.resnet.fc = nn.Linear(
in_features=512, out_features=output_dimension, bias=True
super(ResNet18_FFT, self).__init__()
print("Using FFT1")

self.drgn_encoder = nn.Sequential(
nn.Linear(12892, 512),
nn.GELU(),
nn.Linear(512, 256),
nn.GELU(),
nn.Linear(256, output_dimension),
nn.GELU(),
)
self.mask = Mask(128, 64, inside=True).mask.flatten()

def forward(self, x):
x = torch.fft.fftshift(torch.fft.fft2(x, dim=(-2, -1)))
x = torch.stack([x.real, x.imag], dim=1)
x = self.resnet(x)
if x.dim == 2:
x = x.unsqueeze(0)
x = torch.fft.fftshift(torch.fft.fft2(x, dim=(-2, -1))).real
x = x.flatten(start_dim=1)[:, self.mask]
x = self.drgn_encoder(x)
return x


Expand Down

0 comments on commit 67ba7e6

Please sign in to comment.