Skip to content

Commit

Permalink
ensure one dimensional t
Browse files Browse the repository at this point in the history
  • Loading branch information
Marton Havasi committed Dec 10, 2024
1 parent dffe02b commit 56fc655
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
33 changes: 17 additions & 16 deletions flow_matching/path/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:
| return :math:`X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0`, and the conditional velocity at :math:`X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0`.
Args:
x_0 (Tensor): source data point, shape (Batch, ...).
x_1 (Tensor): target data point, shape (Batch, ...).
t (Tensor, optional): times in [0,1], shape (Batch).
x_0 (Tensor): source data point, shape (batch_size, ...).
x_1 (Tensor): target data point, shape (batch_size, ...).
t (Tensor): times in [0,1], shape (batch_size).
Returns:
PathSample: a conditional sample at :math:`X_t \sim p_t`.
Expand All @@ -72,19 +72,20 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:

scheduler_output = self.scheduler(t)

if t.ndim == 1:
alpha_t = expand_tensor_like(
input_tensor=scheduler_output.alpha_t, expand_to=x_1
)
sigma_t = expand_tensor_like(
input_tensor=scheduler_output.sigma_t, expand_to=x_1
)
d_alpha_t = expand_tensor_like(
input_tensor=scheduler_output.d_alpha_t, expand_to=x_1
)
d_sigma_t = expand_tensor_like(
input_tensor=scheduler_output.d_sigma_t, expand_to=x_1
)
assert t.ndim == 1, f"The time vector t must be one of shape [batch_size]. Got {t.shape}"

alpha_t = expand_tensor_like(
input_tensor=scheduler_output.alpha_t, expand_to=x_1
)
sigma_t = expand_tensor_like(
input_tensor=scheduler_output.sigma_t, expand_to=x_1
)
d_alpha_t = expand_tensor_like(
input_tensor=scheduler_output.d_alpha_t, expand_to=x_1
)
d_sigma_t = expand_tensor_like(
input_tensor=scheduler_output.d_sigma_t, expand_to=x_1
)

# construct xt ~ p_t(x|x1).
x_t = sigma_t * x_0 + alpha_t * x_1
Expand Down
1 change: 1 addition & 0 deletions flow_matching/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def expand_tensor_like(input_tensor: Tensor, expand_to: Tensor) -> Tensor:
Tensor: (B, ...).
"""
assert input_tensor.ndim == 1, "Input tensor must be a 1d vector."
assert input_tensor.shape[0] == expand_to.shape[0], f"The first (batch) dimensions must match. Got shape {input_tensor.shape} and {expand_to.shape}."

dim_diff = expand_to.ndim - input_tensor.ndim

Expand Down

0 comments on commit 56fc655

Please sign in to comment.