From 1e6525a4d76e9a9cbac3d16662502d0cc7839e78 Mon Sep 17 00:00:00 2001 From: Marton Havasi Date: Tue, 10 Dec 2024 19:07:12 +0000 Subject: [PATCH] precommit --- flow_matching/path/affine.py | 4 +++- flow_matching/utils/utils.py | 10 ++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/flow_matching/path/affine.py b/flow_matching/path/affine.py index 004da12..4232f38 100644 --- a/flow_matching/path/affine.py +++ b/flow_matching/path/affine.py @@ -72,7 +72,9 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: scheduler_output = self.scheduler(t) - assert t.ndim == 1, f"The time vector t must be one of shape [batch_size]. Got {t.shape}" + assert ( + t.ndim == 1 + ), f"The time vector t must be one of shape [batch_size]. Got {t.shape}." alpha_t = expand_tensor_like( input_tensor=scheduler_output.alpha_t, expand_to=x_1 diff --git a/flow_matching/utils/utils.py b/flow_matching/utils/utils.py index 23dd1db..beb31ff 100644 --- a/flow_matching/utils/utils.py +++ b/flow_matching/utils/utils.py @@ -43,14 +43,16 @@ def expand_tensor_like(input_tensor: Tensor, expand_to: Tensor) -> Tensor: expand `input_tensor` to have the same shape as `expand_to` along all remaining dimensions. Args: - input_tensor (Tensor): (B,). - expand_to (Tensor): (B, ...). + input_tensor (Tensor): (batch_size,). + expand_to (Tensor): (batch_size, ...). Returns: - Tensor: (B, ...). + Tensor: (batch_size, ...). """ assert input_tensor.ndim == 1, "Input tensor must be a 1d vector." - assert input_tensor.shape[0] == expand_to.shape[0], f"The first (batch) dimension must match. Got shape {input_tensor.shape} and {expand_to.shape}." + assert ( + input_tensor.shape[0] == expand_to.shape[0] + ), f"The first (batch_size) dimension must match. Got shape {input_tensor.shape} and {expand_to.shape}." dim_diff = expand_to.ndim - input_tensor.ndim