From 06111bb015ebb164e7919504276cd6847a887cc9 Mon Sep 17 00:00:00 2001 From: Marton Havasi Date: Tue, 10 Dec 2024 19:31:36 +0000 Subject: [PATCH] added t shape assert to path class. docs nit: batch_size instead of Batch. --- flow_matching/path/affine.py | 4 ---- flow_matching/path/geodesic.py | 10 ++++------ flow_matching/path/mixture.py | 9 ++++----- flow_matching/path/path.py | 9 ++++++--- flow_matching/path/path_sample.py | 22 +++++++++++----------- flow_matching/path/scheduler/scheduler.py | 2 +- flow_matching/utils/model_wrapper.py | 4 ++-- 7 files changed, 28 insertions(+), 32 deletions(-) diff --git a/flow_matching/path/affine.py b/flow_matching/path/affine.py index 4232f38..81cb7ed 100644 --- a/flow_matching/path/affine.py +++ b/flow_matching/path/affine.py @@ -72,10 +72,6 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: scheduler_output = self.scheduler(t) - assert ( - t.ndim == 1 - ), f"The time vector t must be one of shape [batch_size]. Got {t.shape}." - alpha_t = expand_tensor_like( input_tensor=scheduler_output.alpha_t, expand_to=x_1 ) diff --git a/flow_matching/path/geodesic.py b/flow_matching/path/geodesic.py index 5575bfb..d04bf67 100644 --- a/flow_matching/path/geodesic.py +++ b/flow_matching/path/geodesic.py @@ -74,17 +74,15 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: | return :math:`X_0, X_1, X_t = \exp_{X_1}(\kappa_t \log_{X_1}(X_0))`, and the conditional velocity at :math:`X_t, \dot{X}_t`. Args: - x_0 (Tensor): source data point, shape (Batch, ...). - x_1 (Tensor): target data point, shape (Batch, ...). - t (Tensor, optional): times in [0,1], shape (Batch). + x_0 (Tensor): source data point, shape (batch_size, ...). + x_1 (Tensor): target data point, shape (batch_size, ...). + t (Tensor): times in [0,1], shape (batch_size). Returns: PathSample: A conditional sample at :math:`X_t \sim p_t`. """ self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t) - - if t.ndim <= 1: - t = expand_tensor_like(input_tensor=t, expand_to=x_1[..., 0:1]).clone() + t = expand_tensor_like(input_tensor=t, expand_to=x_1[..., 0:1]).clone() def cond_u(x_0, x_1, t): path = geodesic(self.manifold, x_0, x_1) diff --git a/flow_matching/path/mixture.py b/flow_matching/path/mixture.py index 277ef36..28b4043 100644 --- a/flow_matching/path/mixture.py +++ b/flow_matching/path/mixture.py @@ -70,9 +70,9 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample: | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`. | return :math:`X_0, X_1, t`, and :math:`X_t \sim p_t`. Args: - x_0 (Tensor): source data point, shape (Batch, ...). - x_1 (Tensor): target data point, shape (Batch, ...). - t (Tensor): times in [0,1], shape (Batch). + x_0 (Tensor): source data point, shape (batch_size, ...). + x_1 (Tensor): target data point, shape (batch_size, ...). + t (Tensor): times in [0,1], shape (batch_size). Returns: DiscretePathSample: a conditional sample at :math:`X_t ~ p_t`. @@ -81,8 +81,7 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample: sigma_t = self.scheduler(t).sigma_t - if t.ndim == 1: - sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1) + sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1) source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t x_t = torch.where(condition=source_indices, input=x_0, other=x_1) diff --git a/flow_matching/path/path.py b/flow_matching/path/path.py index 45afcbd..5503483 100644 --- a/flow_matching/path/path.py +++ b/flow_matching/path/path.py @@ -44,15 +44,18 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: | returns :math:`X_0, X_1, X_t \sim p_t(X_t)`, and a conditional target :math:`Y`, all objects are under ``PathSample``. Args: - x_0 (Tensor): source data point, shape (Batch, ...). - x_1 (Tensor): target data point, shape (Batch, ...). - t (Tensor, optional): times in [0,1], shape (Batch). + x_0 (Tensor): source data point, shape (batch_size, ...). + x_1 (Tensor): target data point, shape (batch_size, ...). + t (Tensor): times in [0,1], shape (batch_size). Returns: PathSample: a conditional sample. """ def assert_sample_shape(self, x_0: Tensor, x_1: Tensor, t: Tensor): + assert ( + t.ndim == 1 + ), f"The time vector t must be one of shape [batch_size]. Got {t.shape}." assert ( t.shape[0] == x_0.shape[0] == x_1.shape[0] ), f"Time t dimension must match the batch size [{x_1.shape[0]}]. Got {t.shape}" diff --git a/flow_matching/path/path_sample.py b/flow_matching/path/path_sample.py index 3db21b6..867032e 100644 --- a/flow_matching/path/path_sample.py +++ b/flow_matching/path/path_sample.py @@ -17,19 +17,19 @@ class PathSample: x_1 (Tensor): the target sample :math:`X_1`. x_0 (Tensor): the source sample :math:`X_0`. t (Tensor): the time sample :math:`t`. - x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (Batch, ...). - dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (Batch, ...). + x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (batch_size, ...). + dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (batch_size, ...). """ - x_1: Tensor = field(metadata={"help": "target samples X_1 (Batch, ...)."}) - x_0: Tensor = field(metadata={"help": "source samples X_0 (Batch, ...)."}) - t: Tensor = field(metadata={"help": "time samples t (Batch, ...)."}) + x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."}) + x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."}) + t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."}) x_t: Tensor = field( - metadata={"help": "samples x_t ~ p_t(X_t), shape (Batch, ...)."} + metadata={"help": "samples x_t ~ p_t(X_t), shape (batch_size, ...)."} ) dx_t: Tensor = field( - metadata={"help": "conditional target dX_t, shape: (Batch, ...)."} + metadata={"help": "conditional target dX_t, shape: (batch_size, ...)."} ) @@ -45,9 +45,9 @@ class DiscretePathSample: x_t (Tensor): the sample along the path :math:`X_t \sim p_t`. """ - x_1: Tensor = field(metadata={"help": "target samples X_1 (Batch, ...)."}) - x_0: Tensor = field(metadata={"help": "source samples X_0 (Batch, ...)."}) - t: Tensor = field(metadata={"help": "time samples t (Batch, ...)."}) + x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."}) + x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."}) + t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."}) x_t: Tensor = field( - metadata={"help": "samples X_t ~ p_t(X_t), shape (Batch, ...)."} + metadata={"help": "samples X_t ~ p_t(X_t), shape (batch_size, ...)."} ) diff --git a/flow_matching/path/scheduler/scheduler.py b/flow_matching/path/scheduler/scheduler.py index 719b34f..422618a 100644 --- a/flow_matching/path/scheduler/scheduler.py +++ b/flow_matching/path/scheduler/scheduler.py @@ -66,7 +66,7 @@ def __call__(self, t: Tensor) -> SchedulerOutput: """Scheduler for convex paths. Args: - t (Tensor, optional): times in [0,1], shape (...). + t (Tensor): times in [0,1], shape (...). Returns: SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t` diff --git a/flow_matching/utils/model_wrapper.py b/flow_matching/utils/model_wrapper.py index 22733ac..ac7d932 100644 --- a/flow_matching/utils/model_wrapper.py +++ b/flow_matching/utils/model_wrapper.py @@ -33,8 +33,8 @@ def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor: | returns the model output for input x at time t, with extra information `extra`. Args: - x (Tensor): input data to the model (Batch, ...). - t (Tensor): time (Batch). + x (Tensor): input data to the model (batch_size, ...). + t (Tensor): time (batch_size). **extras: additional information forwarded to the model, e.g., text condition. Returns: