diff --git a/flow_matching/path/path.py b/flow_matching/path/path.py index 5503483..c133a14 100644 --- a/flow_matching/path/path.py +++ b/flow_matching/path/path.py @@ -55,7 +55,7 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: def assert_sample_shape(self, x_0: Tensor, x_1: Tensor, t: Tensor): assert ( t.ndim == 1 - ), f"The time vector t must be one of shape [batch_size]. Got {t.shape}." + ), f"The time vector t must have shape [batch_size]. Got {t.shape}." assert ( t.shape[0] == x_0.shape[0] == x_1.shape[0] ), f"Time t dimension must match the batch size [{x_1.shape[0]}]. Got {t.shape}"