Skip to content

Commit

Permalink
added serialize methods to data structures #148
Browse files Browse the repository at this point in the history
  • Loading branch information
jacanchaplais committed Sep 4, 2023
1 parent 7c4eac6 commit 47c6159
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 24 deletions.
127 changes: 103 additions & 24 deletions graphicle/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@
from copy import deepcopy
from enum import Enum

import more_itertools as mit
import numpy as np
import numpy.typing as npt
from attr import Factory, cmp_using, define, field, setters
import typing_extensions as tyx
from attr import Factory, asdict, cmp_using, define, field, setters
from mcpid.lookup import PdgRecords
from numpy.lib import recfunctions as rfn
from rich.console import Console
Expand Down Expand Up @@ -412,16 +414,6 @@ class MaskArray(base.MaskBase, base.ArrayBase):
data : sequence[bool]
Boolean values consituting the mask.
Attributes
----------
data : ndarray[bool_]
Numpy representation of the boolean mask.
Methods
-------
copy()
Provides a deepcopy of the data.
Examples
--------
Instantiating, copying, updating by index, and comparison:
Expand Down Expand Up @@ -461,9 +453,10 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
return _array_ufunc(self, ufunc, method, *inputs, **kwargs)

def __iter__(self) -> ty.Iterator[bool]:
return map(bool, self.data)
yield from map(op.methodcaller("item"), self.data)

def copy(self) -> "MaskArray":
"""Copies the underlying data into a new MaskArray instance."""
return self.__class__(self._data.copy())

def __repr__(self) -> str:
Expand Down Expand Up @@ -520,6 +513,7 @@ def __bool__(self) -> bool:

@property
def data(self) -> base.BoolVector:
"""Numpy representation of the boolean mask."""
return self._data

@data.setter
Expand All @@ -528,6 +522,13 @@ def data(
) -> None:
self._data = values # type: ignore

def serialize(self) -> ty.Tuple[bool, ...]:
"""Provides a serialized version of the underlying data.
.. versionadded:: 0.3.2
"""
return tuple(self.data.tolist())


def _mask_compat(*masks: base.MaskLike) -> bool:
"""Check if a collection of masks are compatible, *ie.* they must be
Expand Down Expand Up @@ -897,6 +898,13 @@ def leaves(
"or",
)

def serialize(self) -> ty.Dict[str, ty.Any]:
"""Returns serialized data as a dictionary.
.. versionadded:: 0.3.2
"""
return {key: val.serialize() for key, val in self._mask_arrays.items()}


@define(eq=False)
class PdgArray(base.ArrayBase):
Expand Down Expand Up @@ -1018,7 +1026,7 @@ def __repr__(self) -> str:
return _array_repr(self)

def __iter__(self) -> ty.Iterator[int]:
yield from map(int, self.data)
yield from map(op.methodcaller("item"), self.data)

def __len__(self) -> int:
return len(self.data)
Expand Down Expand Up @@ -1156,6 +1164,13 @@ def space_parity(self) -> base.DoubleVector:
def charge_parity(self) -> base.DoubleVector:
return self.__get_prop("c")

def serialize(self) -> ty.Tuple[int, ...]:
"""Provides a serialized version of the underlying data.
.. versionadded:: 0.3.2
"""
return tuple(self.data.tolist())


@define(eq=False)
class MomentumArray(base.ArrayBase):
Expand Down Expand Up @@ -1222,8 +1237,8 @@ def __array_wrap__(cls, array: base.AnyVector) -> "MomentumArray":
return cls(array)

def __iter__(self) -> ty.Iterator[MomentumElement]:
flat_vals = map(float, self._data.flatten())
elems = zip(*(flat_vals,) * 4) # type: ignore
flat_vals = map(op.methodcaller("item"), self._data.flatten())
elems = mit.ichunked(flat_vals, 4)
yield from it.starmap(MomentumElement, elems)

def __len__(self) -> int:
Expand Down Expand Up @@ -1569,6 +1584,13 @@ def delta_R(
return calculate._delta_R_symmetric(rap1, self._xy_pol)
return calculate._delta_R(rap1, rap2, self._xy_pol, other._xy_pol)

def serialize(self) -> ty.Tuple[MomentumElement, ...]:
"""Provides a serialized version of the underlying data.
.. versionadded:: 0.3.2
"""
return tuple(self)


@define(eq=False)
class ColorArray(base.ArrayBase):
Expand Down Expand Up @@ -1619,8 +1641,8 @@ def __repr__(self) -> str:
return _array_repr(self)

def __iter__(self) -> ty.Iterator[ColorElement]:
flat_vals = map(int, it.chain.from_iterable(self.data))
elems = zip(*(flat_vals,) * 2)
flat_vals = map(op.methodcaller("item"), self._data.flatten())
elems = mit.ichunked(flat_vals, 2)
yield from it.starmap(ColorElement, elems)

@property
Expand Down Expand Up @@ -1659,6 +1681,13 @@ def __ne__(
) -> MaskArray:
return _array_ne(self, other)

def serialize(self) -> ty.Tuple[ColorElement, ...]:
"""Provides a serialized version of the underlying data.
.. versionadded:: 0.3.2
"""
return tuple(self)


@define(eq=False)
class HelicityArray(base.ArrayBase):
Expand Down Expand Up @@ -1705,7 +1734,7 @@ def __repr__(self) -> str:
return _array_repr(self)

def __iter__(self) -> ty.Iterator[int]:
yield from map(int, self.data)
yield from map(op.methodcaller("item"), self.data)

@property
def data(self) -> base.HalfIntVector:
Expand Down Expand Up @@ -1740,6 +1769,13 @@ def __ne__(
) -> MaskArray:
return _array_ne(self, other)

def serialize(self) -> ty.Tuple[int, ...]:
"""Provides a serialized version of the underlying data.
.. versionadded:: 0.3.2
"""
return tuple(self.data.tolist())


@define(eq=False)
class StatusArray(base.ArrayBase):
Expand Down Expand Up @@ -1792,7 +1828,7 @@ def __repr__(self) -> str:
return _array_repr(self)

def __iter__(self) -> ty.Iterator[int]:
yield from map(int, self.data)
yield from map(op.methodcaller("item"), self.data)

def __getitem__(self, key) -> "StatusArray":
if isinstance(key, base.MaskBase):
Expand Down Expand Up @@ -1880,6 +1916,13 @@ def hard_mask(self) -> MaskGroup:
)
return masks

def serialize(self) -> ty.Tuple[int, ...]:
"""Provides a serialized version of the underlying data.
.. versionadded:: 0.3.2
"""
return tuple(self.data.tolist())


DsetPair = ty.Tuple[ty.Iterator[str], ty.Iterator[base.ArrayBase]]
CompositeType = ty.Union["ParticleSet", "Graphicle"]
Expand Down Expand Up @@ -1925,6 +1968,15 @@ def _composite_copy(instance: CompositeGeneric) -> CompositeGeneric:
return instance.__class__(**dict(zip(names, copies)))


class ParticleSetSerialized(tyx.TypedDict, total=False):
pdg: ty.List[int]
pmu: ty.List[ty.Tuple[float, float, float, float]]
color: ty.List[ty.Tuple[int, int]]
helicity: ty.List[int]
status: ty.List[int]
final: ty.List[bool]


@define
class ParticleSet(base.ParticleBase):
"""Composite of data structures containing particle set description.
Expand Down Expand Up @@ -2041,10 +2093,14 @@ def optional(
final=optional(MaskArray, final),
)

def serialize(self) -> ParticleSetSerialized:
"""Returns serialized data as a dictionary.
class AdjDict(ty.TypedDict):
edges: ty.Tuple[ty.Tuple[int, int, ty.Dict[str, ty.Any]], ...]
nodes: ty.Tuple[ty.Tuple[int, ty.Dict[str, ty.Any]], ...]
.. versionadded:: 0.3.2
"""
return {
key: getattr(self, key).serialize() for key in asdict(self).keys()
} # type: ignore


@define
Expand Down Expand Up @@ -2100,8 +2156,8 @@ def __repr__(self) -> str:
return _array_repr(self)

def __iter__(self) -> ty.Iterator[VertexPair]:
flat_vals = map(int, it.chain.from_iterable(self._data))
elems = zip(*(flat_vals,) * 2)
flat_vals = map(op.methodcaller("item"), self._data.flatten())
elems = mit.ichunked(flat_vals, 2)
yield from it.starmap(VertexPair, elems)

def __len__(self) -> int:
Expand Down Expand Up @@ -2299,6 +2355,23 @@ def to_sparse(self, data: ty.Optional[base.AnyVector] = None) -> coo_array:
shape=(size, size),
)

def serialize(self) -> ty.Tuple[VertexPair, ...]:
"""Provides a serialized version of the underlying data.
.. versionadded:: 0.3.2
"""
return tuple(self)


class GraphicleSerialized(tyx.TypedDict, total=False):
pdg: ty.List[int]
pmu: ty.List[ty.Tuple[float, float, float, float]]
color: ty.List[ty.Tuple[int, int]]
helicity: ty.List[int]
status: ty.List[int]
final: ty.List[bool]
adj: ty.List[ty.Tuple[int, int]]


@define
class Graphicle:
Expand Down Expand Up @@ -2489,3 +2562,9 @@ def edges(self) -> base.VoidVector:
def nodes(self) -> base.IntVector:
"""Vertex ids of each particle with at least one edge."""
return self.adj.nodes

def serialize(self) -> GraphicleSerialized:
"""Returns serialized data as a dictionary."""
out_dict: GraphicleSerialized = self.particles.serialize()
out_dict["adj"] = self.adj.serialize()
return out_dict
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"pyjet ==1.9.0",
"rich",
"deprecation",
"more-itertools >=7.2.0",
]

[project.urls]
Expand Down

0 comments on commit 47c6159

Please sign in to comment.