Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
m5l14i11 committed Oct 5, 2024
1 parent d174507 commit 510b65d
Showing 1 changed file with 33 additions and 26 deletions.
59 changes: 33 additions & 26 deletions notebooks/models/autoencoder_v1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -9,20 +7,14 @@ class PositionalEncoder(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoder, self).__init__()

pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)

self.register_buffer("pe", pe)
self.pe = nn.Parameter(torch.randn(max_len, d_model) * 0.1)
self.layer_norm = nn.LayerNorm(d_model)

def forward(self, x):
pe_slice = self.pe[: x.size(0), :]
return x + pe_slice
pe_slice = self.pe[: x.size(1), :].unsqueeze(0)
x = x + pe_slice

return self.layer_norm(x)


class AutoEncoder(nn.Module):
Expand All @@ -31,6 +23,7 @@ def __init__(
segment_length: int,
n_features: int,
latent_dim: int = 32,
conv_filters=(64, 128),
num_heads: int = 4,
dropout_prob: float = 0.2,
activation_type: str = "leaky_relu",
Expand All @@ -47,13 +40,13 @@ def __init__(
self.positional_encoder = PositionalEncoder(latent_dim)

self.encoder = nn.Sequential(
nn.Conv1d(self.n_features, 64, kernel_size=3, padding=1),
nn.BatchNorm1d(64),
nn.Conv1d(self.n_features, conv_filters[0], kernel_size=3, padding=1),
nn.BatchNorm1d(conv_filters[0]),
self._get_activation(self.activation_type),
nn.MaxPool1d(2),
nn.Dropout(dropout_prob),
nn.Conv1d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm1d(128),
nn.Conv1d(conv_filters[0], conv_filters[1], kernel_size=3, padding=1),
nn.BatchNorm1d(conv_filters[1]),
self._get_activation(self.activation_type),
nn.MaxPool1d(2),
nn.Dropout(dropout_prob),
Expand All @@ -68,25 +61,29 @@ def __init__(
self.fc_encoder = nn.Sequential(
nn.Linear(128 * self.encoded_length, 256),
nn.LayerNorm(256),
self._get_activation("tanh"),
self._get_activation("gelu"),
nn.Linear(256, latent_dim),
nn.Dropout(dropout_prob),
)

self.fc_decoder = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LayerNorm(256),
self._get_activation("tanh"),
self._get_activation("gelu"),
nn.Linear(256, 128 * self.encoded_length),
nn.Dropout(dropout_prob),
)

self.decoder = nn.Sequential(
nn.ConvTranspose1d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm1d(64),
nn.ConvTranspose1d(
conv_filters[1], conv_filters[0], kernel_size=3, padding=1
),
nn.BatchNorm1d(conv_filters[0]),
self._get_activation(self.activation_type),
nn.Upsample(scale_factor=2),
nn.ConvTranspose1d(64, self.n_features, kernel_size=3, padding=1),
nn.ConvTranspose1d(
conv_filters[0], self.n_features, kernel_size=3, padding=1
),
nn.BatchNorm1d(self.n_features),
self._get_activation(self.activation_type),
nn.Upsample(scale_factor=2),
Expand All @@ -95,7 +92,16 @@ def __init__(
self.residual = nn.Sequential(
nn.Conv1d(self.n_features, self.n_features, kernel_size=3, padding=1),
nn.BatchNorm1d(self.n_features),
self._get_activation("relu"),
self._get_activation("leaky_relu"),
)

self.residual_scaler = nn.Sequential(
nn.Linear(latent_dim, 128),
self._get_activation("gelu"),
nn.Linear(128, 64),
self._get_activation("gelu"),
nn.Linear(64, 1),
nn.Sigmoid(),
)

self.apply(self._init_weights)
Expand All @@ -114,7 +120,7 @@ def forward(self, x):
latent = self.positional_encoder(latent)

attn_encoded, _ = self.attention_encoder(latent, latent, latent)
attn_encoded = self._get_activation(self.activation_type)(attn_encoded)
attn_encoded = self._get_activation("gelu")(attn_encoded)
attn_encoded = attn_encoded + latent
latent = attn_encoded.squeeze(1)

Expand All @@ -125,8 +131,9 @@ def forward(self, x):

residual_out = self.residual(identity)
residual_out = residual_out.permute(0, 2, 1)
residual_scale = self.residual_scaler(latent).unsqueeze(1)

output = decoded + residual_out
output = decoded + residual_scale * residual_out

return output

Expand Down

0 comments on commit 510b65d

Please sign in to comment.