Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions src/power_grid_model_ds/_core/model/arrays/base/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ class FancyArray(ABC):
_defaults: dict[str, Any] = {}
_str_lengths: dict[str, int] = {}

def __init__(self: Self, *args, data: NDArray | None = None, **kwargs):
def __init__(self, *args, data: NDArray | None = None, **kwargs):
if data is None:
self._data = build_array(*args, dtype=self.get_dtype(), defaults=self.get_defaults(), **kwargs)
else:
self._data = data

@property
def data(self: Self) -> NDArray:
def data(self) -> NDArray:
return self._data

@classmethod
Expand Down Expand Up @@ -110,7 +110,7 @@ def get_dtype(cls):
dtype_list.append((name, dtype))
return np.dtype(dtype_list)

def __repr__(self: Self) -> str:
def __repr__(self) -> str:
try:
data = getattr(self, "data")
if data.size > 3:
Expand All @@ -125,7 +125,7 @@ def __str__(self) -> str:
def __len__(self) -> int:
return len(self._data)

def __iter__(self: Self):
def __iter__(self):
for record in self._data:
yield self.__class__(data=np.array([record]))

Expand Down Expand Up @@ -177,16 +177,18 @@ def __contains__(self: Self, item: Self) -> bool:
return item.data in self._data
return False

def __hash__(self: Self):
def __hash__(self):
return hash(f"{self.__class__} {self}")

def __eq__(self: Self, other):
return self._data.__eq__(other.data)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return self.data.__eq__(other.data)

def __copy__(self: Self):
def __copy__(self):
return self.__class__(data=copy(self._data))

def copy(self: Self):
def copy(self):
"""Return a copy of this array including its data"""
return copy(self)

Expand Down Expand Up @@ -281,15 +283,15 @@ def get(
return self.__class__(data=apply_get(*args, array=self._data, mode_=mode_, **kwargs))

def filter_mask(
self: Self,
self,
*args: int | Iterable[int] | np.ndarray,
mode_: Literal["AND", "OR"] = "AND",
**kwargs: Any | list[Any] | np.ndarray,
) -> np.ndarray:
return get_filter_mask(*args, array=self._data, mode_=mode_, **kwargs)

def exclude_mask(
self: Self,
self,
*args: int | Iterable[int] | np.ndarray,
mode_: Literal["AND", "OR"] = "AND",
**kwargs: Any | list[Any] | np.ndarray,
Expand All @@ -299,7 +301,7 @@ def exclude_mask(
def re_order(self: Self, new_order: ArrayLike, column: str = "id") -> Self:
return self.__class__(data=re_order(self._data, new_order, column=column))

def update_by_id(self: Self, ids: ArrayLike, allow_missing: bool = False, **kwargs) -> None:
def update_by_id(self, ids: ArrayLike, allow_missing: bool = False, **kwargs) -> None:
try:
_ = update_by_id(self._data, ids, allow_missing, **kwargs)
except ValueError as error:
Expand All @@ -312,13 +314,13 @@ def get_updated_by_id(self: Self, ids: ArrayLike, allow_missing: bool = False, *
except ValueError as error:
raise ValueError(f"Cannot update {self.__class__.__name__}. {error}") from error

def check_ids(self: Self, return_duplicates: bool = False) -> NDArray | None:
def check_ids(self, return_duplicates: bool = False) -> NDArray | None:
return check_ids(self._data, return_duplicates=return_duplicates)

def as_table(self: Self, column_width: int | str = "auto", rows: int = 10) -> str:
def as_table(self, column_width: int | str = "auto", rows: int = 10) -> str:
return convert_array_to_string(self, column_width=column_width, rows=rows)

def as_df(self: Self):
def as_df(self):
"""Convert to pandas DataFrame"""
if pandas is None:
raise ImportError("pandas is not installed")
Expand Down