diff --git a/sgm/models/autoencoder.py b/sgm/models/autoencoder.py index 2949b910..8eccc080 100644 --- a/sgm/models/autoencoder.py +++ b/sgm/models/autoencoder.py @@ -549,15 +549,17 @@ def __init__( class IdentityFirstStage(AbstractAutoencoder): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.encoder = lambda x: x + self.decoder = lambda x: x def get_input(self, x: Any) -> Any: return x def encode(self, x: Any, *args, **kwargs) -> Any: - return x + return self.encoder(x) def decode(self, x: Any, *args, **kwargs) -> Any: - return x + return self.decoder(x) class AEIntegerWrapper(nn.Module):