From f50b085f7dece84b3bf175ad4d9db251aaaee447 Mon Sep 17 00:00:00 2001 From: Tom Cobb Date: Fri, 1 Aug 2025 14:54:22 +0000 Subject: [PATCH] Add support for int Duration --- src/scanspec/specs.py | 8 +++++--- tests/test_specs.py | 5 +++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/scanspec/specs.py b/src/scanspec/specs.py index 460f12b8..0ca23549 100644 --- a/src/scanspec/specs.py +++ b/src/scanspec/specs.py @@ -9,7 +9,7 @@ import warnings from collections.abc import Callable, Mapping -from typing import Any, Generic, Literal, overload +from typing import Any, Generic, Literal, SupportsFloat, overload import numpy as np import numpy.typing as npt @@ -110,9 +110,11 @@ def shape(self) -> tuple[int, ...]: def __rmul__(self, other: int) -> Product[Axis]: return if_instance_do(other, int, lambda o: Product(Repeat(o), self)) - def __rmatmul__(self, other: float) -> ConstantDuration[Axis]: + def __rmatmul__(self, other: SupportsFloat) -> ConstantDuration[Axis]: return if_instance_do( - other, float, lambda o: ConstantDuration(constant_duration=o, spec=self) + other, + SupportsFloat, + lambda o: ConstantDuration(constant_duration=float(o), spec=self), ) @overload diff --git a/tests/test_specs.py b/tests/test_specs.py index a276ca9e..1600fe4b 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -630,6 +630,11 @@ def test_constant_duration(): assert "Only one of left and right defines a duration" in str(msg.value) +def test_int_duration(): + spec1 = 1 @ Line("x", 0, 1, 2) + assert spec1.duration() == 1.0 + + @pytest.mark.filterwarnings("ignore:fly") def test_fly(): spec = fly(Line("x", 0, 1, 5), 0.1)