Skip to content

Commit

Permalink
Add type hinting for score models
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanCrabbe committed Dec 22, 2023
1 parent 7167575 commit 6a53ea5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
17 changes: 10 additions & 7 deletions src/fdiff/models/score_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
PositionalEncoding,
TimeEncoding,
)
from fdiff.schedulers.ddpm import CustomDDPMScheduler
from fdiff.schedulers.sde import SDE
from fdiff.utils.dataclasses import DiffusableBatch
from fdiff.utils.losses import get_ddpm_loss, get_sde_loss_fn
Expand All @@ -24,7 +23,7 @@ def __init__(
self,
n_channels: int,
max_len: int,
noise_scheduler: DDPMScheduler | CustomDDPMScheduler | SDE,
noise_scheduler: DDPMScheduler | SDE,
fourier_noise_scaling: bool = True,
d_model: int = 60,
num_layers: int = 3,
Expand Down Expand Up @@ -105,7 +104,6 @@ def training_step(
on_epoch=True,
on_step=True,
)
assert isinstance(loss, torch.Tensor)
return loss

def validation_step(
Expand All @@ -130,7 +128,12 @@ def configure_optimizers(self) -> OptimizerLRScheduler:
lr_scheduler_config = {"scheduler": lr_scheduler, "interval": "step"}
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}

def set_loss_fn(self) -> tuple[Callable, Callable]:
def set_loss_fn(
self,
) -> tuple[
Callable[[nn.Module, DiffusableBatch], torch.Tensor],
Callable[[nn.Module, DiffusableBatch], torch.Tensor],
]:
# Depending on the scheduler, get the right loss function

if isinstance(self.noise_scheduler, DDPMScheduler):
Expand Down Expand Up @@ -162,17 +165,17 @@ def set_loss_fn(self) -> tuple[Callable, Callable]:

else:
raise NotImplementedError(
"Scheduler not implemented yet, cannot set loss function"
f"Scheduler {self.noise_scheduler} not implemented yet, cannot set loss function."
)

def set_time_encoder(self) -> TimeEncoding | GaussianFourierProjection:
if isinstance(self.noise_scheduler, (DDPMScheduler, CustomDDPMScheduler)):
if isinstance(self.noise_scheduler, DDPMScheduler):
return TimeEncoding(d_model=self.d_model, max_time=self.max_time)

elif isinstance(self.noise_scheduler, SDE):
return GaussianFourierProjection(d_model=self.d_model)

else:
raise NotImplementedError(
"Scheduler not implemented yet, cannot set loss function"
f"Scheduler {self.noise_scheduler} not implemented yet, cannot set time encoder."
)
4 changes: 2 additions & 2 deletions src/fdiff/utils/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def loss_fn(model: nn.Module, batch: DiffusableBatch) -> torch.Tensor:
# grad log(p(x(t)|x(0))) = (-1) * Cov^{-1} (x(t) - mean)

# Reduction
losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1)
losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) # type: ignore

else:
# Compute the Mahalanobis distance, cf. https://arxiv.org/pdf/2111.13606.pdf + https://www.iro.umontreal.ca/~vincentp/Publications/smdae_techreport.pdf
Expand All @@ -120,7 +120,7 @@ def loss_fn(model: nn.Module, batch: DiffusableBatch) -> torch.Tensor:

# 3) Compute the loss
losses = torch.square(scaled_difference)
losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1)
losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) # type: ignore

loss = torch.mean(losses)
return loss
Expand Down

0 comments on commit 6a53ea5

Please sign in to comment.