Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
Marton Havasi committed Dec 10, 2024
1 parent 9afa176 commit 1e6525a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
4 changes: 3 additions & 1 deletion flow_matching/path/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:

scheduler_output = self.scheduler(t)

assert t.ndim == 1, f"The time vector t must be one of shape [batch_size]. Got {t.shape}"
assert (
t.ndim == 1
), f"The time vector t must be one of shape [batch_size]. Got {t.shape}."

alpha_t = expand_tensor_like(
input_tensor=scheduler_output.alpha_t, expand_to=x_1
Expand Down
10 changes: 6 additions & 4 deletions flow_matching/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ def expand_tensor_like(input_tensor: Tensor, expand_to: Tensor) -> Tensor:
expand `input_tensor` to have the same shape as `expand_to` along all remaining dimensions.
Args:
input_tensor (Tensor): (B,).
expand_to (Tensor): (B, ...).
input_tensor (Tensor): (batch_size,).
expand_to (Tensor): (batch_size, ...).
Returns:
Tensor: (B, ...).
Tensor: (batch_size, ...).
"""
assert input_tensor.ndim == 1, "Input tensor must be a 1d vector."
assert input_tensor.shape[0] == expand_to.shape[0], f"The first (batch) dimension must match. Got shape {input_tensor.shape} and {expand_to.shape}."
assert (
input_tensor.shape[0] == expand_to.shape[0]
), f"The first (batch_size) dimension must match. Got shape {input_tensor.shape} and {expand_to.shape}."

dim_diff = expand_to.ndim - input_tensor.ndim

Expand Down

0 comments on commit 1e6525a

Please sign in to comment.