Skip to content

Commit

Permalink
Merge pull request #1 from JonathanCrabbe/fourier
Browse files Browse the repository at this point in the history
Fourier Diffusion
  • Loading branch information
nicolashuynh authored Dec 1, 2023
2 parents d202c75 + c1aaabc commit 252e2a0
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 24 deletions.
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ repos:
rev: v0.991
hooks:
- id: mypy
args: [
args:
[
"--ignore-missing-imports",
"--scripts-are-modules",
"--disallow-incomplete-defs",
Expand All @@ -35,5 +36,5 @@ repos:
"--disallow-untyped-calls",
"--install-types",
"--non-interactive",
"--follow-imports=skip", # This is temporary until the mbi directory is not excluded
"--follow-imports=skip",
]
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ conda activate fdiff
pip install -e .
```
4. If you intend to train models, make sure that wandb is correctly configured on your machine by following [this guide](https://docs.wandb.ai/quickstart).
5. Some of the datasets are automatically downloaded by our scripts via kaggle API. Make sure to create a kaggle token as explained [here](https://towardsdatascience.com/downloading-datasets-from-kaggle-for-your-ml-project-b9120d405ea4).

When the packages are installed, you are ready to train diffusion models!

Expand Down
1 change: 1 addition & 0 deletions cmd/conf/datamodule/ecg.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
_target_: fdiff.dataloaders.datamodules.ECGDatamodule
data_dir: ${hydra:runtime.cwd}/data
random_seed: ${random_seed}
fourier_transform: ${fourier_transform}
batch_size: 64
1 change: 1 addition & 0 deletions cmd/conf/train.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
random_seed: 42
fourier_transform: false
defaults:
- _self_
- score_model: default
Expand Down
6 changes: 6 additions & 0 deletions cmd/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fdiff.sampling.metrics import MetricCollection
from fdiff.sampling.sampler import DiffusionSampler
from fdiff.utils.extraction import dict_to_str, get_best_checkpoint
from fdiff.utils.fourier import idft


class SamplingRunner:
Expand All @@ -38,6 +39,7 @@ def __init__(self, cfg: DictConfig) -> None:
# Read training config from model directory and instantiate the right datamodule
train_cfg = OmegaConf.load(self.save_dir / "train_config.yaml")
self.datamodule: Datamodule = instantiate(train_cfg.datamodule)
self.fourier_transform: bool = self.datamodule.fourier_transform
self.datamodule.prepare_data()
self.datamodule.setup()

Expand Down Expand Up @@ -69,6 +71,10 @@ def sample(self) -> None:
num_samples=self.num_samples, num_diffusion_steps=self.num_diffusion_steps
)

# If sampling in frequency domain, bring back the sample to time domain
if self.fourier_transform:
X = idft(X)

# Compute metrics
results = self.metrics(X)
logging.info(f"Metrics:\n{dict_to_str(results)}")
Expand Down
32 changes: 26 additions & 6 deletions src/fdiff/dataloaders/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,26 @@
from torch.utils.data import DataLoader, Dataset

from fdiff.utils.dataclasses import collate_batch
from fdiff.utils.fourier import dft


class DiffusionDataset(Dataset):
def __init__(self, X: torch.Tensor, y: Optional[torch.Tensor] = None):
def __init__(
self,
X: torch.Tensor,
y: Optional[torch.Tensor] = None,
fourier_transform: bool = False,
) -> None:
super().__init__()
if fourier_transform:
X = dft(X).detach()
self.X = X
self.y = y

def __len__(self) -> int:
return len(self.X)

def __getitem__(self, index) -> dict[str, torch.Tensor]:
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
data = {}
data["X"] = self.X[index]
if self.y is not None:
Expand All @@ -35,6 +43,7 @@ def __init__(
data_dir: Path | str = Path.cwd() / "data",
random_seed: int = 42,
batch_size: int = 32,
fourier_transform: bool = False,
) -> None:
super().__init__()
# Cast data_dir to Path type
Expand All @@ -43,6 +52,7 @@ def __init__(
self.data_dir = data_dir / self.dataset_name
self.random_seed = random_seed
self.batch_size = batch_size
self.fourier_transform = fourier_transform
self.X_train = torch.Tensor()
self.y_train: Optional[torch.Tensor] = None
self.X_test = torch.Tensor()
Expand All @@ -61,7 +71,9 @@ def download_data(self) -> None:
...

def train_dataloader(self) -> DataLoader:
train_set = DiffusionDataset(X=self.X_train, y=self.y_train)
train_set = DiffusionDataset(
X=self.X_train, y=self.y_train, fourier_transform=self.fourier_transform
)
return DataLoader(
train_set,
batch_size=self.batch_size,
Expand All @@ -70,7 +82,9 @@ def train_dataloader(self) -> DataLoader:
)

def test_dataloader(self) -> DataLoader:
test_set = DiffusionDataset(X=self.X_test, y=self.y_test)
test_set = DiffusionDataset(
X=self.X_test, y=self.y_test, fourier_transform=self.fourier_transform
)
return DataLoader(
test_set,
batch_size=self.batch_size,
Expand All @@ -79,7 +93,9 @@ def test_dataloader(self) -> DataLoader:
)

def val_dataloader(self) -> DataLoader:
test_set = DiffusionDataset(X=self.X_test, y=self.y_test)
test_set = DiffusionDataset(
X=self.X_test, y=self.y_test, fourier_transform=self.fourier_transform
)
return DataLoader(
test_set,
batch_size=self.batch_size,
Expand All @@ -106,9 +122,13 @@ def __init__(
data_dir: Path | str = Path.cwd() / "data",
random_seed: int = 42,
batch_size: int = 32,
fourier_transform: bool = False,
) -> None:
super().__init__(
data_dir=data_dir, random_seed=random_seed, batch_size=batch_size
data_dir=data_dir,
random_seed=random_seed,
batch_size=batch_size,
fourier_transform=fourier_transform,
)

def setup(self, stage: str = "fit") -> None:
Expand Down
35 changes: 28 additions & 7 deletions src/fdiff/sampling/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import torch

from fdiff.utils.fourier import dft
from fdiff.utils.tensors import check_flat_array
from fdiff.utils.wasserstein import WassersteinDistances

Expand Down Expand Up @@ -33,29 +34,49 @@ def __init__(
original_samples: Optional[np.ndarray | torch.Tensor] = None,
include_baselines: bool = True,
) -> None:
for i, metric in enumerate(metrics):
metrics_time: list[Metric] = []
metrics_freq: list[Metric] = []

original_samples_freq = (
dft(original_samples) if original_samples is not None else None
)

for metric in metrics:
# If metric is partially instantiated, instantiate it with original samples
if isinstance(metric, partial):
assert (
original_samples is not None
), f"Original samples must be provided for metric {metric.name} to be instantiated."
metrics[i] = metric(original_samples=original_samples) # type: ignore
self.metrics = metrics
metrics_time.append(metric(original_samples=original_samples)) # type: ignore
metrics_freq.append(metric(original_samples=original_samples_freq)) # type: ignore
self.metrics_time = metrics_time
self.metrics_freq = metrics_freq
self.include_baselines = include_baselines

def __call__(self, other_samples: np.ndarray | torch.Tensor) -> dict[str, float]:
metric_dict = {}
for metric in self.metrics:
metric_dict.update(metric(other_samples))
other_samples_freq = dft(other_samples)
for metric_time, metric_freq in zip(self.metrics_time, self.metrics_freq):
metric_dict.update(
{f"time_{k}": v for k, v in metric_time(other_samples).items()}
)
metric_dict.update(
{f"freq_{k}": v for k, v in metric_freq(other_samples_freq).items()}
)
if self.include_baselines:
metric_dict.update(self.baseline_metrics)
return dict(sorted(metric_dict.items(), key=lambda item: item[0]))

@property
def baseline_metrics(self) -> dict[str, float]:
metric_dict = {}
for metric in self.metrics:
metric_dict.update(metric.baseline_metrics)
for metric_time, metric_freq in zip(self.metrics_time, self.metrics_freq):
metric_dict.update(
{f"time_{k}": v for k, v in metric_time.baseline_metrics.items()}
)
metric_dict.update(
{f"freq_{k}": v for k, v in metric_freq.baseline_metrics.items()}
)
return metric_dict


Expand Down
86 changes: 86 additions & 0 deletions src/fdiff/utils/fourier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import math

import torch
from torch.fft import irfft, rfft


def dft(x: torch.Tensor) -> torch.Tensor:
"""Compute the DFT of the input time series by keeping only the non-redundant components.
Args:
x (torch.Tensor): Time series of shape (batch_size, max_len, n_channels).
Returns:
torch.Tensor: DFT of x with the same size (batch_size, max_len, n_channels).
"""

max_len = x.size(1)

# Compute the FFT until the Nyquist frequency
dft_full = rfft(x, dim=1, norm="ortho")
dft_re = torch.real(dft_full)
dft_im = torch.imag(dft_full)

# The first harmonic corresponds to the mean, which is always real
zero_padding = torch.zeros_like(dft_im[:, 0, :], device=x.device)
assert torch.allclose(
dft_im[:, 0, :], zero_padding
), f"The first harmonic of a real time series should be real, yet got imaginary part {dft_im[:, 0, :]}."
dft_im = dft_im[:, 1:]

# If max_len is even, the last component is always zero
if max_len % 2 == 0:
assert torch.allclose(
dft_im[:, -1, :], zero_padding
), f"Got an even {max_len=}, which should be real at the Nyquist frequency, yet got imaginary part {dft_im[:, -1, :]}."
dft_im = dft_im[:, :-1]

# Concatenate real and imaginary parts
x_tilde = torch.cat((dft_re, dft_im), dim=1)
assert (
x_tilde.size() == x.size()
), f"The DFT and the input should have the same size. Got {x_tilde.size()} and {x.size()} instead."

return x_tilde.detach()


def idft(x: torch.Tensor) -> torch.Tensor:
"""Compute the inverse DFT of the input DFT that only contains non-redundant components.
Args:
x (torch.Tensor): DFT of shape (batch_size, max_len, n_channels).
Returns:
torch.Tensor: Inverse DFT of x with the same size (batch_size, max_len, n_channels).
"""

max_len = x.size(1)
n_real = math.ceil((max_len + 1) / 2)

# Extract real and imaginary parts
x_re = x[:, :n_real, :]
x_im = x[:, n_real:, :]

# Create imaginary tensor
zero_padding = torch.zeros(size=(x.size(0), 1, x.size(2)))
x_im = torch.cat((zero_padding, x_im), dim=1)

# If number of time steps is even, put the null imaginary part
if max_len % 2 == 0:
x_im = torch.cat((x_im, zero_padding), dim=1)

assert (
x_im.size() == x_re.size()
), f"The real and imaginary parts should have the same shape, got {x_re.size()} and {x_im.size()} instead."

x_freq = torch.complex(x_re, x_im)

# Apply IFFT
x_time = irfft(x_freq, n=max_len, dim=1, norm="ortho")

assert isinstance(x_time, torch.Tensor)
assert (
x_time.size() == x.size()
), f"The inverse DFT and the input should have the same size. Got {x_time.size()} and {x.size()} instead."

return x_time.detach()
5 changes: 1 addition & 4 deletions src/fdiff/utils/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ def maybe_initialize_wandb(cfg: DictConfig) -> str | None:
"""Initialize wandb if necessary."""
cfg_flat = flatten_config(cfg)
if "pytorch_lightning.loggers.WandbLogger" in cfg_flat.values():
wandb.init(
project="FourierDiffusion",
config=cfg_flat,
)
wandb.init(project="FourierDiffusion", config=cfg_flat, entity="fdiff")
assert wandb.run is not None
run_id = wandb.run.id
assert isinstance(run_id, str)
Expand Down
27 changes: 25 additions & 2 deletions tests/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from fdiff.dataloaders.datamodules import Datamodule
from fdiff.utils.dataclasses import DiffusableBatch
from fdiff.utils.fourier import idft

max_len = 30
n_channels = 3
Expand All @@ -20,15 +21,20 @@ def __init__(
batch_size: int = batch_size,
max_len: int = max_len,
n_channels: int = n_channels,
fourier_transform: bool = False,
) -> None:
super().__init__(
data_dir=data_dir, random_seed=random_seed, batch_size=batch_size
data_dir=data_dir,
random_seed=random_seed,
batch_size=batch_size,
fourier_transform=fourier_transform,
)
self.max_len = max_len
self.n_channels = n_channels
self.batch_size = batch_size

def setup(self, stage: str = "fit") -> None:
torch.manual_seed(self.random_seed)
self.X_train = torch.randn(
(10 * self.batch_size, self.max_len, self.n_channels), dtype=torch.float32
)
Expand All @@ -46,7 +52,7 @@ def dataset_name(self) -> str:
return "dummy"


def test_dataloader():
def test_dataloader() -> None:
datamodule = DummyDatamodule()
datamodule.prepare_data()
datamodule.setup()
Expand All @@ -55,3 +61,20 @@ def test_dataloader():
assert isinstance(batch, DiffusableBatch)
assert batch.X.shape == (batch_size, max_len, n_channels)
assert batch.y.shape == (batch_size,)


def test_fourier_transform() -> None:
# Default datamodule
datamodule = DummyDatamodule()
datamodule.prepare_data()
datamodule.setup()

# Fourier datamodule
datamodule_fourier = DummyDatamodule(fourier_transform=True)
datamodule_fourier.prepare_data()
datamodule_fourier.setup()

X = datamodule.train_dataloader().dataset.X
X_tilde = datamodule_fourier.train_dataloader().dataset.X

assert torch.allclose(X, idft(X_tilde), atol=1e-5)
Loading

0 comments on commit 252e2a0

Please sign in to comment.