Skip to content

Commit

Permalink
added t shape assert to path class. docs nit: batch_size instead of B…
Browse files Browse the repository at this point in the history
…atch.
  • Loading branch information
Marton Havasi committed Dec 10, 2024
1 parent 1e6525a commit 06111bb
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 32 deletions.
4 changes: 0 additions & 4 deletions flow_matching/path/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:

scheduler_output = self.scheduler(t)

assert (
t.ndim == 1
), f"The time vector t must be one of shape [batch_size]. Got {t.shape}."

alpha_t = expand_tensor_like(
input_tensor=scheduler_output.alpha_t, expand_to=x_1
)
Expand Down
10 changes: 4 additions & 6 deletions flow_matching/path/geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,15 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:
| return :math:`X_0, X_1, X_t = \exp_{X_1}(\kappa_t \log_{X_1}(X_0))`, and the conditional velocity at :math:`X_t, \dot{X}_t`.
Args:
x_0 (Tensor): source data point, shape (Batch, ...).
x_1 (Tensor): target data point, shape (Batch, ...).
t (Tensor, optional): times in [0,1], shape (Batch).
x_0 (Tensor): source data point, shape (batch_size, ...).
x_1 (Tensor): target data point, shape (batch_size, ...).
t (Tensor): times in [0,1], shape (batch_size).
Returns:
PathSample: A conditional sample at :math:`X_t \sim p_t`.
"""
self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t)

if t.ndim <= 1:
t = expand_tensor_like(input_tensor=t, expand_to=x_1[..., 0:1]).clone()
t = expand_tensor_like(input_tensor=t, expand_to=x_1[..., 0:1]).clone()

def cond_u(x_0, x_1, t):
path = geodesic(self.manifold, x_0, x_1)
Expand Down
9 changes: 4 additions & 5 deletions flow_matching/path/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample:
| given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`.
| return :math:`X_0, X_1, t`, and :math:`X_t \sim p_t`.
Args:
x_0 (Tensor): source data point, shape (Batch, ...).
x_1 (Tensor): target data point, shape (Batch, ...).
t (Tensor): times in [0,1], shape (Batch).
x_0 (Tensor): source data point, shape (batch_size, ...).
x_1 (Tensor): target data point, shape (batch_size, ...).
t (Tensor): times in [0,1], shape (batch_size).
Returns:
DiscretePathSample: a conditional sample at :math:`X_t ~ p_t`.
Expand All @@ -81,8 +81,7 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample:

sigma_t = self.scheduler(t).sigma_t

if t.ndim == 1:
sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1)
sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1)

source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t
x_t = torch.where(condition=source_indices, input=x_0, other=x_1)
Expand Down
9 changes: 6 additions & 3 deletions flow_matching/path/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,18 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:
| returns :math:`X_0, X_1, X_t \sim p_t(X_t)`, and a conditional target :math:`Y`, all objects are under ``PathSample``.
Args:
x_0 (Tensor): source data point, shape (Batch, ...).
x_1 (Tensor): target data point, shape (Batch, ...).
t (Tensor, optional): times in [0,1], shape (Batch).
x_0 (Tensor): source data point, shape (batch_size, ...).
x_1 (Tensor): target data point, shape (batch_size, ...).
t (Tensor): times in [0,1], shape (batch_size).
Returns:
PathSample: a conditional sample.
"""

def assert_sample_shape(self, x_0: Tensor, x_1: Tensor, t: Tensor):
assert (
t.ndim == 1
), f"The time vector t must be one of shape [batch_size]. Got {t.shape}."
assert (
t.shape[0] == x_0.shape[0] == x_1.shape[0]
), f"Time t dimension must match the batch size [{x_1.shape[0]}]. Got {t.shape}"
22 changes: 11 additions & 11 deletions flow_matching/path/path_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@ class PathSample:
x_1 (Tensor): the target sample :math:`X_1`.
x_0 (Tensor): the source sample :math:`X_0`.
t (Tensor): the time sample :math:`t`.
x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (Batch, ...).
dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (Batch, ...).
x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (batch_size, ...).
dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (batch_size, ...).
"""

x_1: Tensor = field(metadata={"help": "target samples X_1 (Batch, ...)."})
x_0: Tensor = field(metadata={"help": "source samples X_0 (Batch, ...)."})
t: Tensor = field(metadata={"help": "time samples t (Batch, ...)."})
x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."})
x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."})
t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."})
x_t: Tensor = field(
metadata={"help": "samples x_t ~ p_t(X_t), shape (Batch, ...)."}
metadata={"help": "samples x_t ~ p_t(X_t), shape (batch_size, ...)."}
)
dx_t: Tensor = field(
metadata={"help": "conditional target dX_t, shape: (Batch, ...)."}
metadata={"help": "conditional target dX_t, shape: (batch_size, ...)."}
)


Expand All @@ -45,9 +45,9 @@ class DiscretePathSample:
x_t (Tensor): the sample along the path :math:`X_t \sim p_t`.
"""

x_1: Tensor = field(metadata={"help": "target samples X_1 (Batch, ...)."})
x_0: Tensor = field(metadata={"help": "source samples X_0 (Batch, ...)."})
t: Tensor = field(metadata={"help": "time samples t (Batch, ...)."})
x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."})
x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."})
t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."})
x_t: Tensor = field(
metadata={"help": "samples X_t ~ p_t(X_t), shape (Batch, ...)."}
metadata={"help": "samples X_t ~ p_t(X_t), shape (batch_size, ...)."}
)
2 changes: 1 addition & 1 deletion flow_matching/path/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __call__(self, t: Tensor) -> SchedulerOutput:
"""Scheduler for convex paths.
Args:
t (Tensor, optional): times in [0,1], shape (...).
t (Tensor): times in [0,1], shape (...).
Returns:
SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t`
Expand Down
4 changes: 2 additions & 2 deletions flow_matching/utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
| returns the model output for input x at time t, with extra information `extra`.
Args:
x (Tensor): input data to the model (Batch, ...).
t (Tensor): time (Batch).
x (Tensor): input data to the model (batch_size, ...).
t (Tensor): time (batch_size).
**extras: additional information forwarded to the model, e.g., text condition.
Returns:
Expand Down

0 comments on commit 06111bb

Please sign in to comment.