From d7076f1fbae09ce714e9bc749c88c7cba1e452e4 Mon Sep 17 00:00:00 2001 From: Michal Januszewski Date: Tue, 21 Jun 2022 15:33:26 -0700 Subject: [PATCH] Make BoundingBox objects hashable. PiperOrigin-RevId: 456365921 --- connectomics/common/array.py | 3 +++ connectomics/common/bounding_box.py | 19 +++++++++++-------- connectomics/common/bounding_box_test.py | 5 +++++ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/connectomics/common/array.py b/connectomics/common/array.py index da9561f..39df51e 100644 --- a/connectomics/common/array.py +++ b/connectomics/common/array.py @@ -159,6 +159,9 @@ def copy(self, *args, **kwargs) -> 'MutableArray': def __str__(self): return np.ndarray.__repr__(self) + def __hash__(self): + return hash(self.tobytes()) + class MutableArray(array_mixins.MutableArrayMixin, ImmutableArray): """Strongly typed mutable version of np.ndarray.""" diff --git a/connectomics/common/bounding_box.py b/connectomics/common/bounding_box.py index bc8dc97..3108386 100644 --- a/connectomics/common/bounding_box.py +++ b/connectomics/common/bounding_box.py @@ -28,12 +28,13 @@ BoolSequence = Union[bool, Sequence[bool]] -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True, init=False) class BoundingBoxBase(Generic[T]): """BoundingBox encapsulates start/end coordinate pairs of the same length.""" _start: Tuple[T, ...] - _end: Tuple[T, ...] _size: Tuple[T, ...] + is_border_start: array.ImmutableArray + is_border_end: array.ImmutableArray def __init__( self, @@ -87,8 +88,8 @@ def __init__( else: size = end - start - self._start = self._tupleize(start) - self._size = self._tupleize(size) + object.__setattr__(self, '_start', self._tupleize(start)) + object.__setattr__(self, '_size', self._tupleize(size)) if len(self.start) != len(self.end) or len(self.end) != len(self.start): raise ValueError( @@ -99,17 +100,19 @@ def __init__( if len(is_border_start) != self.rank: raise ValueError( f'is_border_start needs to have exactly {self.rank} items') - self.is_border_start = np.asarray(is_border_start) + is_border_start = np.asarray(is_border_start) else: - self.is_border_start = np.zeros(self.rank, dtype=bool) + is_border_start = np.zeros(self.rank, dtype=bool) + object.__setattr__(self, 'is_border_start', array.ImmutableArray(is_border_start)) if is_border_end is not None: if len(is_border_end) != self.rank: raise ValueError( f'is_border_end needs to have exactly {self.rank} items') - self.is_border_end = np.asarray(is_border_end) + is_border_end = np.asarray(is_border_end) else: - self.is_border_end = np.zeros(self.rank, dtype=bool) + is_border_end = np.zeros(self.rank, dtype=bool) + object.__setattr__(self, 'is_border_end', array.ImmutableArray(is_border_end)) def __eq__(self: S, other: S) -> bool: for k, v in self.__dict__.items(): diff --git a/connectomics/common/bounding_box_test.py b/connectomics/common/bounding_box_test.py index 5855862..129a92e 100644 --- a/connectomics/common/bounding_box_test.py +++ b/connectomics/common/bounding_box_test.py @@ -282,6 +282,11 @@ def test_to_slice_float(self): expected_4d = none_slice * (4 - dim_size) + expected_slice self.assertEqual(expected_4d, box.to_slice4d()) + def test_hashability(self): + b = Box(start=[1, 2, 3], size=[4, 5, 6]) + d = {b: 1} + self.assertEqual(d[b], 1) + class GlobalTest(absltest.TestCase):