Skip to content

Commit

Permalink
current updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Aug 11, 2023
1 parent 23e0a65 commit 625f927
Show file tree
Hide file tree
Showing 4 changed files with 465 additions and 1,060 deletions.
1,428 changes: 375 additions & 1,053 deletions notebooks/Untitled.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/cryo_sbi/inference/models/build_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module:
flow=model,
theta_shift=config["THETA_SHIFT"],
theta_scale=config["THETA_SCALE"],
**{"activation": partial(nn.LeakyReLU, 0.1)},
**{"activation": nn.GELU},#partial(nn.LeakyReLU, 0.1), nn.GELU
)
print("Training with GELU")

return estimator

Expand Down
91 changes: 86 additions & 5 deletions src/cryo_sbi/inference/models/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@ def add(class_):
return add


@add_embedding("RESNET18_TEST2")
class ResNet18_Encoder_Test2(nn.Module):
def __init__(self, output_dimension: int):
super(ResNet18_Encoder_Test2, self).__init__()
self.resnet = models.resnet18()
self.resnet.conv1 = nn.Conv2d(
1, 64, kernel_size=(15, 15), stride=(2, 2), padding=(3, 3), bias=False
)
self.resnet.fc = nn.Linear(
in_features=512, out_features=output_dimension, bias=True
)

def forward(self, x):
x = x.unsqueeze(1)
x = self.resnet(x)
return x


@add_embedding("RESNET18")
class ResNet18_Encoder(nn.Module):
def __init__(self, output_dimension: int):
Expand All @@ -43,6 +61,24 @@ def forward(self, x):
x = x.unsqueeze(1)
x = self.resnet(x)
return x


@add_embedding("RESNET18_TEST")
class ResNet18_Encoder_Test(nn.Module):
def __init__(self, output_dimension: int):
super(ResNet18_Encoder_Test, self).__init__()
print("Training with avg pooling")
self.resnet = models.resnet18()
self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.resnet.fc = nn.Linear(
in_features=512, out_features=output_dimension, bias=True
)
self.resnet.maxpool = nn.MaxPool2d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=False)

def forward(self, x):
x = x.unsqueeze(1)
x = self.resnet(x)
return x


@add_embedding("RESNET50")
Expand Down Expand Up @@ -124,19 +160,16 @@ class EfficientNet_Encoder(nn.Module):
def __init__(self, output_dimension: int):
super(EfficientNet_Encoder, self).__init__()

self.efficient_net = models.efficientnet_b3().features
self.efficient_net = models.efficientnet_b0().features
self.efficient_net[0][0] = nn.Conv2d(
1, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
)
self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
self.leakyrelu = nn.LeakyReLU()
self.linear = nn.Linear(1536, output_dimension)

def forward(self, x):
x = x.unsqueeze(1)
x = self.efficient_net(x)
x = self.avg_pool(x).flatten(start_dim=1)
x = self.leakyrelu(self.linear(x))
return x


Expand Down Expand Up @@ -254,6 +287,54 @@ def forward(self, x):
x = x.unsqueeze(1)
x = self.resnet(x)
return x


@add_embedding("RESNET18_FFT_FILTER_224")
class ResNet18_FFT_Encoder_224(nn.Module):
def __init__(self, output_dimension: int):
super(ResNet18_FFT_Encoder_224, self).__init__()
self.resnet = models.resnet18()
self.resnet.conv1 = nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.resnet.fc = nn.Linear(
in_features=512, out_features=output_dimension, bias=True
)

self._fft_filter = LowPassFilter(224, 30)
print("embedding with 224 lp 30")

def forward(self, x):
# Low pass filter images
x = self._fft_filter(x)
# Proceed as normal
x = x.unsqueeze(1)
x = self.resnet(x)
return x


@add_embedding("RESNET18_FFT_FILTER_224_LP25")
class ResNet18_FFT_Encoder_224_LP25(nn.Module):
def __init__(self, output_dimension: int):
super(ResNet18_FFT_Encoder_224_LP25, self).__init__()
self.resnet = models.resnet18()
self.resnet.conv1 = nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.resnet.fc = nn.Linear(
in_features=512, out_features=output_dimension, bias=True
)

self._fft_filter = LowPassFilter(224, 25)
print("embedding with 224 lp 25")

def forward(self, x):
# Low pass filter images
x = self._fft_filter(x)
# Proceed as normal
x = x.unsqueeze(1)
x = self.resnet(x)
return x


@add_embedding("RESNET18_FFT_FILTER_132")
Expand Down
3 changes: 2 additions & 1 deletion src/cryo_sbi/inference/train_npe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,15 @@ def npe_train_no_saving(
estimator = load_model(
train_config, model_state_dict, device, train_from_checkpoint
)


loss = NPELoss(estimator)
optimizer = optim.AdamW(
estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=0.001
)
step = GDStep(optimizer, clip=train_config["CLIP_GRADIENT"])
mean_loss = []

print("Training neural netowrk:")
estimator.train()
with tqdm(range(epochs), unit="epoch") as tq:
Expand Down

0 comments on commit 625f927

Please sign in to comment.