Skip to content

Commit

Permalink
Add sampling callback
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanCrabbe committed Dec 27, 2023
1 parent b805018 commit 639a287
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 0 deletions.
13 changes: 13 additions & 0 deletions cmd/conf/trainer/callbacks/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,16 @@
- _target_: pytorch_lightning.callbacks.EarlyStopping
monitor: val/loss
patience: 20
- _target_: fdiff.utils.callbacks.SamplingCallback
every_n_epochs: 1
sample_batch_size: ${datamodule.batch_size}
num_samples: 200
num_diffusion_steps: 1000
metrics:
- _target_: fdiff.sampling.metrics.SlicedWasserstein
_partial_: true
random_seed: ${random_seed}
num_directions: 200
- _target_: fdiff.sampling.metrics.MarginalWasserstein
_partial_: true
random_seed: ${random_seed}
6 changes: 6 additions & 0 deletions cmd/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from fdiff.dataloaders.datamodules import Datamodule
from fdiff.models.score_models import ScoreModule
from fdiff.utils.callbacks import SamplingCallback
from fdiff.utils.extraction import dict_to_str, get_training_params
from fdiff.utils.wandb import maybe_initialize_wandb

Expand Down Expand Up @@ -50,6 +51,11 @@ def __init__(self, cfg: DictConfig) -> None:
training_params = get_training_params(self.datamodule, self.trainer)
self.score_model = self.score_model(**training_params)

# Possibly setup the datamodule in the sampling callback
for callback in self.trainer.callbacks: # type: ignore
if isinstance(callback, SamplingCallback):
callback.setup_datamodule(datamodule=self.datamodule)

def train(self) -> None:
assert not (
self.score_model.scale_noise and not self.datamodule.fourier_transform
Expand Down
86 changes: 86 additions & 0 deletions src/fdiff/utils/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytorch_lightning as pl
import torch

from fdiff.dataloaders.datamodules import Datamodule
from fdiff.models.score_models import ScoreModule
from fdiff.sampling.metrics import Metric, MetricCollection
from fdiff.sampling.sampler import DiffusionSampler

from .fourier import idft


class SamplingCallback(pl.Callback):
def __init__(
self,
every_n_epochs: int,
sample_batch_size: int,
num_samples: int,
num_diffusion_steps: int,
metrics: list[Metric],
) -> None:
super().__init__()
self.every_n_epochs = every_n_epochs
self.sample_batch_size = sample_batch_size
self.num_samples = num_samples
self.num_diffusion_steps = num_diffusion_steps
self.metrics = metrics
self.datamodule_initialized = False

def setup_datamodule(self, datamodule: Datamodule) -> None:
# Exract the necessary information from the datamodule
self.standardize = datamodule.standardize
self.fourier_transform = datamodule.fourier_transform
self.feature_mean, self.feature_std = datamodule.feature_mean_and_std
self.metric_collection = MetricCollection(
metrics=self.metrics,
original_samples=datamodule.X_train,
include_baselines=False,
)
self.datamodule_initialized = True

def on_train_start(self, trainer: pl.Trainer, pl_module: ScoreModule) -> None:
# Initialize the sampler with the score model
self.sampler = DiffusionSampler(
score_model=pl_module,
sample_batch_size=self.sample_batch_size,
)

def on_train_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
if trainer.current_epoch % self.every_n_epochs == 0:
# Sample from score model
X = self.sample()

# Compute metrics
results = self.metric_collection(X)

# Add a metrics/ suffix to the keys in results
results = {f"metrics/{key}": value for key, value in results.items()}

# Log metrics
pl_module.log_dict(results, on_step=False, on_epoch=True)

def sample(self) -> torch.Tensor:
# Check that the dqtqmodule is initialized
assert self.datamodule_initialized, (
"The datamodule has not been initialized. "
"Please call `setup_datamodule` before sampling."
)

# Sample from score model

X = self.sampler.sample(
num_samples=self.num_samples,
num_diffusion_steps=self.num_diffusion_steps,
)

# Map to the original scale if the input was standardized
if self.standardize:
X = X * self.feature_std + self.feature_mean

# If sampling in frequency domain, bring back the sample to time domain
if self.fourier_transform:
X = idft(X)
assert isinstance(X, torch.Tensor)
return X

0 comments on commit 639a287

Please sign in to comment.