From d23367d82d36b3b6f507422153fe4535c3d0dfd7 Mon Sep 17 00:00:00 2001 From: mhavasi Date: Wed, 11 Dec 2024 11:33:35 -0500 Subject: [PATCH] Assert t shape in affine path (#6) * ensure one dimensional t * grammar nit * precommit * added t shape assert to path class. docs nit: batch_size instead of Batch. * docs phrasing nit --------- Co-authored-by: Marton Havasi --- flow_matching/path/affine.py | 31 +++++++++++------------ flow_matching/path/geodesic.py | 10 +++----- flow_matching/path/mixture.py | 9 +++---- flow_matching/path/path.py | 9 ++++--- flow_matching/path/path_sample.py | 22 ++++++++-------- flow_matching/path/scheduler/scheduler.py | 2 +- flow_matching/utils/model_wrapper.py | 4 +-- flow_matching/utils/utils.py | 9 ++++--- 8 files changed, 49 insertions(+), 47 deletions(-) diff --git a/flow_matching/path/affine.py b/flow_matching/path/affine.py index 7e4a18f..81cb7ed 100644 --- a/flow_matching/path/affine.py +++ b/flow_matching/path/affine.py @@ -61,9 +61,9 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: | return :math:`X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0`, and the conditional velocity at :math:`X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0`. Args: - x_0 (Tensor): source data point, shape (Batch, ...). - x_1 (Tensor): target data point, shape (Batch, ...). - t (Tensor, optional): times in [0,1], shape (Batch). + x_0 (Tensor): source data point, shape (batch_size, ...). + x_1 (Tensor): target data point, shape (batch_size, ...). + t (Tensor): times in [0,1], shape (batch_size). Returns: PathSample: a conditional sample at :math:`X_t \sim p_t`. @@ -72,19 +72,18 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: scheduler_output = self.scheduler(t) - if t.ndim == 1: - alpha_t = expand_tensor_like( - input_tensor=scheduler_output.alpha_t, expand_to=x_1 - ) - sigma_t = expand_tensor_like( - input_tensor=scheduler_output.sigma_t, expand_to=x_1 - ) - d_alpha_t = expand_tensor_like( - input_tensor=scheduler_output.d_alpha_t, expand_to=x_1 - ) - d_sigma_t = expand_tensor_like( - input_tensor=scheduler_output.d_sigma_t, expand_to=x_1 - ) + alpha_t = expand_tensor_like( + input_tensor=scheduler_output.alpha_t, expand_to=x_1 + ) + sigma_t = expand_tensor_like( + input_tensor=scheduler_output.sigma_t, expand_to=x_1 + ) + d_alpha_t = expand_tensor_like( + input_tensor=scheduler_output.d_alpha_t, expand_to=x_1 + ) + d_sigma_t = expand_tensor_like( + input_tensor=scheduler_output.d_sigma_t, expand_to=x_1 + ) # construct xt ~ p_t(x|x1). x_t = sigma_t * x_0 + alpha_t * x_1 diff --git a/flow_matching/path/geodesic.py b/flow_matching/path/geodesic.py index 5575bfb..d04bf67 100644 --- a/flow_matching/path/geodesic.py +++ b/flow_matching/path/geodesic.py @@ -74,17 +74,15 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: | return :math:`X_0, X_1, X_t = \exp_{X_1}(\kappa_t \log_{X_1}(X_0))`, and the conditional velocity at :math:`X_t, \dot{X}_t`. Args: - x_0 (Tensor): source data point, shape (Batch, ...). - x_1 (Tensor): target data point, shape (Batch, ...). - t (Tensor, optional): times in [0,1], shape (Batch). + x_0 (Tensor): source data point, shape (batch_size, ...). + x_1 (Tensor): target data point, shape (batch_size, ...). + t (Tensor): times in [0,1], shape (batch_size). Returns: PathSample: A conditional sample at :math:`X_t \sim p_t`. """ self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t) - - if t.ndim <= 1: - t = expand_tensor_like(input_tensor=t, expand_to=x_1[..., 0:1]).clone() + t = expand_tensor_like(input_tensor=t, expand_to=x_1[..., 0:1]).clone() def cond_u(x_0, x_1, t): path = geodesic(self.manifold, x_0, x_1) diff --git a/flow_matching/path/mixture.py b/flow_matching/path/mixture.py index 277ef36..28b4043 100644 --- a/flow_matching/path/mixture.py +++ b/flow_matching/path/mixture.py @@ -70,9 +70,9 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample: | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`. | return :math:`X_0, X_1, t`, and :math:`X_t \sim p_t`. Args: - x_0 (Tensor): source data point, shape (Batch, ...). - x_1 (Tensor): target data point, shape (Batch, ...). - t (Tensor): times in [0,1], shape (Batch). + x_0 (Tensor): source data point, shape (batch_size, ...). + x_1 (Tensor): target data point, shape (batch_size, ...). + t (Tensor): times in [0,1], shape (batch_size). Returns: DiscretePathSample: a conditional sample at :math:`X_t ~ p_t`. @@ -81,8 +81,7 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample: sigma_t = self.scheduler(t).sigma_t - if t.ndim == 1: - sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1) + sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1) source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t x_t = torch.where(condition=source_indices, input=x_0, other=x_1) diff --git a/flow_matching/path/path.py b/flow_matching/path/path.py index 45afcbd..c133a14 100644 --- a/flow_matching/path/path.py +++ b/flow_matching/path/path.py @@ -44,15 +44,18 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: | returns :math:`X_0, X_1, X_t \sim p_t(X_t)`, and a conditional target :math:`Y`, all objects are under ``PathSample``. Args: - x_0 (Tensor): source data point, shape (Batch, ...). - x_1 (Tensor): target data point, shape (Batch, ...). - t (Tensor, optional): times in [0,1], shape (Batch). + x_0 (Tensor): source data point, shape (batch_size, ...). + x_1 (Tensor): target data point, shape (batch_size, ...). + t (Tensor): times in [0,1], shape (batch_size). Returns: PathSample: a conditional sample. """ def assert_sample_shape(self, x_0: Tensor, x_1: Tensor, t: Tensor): + assert ( + t.ndim == 1 + ), f"The time vector t must have shape [batch_size]. Got {t.shape}." assert ( t.shape[0] == x_0.shape[0] == x_1.shape[0] ), f"Time t dimension must match the batch size [{x_1.shape[0]}]. Got {t.shape}" diff --git a/flow_matching/path/path_sample.py b/flow_matching/path/path_sample.py index 3db21b6..867032e 100644 --- a/flow_matching/path/path_sample.py +++ b/flow_matching/path/path_sample.py @@ -17,19 +17,19 @@ class PathSample: x_1 (Tensor): the target sample :math:`X_1`. x_0 (Tensor): the source sample :math:`X_0`. t (Tensor): the time sample :math:`t`. - x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (Batch, ...). - dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (Batch, ...). + x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (batch_size, ...). + dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (batch_size, ...). """ - x_1: Tensor = field(metadata={"help": "target samples X_1 (Batch, ...)."}) - x_0: Tensor = field(metadata={"help": "source samples X_0 (Batch, ...)."}) - t: Tensor = field(metadata={"help": "time samples t (Batch, ...)."}) + x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."}) + x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."}) + t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."}) x_t: Tensor = field( - metadata={"help": "samples x_t ~ p_t(X_t), shape (Batch, ...)."} + metadata={"help": "samples x_t ~ p_t(X_t), shape (batch_size, ...)."} ) dx_t: Tensor = field( - metadata={"help": "conditional target dX_t, shape: (Batch, ...)."} + metadata={"help": "conditional target dX_t, shape: (batch_size, ...)."} ) @@ -45,9 +45,9 @@ class DiscretePathSample: x_t (Tensor): the sample along the path :math:`X_t \sim p_t`. """ - x_1: Tensor = field(metadata={"help": "target samples X_1 (Batch, ...)."}) - x_0: Tensor = field(metadata={"help": "source samples X_0 (Batch, ...)."}) - t: Tensor = field(metadata={"help": "time samples t (Batch, ...)."}) + x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."}) + x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."}) + t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."}) x_t: Tensor = field( - metadata={"help": "samples X_t ~ p_t(X_t), shape (Batch, ...)."} + metadata={"help": "samples X_t ~ p_t(X_t), shape (batch_size, ...)."} ) diff --git a/flow_matching/path/scheduler/scheduler.py b/flow_matching/path/scheduler/scheduler.py index 719b34f..422618a 100644 --- a/flow_matching/path/scheduler/scheduler.py +++ b/flow_matching/path/scheduler/scheduler.py @@ -66,7 +66,7 @@ def __call__(self, t: Tensor) -> SchedulerOutput: """Scheduler for convex paths. Args: - t (Tensor, optional): times in [0,1], shape (...). + t (Tensor): times in [0,1], shape (...). Returns: SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t` diff --git a/flow_matching/utils/model_wrapper.py b/flow_matching/utils/model_wrapper.py index 22733ac..ac7d932 100644 --- a/flow_matching/utils/model_wrapper.py +++ b/flow_matching/utils/model_wrapper.py @@ -33,8 +33,8 @@ def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor: | returns the model output for input x at time t, with extra information `extra`. Args: - x (Tensor): input data to the model (Batch, ...). - t (Tensor): time (Batch). + x (Tensor): input data to the model (batch_size, ...). + t (Tensor): time (batch_size). **extras: additional information forwarded to the model, e.g., text condition. Returns: diff --git a/flow_matching/utils/utils.py b/flow_matching/utils/utils.py index 9b75521..beb31ff 100644 --- a/flow_matching/utils/utils.py +++ b/flow_matching/utils/utils.py @@ -43,13 +43,16 @@ def expand_tensor_like(input_tensor: Tensor, expand_to: Tensor) -> Tensor: expand `input_tensor` to have the same shape as `expand_to` along all remaining dimensions. Args: - input_tensor (Tensor): (B,). - expand_to (Tensor): (B, ...). + input_tensor (Tensor): (batch_size,). + expand_to (Tensor): (batch_size, ...). Returns: - Tensor: (B, ...). + Tensor: (batch_size, ...). """ assert input_tensor.ndim == 1, "Input tensor must be a 1d vector." + assert ( + input_tensor.shape[0] == expand_to.shape[0] + ), f"The first (batch_size) dimension must match. Got shape {input_tensor.shape} and {expand_to.shape}." dim_diff = expand_to.ndim - input_tensor.ndim