Skip to content

Commit

Permalink
Merge pull request #27 from quantumlib/format
Browse files Browse the repository at this point in the history
Format files and fix some typos
  • Loading branch information
NoureldinYosri authored Jan 16, 2025
2 parents c451c27 + ca8c775 commit fff3ffc
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 71 deletions.
5 changes: 4 additions & 1 deletion test/test_protos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast
import itertools
import pytest
import numpy as np
Expand Down Expand Up @@ -75,7 +76,9 @@ def test_valuearray_conversion_trip(unit: Value) -> None:
def test_complex_valuearray_conversion_trip(unit: Value) -> None:
rs = np.random.RandomState(0)
for real, imag in zip(rs.random((4, 2, 4, 3)), rs.random((4, 2, 4, 3))):
v = (real + 1j * imag) * unit
real_ = cast(np.typing.NDArray[np.float64], real)
value = real_ + 1j * imag
v = unit * value
got = ValueArray.from_proto(v.to_proto())
assert got.unit == unit
np.testing.assert_allclose(got.value, real + 1j * imag)
Expand Down
6 changes: 3 additions & 3 deletions test/test_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,12 @@ def test_hash() -> None:
def test_numpy_sqrt() -> None:
from tunits.units import m, km, cm

u = np.sqrt(8 * km * m) - cm
u: Value = np.sqrt(8 * km * m) - cm
v = 8943.27191 * cm
assert np.isclose(u / v, 1)

u = np.sqrt(8 * km / m)
u = np.sqrt(8 * km / m) # type: ignore[assignment]
assert np.isclose(u, 89.4427191)

u = np.sqrt((8 * km / m).in_base_units())
u = np.sqrt((8 * km / m).in_base_units()) # type: ignore[assignment]
assert np.isclose(u, 89.4427191)
14 changes: 11 additions & 3 deletions test/test_value_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from packaging.version import Version
import numpy as np
import pytest
from tunits.core import raw_WithUnit, raw_UnitArray
Expand Down Expand Up @@ -92,9 +93,16 @@ def test_repr() -> None:
assert repr(km ** (2 / 3.0) * [-1] / kg**3 * s) == "ValueArray(array([-1.]), 'km^(2/3)*s/kg^3')"

# Numpy abbreviation is allowed.
assert repr(list(range(50000)) * km) == (
"LengthArray(array([ 0, 1, " "2, ..., 49997, 49998, 49999]), 'km')"
)
if Version(np.__version__) >= Version('2.2'):
expected_repr = (
"LengthArray(array([ 0, 1, "
"2, ..., 49997, 49998, 49999], shape=(50000,)), 'km')"
)
else:
expected_repr = (
"LengthArray(array([ 0, 1, " "2, ..., 49997, 49998, 49999]), 'km')"
)
assert repr(list(range(50000)) * km) == expected_repr

# Fallback case.
v: ValueArray = raw_WithUnit(
Expand Down
10 changes: 9 additions & 1 deletion tunits/core/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ class WithUnit:
class Value(WithUnit):
"""A floating-point value with associated units."""

value: float | complex
real: 'Value'
imag: 'Value'

@classmethod
def from_proto(cls: type[T], msg: tunits_pb2.Value) -> T: ...
def to_proto(self, msg: tunits_pb2.Value | None = None) -> tunits_pb2.Value: ...
Expand Down Expand Up @@ -302,9 +306,13 @@ class Value(WithUnit):
def __getitem__(self, key: Any) -> float: ...

class ValueArray(WithUnit):
value: NDArray[Any]
real: 'ValueArray'
imag: 'ValueArray'

@classmethod
def from_proto(cls: type[T], msg: tunits_pb2.ValueArray) -> T: ...
def to_proto(self, msg: tunits_pb2.ValueArray | None) -> tunits_pb2.ValueArray: ...
def to_proto(self, msg: tunits_pb2.ValueArray | None = None) -> tunits_pb2.ValueArray: ...
def __init__(self, data: Any, unit: Any = None) -> None: ...
def allclose(self, other: ValueArray, *args: Any, **kwargs: dict[str, Any]) -> bool: ...
def __array__(self, dtype: DTypeLike = None) -> NDArray[Any]: ...
Expand Down
33 changes: 12 additions & 21 deletions tunits/core/cython/dimension.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import abc

from functools import cache


class Dimension(abc.ABC):
"""Dimension abstraction.
Expand Down Expand Up @@ -62,8 +63,7 @@ class _Acceleration(Dimension):
@cache
def valid_base_units() -> tuple[Value, ...]:
return (
default_unit_database.known_units['m']
/ default_unit_database.known_units['s'] ** 2,
default_unit_database.known_units['m'] / default_unit_database.known_units['s'] ** 2,
)

def _value_class(self) -> type[Value]:
Expand All @@ -74,13 +74,14 @@ class _Acceleration(Dimension):


class ValueWithDimension(Dimension, Value):
def __init__(self, val, unit=None, validate:bool=True):
def __init__(self, val, unit=None, validate: bool = True):
super().__init__(val, unit=unit)
if validate and not type(self).is_valid(self):
raise ValueError(f'{self.unit} is not a valid unit for dimension {type(self)}')


class ArrayWithDimension(Dimension, ValueArray):
def __init__(self, val, unit=None, validate:bool=True):
def __init__(self, val, unit=None, validate: bool = True):
super().__init__(val, unit=unit)
if validate and not type(self).is_valid(self):
raise ValueError(f'{self.unit} is not a valid unit for dimension {type(self)}')
Expand Down Expand Up @@ -121,9 +122,7 @@ class _AngularFrequency(Dimension):
@cache
def valid_base_units() -> tuple[Value, ...]:
return (
default_unit_database.known_units['rad']
* default_unit_database.known_units['Hz']
* 2,
default_unit_database.known_units['rad'] * default_unit_database.known_units['Hz'] * 2,
)

def _value_class(self) -> type[Value]:
Expand Down Expand Up @@ -228,8 +227,7 @@ class _Density(Dimension):
@cache
def valid_base_units() -> tuple[Value, ...]:
return (
default_unit_database.known_units['kg']
/ default_unit_database.known_units['m'] ** 3,
default_unit_database.known_units['kg'] / default_unit_database.known_units['m'] ** 3,
)

def _value_class(self) -> type[Value]:
Expand Down Expand Up @@ -554,10 +552,8 @@ class _Noise(Dimension):
@cache
def valid_base_units() -> tuple[Value, ...]:
return (
default_unit_database.known_units['V']
/ default_unit_database.known_units['Hz'] ** 0.5,
default_unit_database.known_units['watt']
/ default_unit_database.known_units['Hz'],
default_unit_database.known_units['V'] / default_unit_database.known_units['Hz'] ** 0.5,
default_unit_database.known_units['watt'] / default_unit_database.known_units['Hz'],
)

def _value_class(self) -> type[Value]:
Expand Down Expand Up @@ -658,10 +654,7 @@ class _Speed(Dimension):
@staticmethod
@cache
def valid_base_units() -> tuple[Value, ...]:
return (
default_unit_database.known_units['m']
/ default_unit_database.known_units['s'],
)
return (default_unit_database.known_units['m'] / default_unit_database.known_units['s'],)

def _value_class(self) -> type[Value]:
return Speed
Expand All @@ -682,8 +675,7 @@ class _SurfaceDensity(Dimension):
@cache
def valid_base_units() -> tuple[Value, ...]:
return (
default_unit_database.known_units['kg']
/ default_unit_database.known_units['m'] ** 2,
default_unit_database.known_units['kg'] / default_unit_database.known_units['m'] ** 2,
)

def _value_class(self) -> type[Value]:
Expand Down Expand Up @@ -745,8 +737,7 @@ class _Torque(Dimension):
@cache
def valid_base_units() -> tuple[Value, ...]:
return (
default_unit_database.known_units['newton']
* default_unit_database.known_units['m'],
default_unit_database.known_units['newton'] * default_unit_database.known_units['m'],
)

def _value_class(self) -> type[Value]:
Expand Down
1 change: 1 addition & 0 deletions tunits/core/cython/proto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ _SERIALIZATION_ERROR_MESSAGE = (
@functools.cache
def _construct_unit(unit_enum: int, scale_enum: Optional[int] = None) -> 'Value':
from tunits.proto import tunits_pb2

unit_name = _PROTO_TO_UNIT_STRING.get(tunits_pb2.UnitEnum.Name(unit_enum), None)
scale = '' if scale_enum is None else _ENUM_TO_SCALE_SYMBOL[scale_enum]
return _try_interpret_as_with_unit(scale + unit_name)
Expand Down
14 changes: 9 additions & 5 deletions tunits/core/cython/unit_database.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ class UnitDatabase:
self.add_root_unit(unit_name)
return self.known_units[unit_name]

def parse_unit_formula(
self, formula: str, auto_create: Optional[bool] = None
) -> Value:
def parse_unit_formula(self, formula: str, auto_create: Optional[bool] = None) -> Value:
"""
:param str formula: Describes a combination of units.
:param None|bool auto_create: If this is set, missing unit strings will
Expand Down Expand Up @@ -109,7 +107,12 @@ class UnitDatabase:
"""
ua = UnitArray(unit_name)
unit: Value = raw_WithUnit(
1, {'factor': 1.0, 'ratio': {'numer': 1, 'denom': 1}, 'exp10': 0}, ua, ua, Value, ValueArray
1,
{'factor': 1.0, 'ratio': {'numer': 1, 'denom': 1}, 'exp10': 0},
ua,
ua,
Value,
ValueArray,
)
self.add_unit(unit_name, unit)

Expand Down Expand Up @@ -152,7 +155,8 @@ class UnitDatabase:
},
parent.base_units,
UnitArray(unit_name),
Value, ValueArray
Value,
ValueArray,
)

self.add_unit(unit_name, unit)
Expand Down
2 changes: 1 addition & 1 deletion tunits/core/cython/unit_mismatch_error.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ class UnitMismatchError(Exception):


class NotTUnitsLikeError(Exception):
"""The value is not a tunits object and can't be converted to one."""
"""The value is not a tunits object and can't be converted to one."""
3 changes: 2 additions & 1 deletion tunits/core/cython/with_unit_value.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ class Value(WithUnit):
else:
raise ValueError(f"{msg=} doesn't have a value.")
return cls(v, _proto_to_units(msg.units))

def to_proto(self, msg: Optional['tunits_pb2.Value'] = None) -> 'tunits_pb2.Value':
from tunits.proto import tunits_pb2

if msg is None:
msg = tunits_pb2.Value()
if isinstance(self.value, (complex, np.complexfloating)):
Expand Down
64 changes: 29 additions & 35 deletions tunits/proto/tunits.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,37 @@ option java_multiple_files = true;

// Units supported for serialization.
enum UnitEnum {
DECIBEL = 1; // Power unit (dB).
DECIBEL_MILLIWATTS = 2; // Decibel-milliwatts (dBm).
RADIANS = 3; // Radians (rad).
HERTZ = 4; // Frequency unit (Hz).
VOLT = 5; // Electric potential Unit (V).
SECOND = 6; // Time unit (s).
DECIBEL = 1; // Power unit (dB).
DECIBEL_MILLIWATTS = 2; // Decibel-milliwatts (dBm).
RADIANS = 3; // Radians (rad).
HERTZ = 4; // Frequency unit (Hz).
VOLT = 5; // Electric potential Unit (V).
SECOND = 6; // Time unit (s).
}

enum Scale {
// Enum value should be the associated exponent.
YOTTA = 24; // 10^24
ZETTA = 21; // 10^21
EXA = 18; // 10^18
PETA = 15; // 10^15
TERA = 12; // 10^12
GIGA = 9; // 10^9
MEGA = 6; // 10^6
KILO = 3; // 10^3
HECTO = 2; // 10^2
DECAD = 1; // 10^1
UNITY = 0; // 1
DECI = -1; // 10^-1
CENTI = -2; // 10^-2
MILLI = -3; // 10^-3
MICRO = -6; // 10^-6
NANO = -9; // 10^-9
PICO = -12; // 10^-12
FEMTO = -15; // 10^-15
ATTO = -18; // 10^-18
ZEPTO = -21; // 10^-21
YOCTO = -24; // 10^-24
YOTTA = 24; // 10^24
ZETTA = 21; // 10^21
EXA = 18; // 10^18
PETA = 15; // 10^15
TERA = 12; // 10^12
GIGA = 9; // 10^9
MEGA = 6; // 10^6
KILO = 3; // 10^3
HECTO = 2; // 10^2
DECAD = 1; // 10^1
UNITY = 0; // 1
DECI = -1; // 10^-1
CENTI = -2; // 10^-2
MILLI = -3; // 10^-3
MICRO = -6; // 10^-6
NANO = -9; // 10^-9
PICO = -12; // 10^-12
FEMTO = -15; // 10^-15
ATTO = -18; // 10^-18
ZEPTO = -21; // 10^-21
YOCTO = -24; // 10^-24
}

// The exponent of a unit e.g.
Expand Down Expand Up @@ -79,17 +79,14 @@ message Value {
// Units are repeated to represent combinations of units (e.g. V*s and mV/us).
// Units are combined through multiplication.
repeated Unit units = 1;

oneof value {
double real_value = 2;
Complex complex_value = 3;
}
}

message DoubleArray {
repeated double values = 1 [
packed = true
];
repeated double values = 1 [packed = true];
}

message ComplexArray {
Expand All @@ -102,14 +99,11 @@ message ValueArray {
// Units are repeated to represent combinations of units (e.g. V*s and mV/us).
// Units are combined through multiplication.
repeated Unit units = 1;

oneof values {
// The flattened array.
DoubleArray reals = 2;
ComplexArray complexes = 3;
}

repeated uint32 shape = 4 [
packed = true
]; // The shape of the array.
repeated uint32 shape = 4 [packed = true]; // The shape of the array.
}

0 comments on commit fff3ffc

Please sign in to comment.