Skip to content

Commit ec57e21

Browse files
committed
1 parent 22f8445 commit ec57e21

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

tests/extension/decimal/array.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from collections.abc import (
55
Callable,
66
Iterable,
7+
MutableSequence,
8+
Sequence,
79
)
810
import decimal
911
import numbers
@@ -37,6 +39,7 @@
3739
from pandas._typing import (
3840
ArrayLike,
3941
AstypeArg,
42+
ListLike,
4043
ScalarIndexer,
4144
SequenceIndexer,
4245
SequenceNotStr,
@@ -91,7 +94,7 @@ class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray):
9194

9295
def __init__(
9396
self,
94-
values: list[decimal.Decimal | float] | np.ndarray,
97+
values: MutableSequence[decimal._DecimalNew] | np.ndarray | ExtensionArray,
9598
dtype: DecimalDtype | None = None,
9699
copy: bool = False,
97100
context: decimal.Context | None = None,
@@ -123,7 +126,7 @@ def dtype(self) -> DecimalDtype:
123126
@classmethod
124127
def _from_sequence(
125128
cls,
126-
scalars: list[decimal.Decimal | float] | np.ndarray,
129+
scalars: list[decimal._DecimalNew] | np.ndarray | ExtensionArray,
127130
dtype: DecimalDtype | None = None,
128131
copy: bool = False,
129132
) -> Self:
@@ -140,7 +143,9 @@ def _from_sequence_of_strings(
140143

141144
@classmethod
142145
def _from_factorized(
143-
cls, values: list[decimal.Decimal | float] | np.ndarray, original: Any
146+
cls,
147+
values: list[decimal._DecimalNew] | np.ndarray | ExtensionArray,
148+
original: Any,
144149
) -> Self:
145150
return cls(values)
146151

@@ -186,7 +191,7 @@ def reconstruct(
186191
x: (
187192
decimal.Decimal
188193
| numbers.Number
189-
| list[decimal.Decimal | float]
194+
| list[decimal._DecimalNew]
190195
| np.ndarray
191196
),
192197
) -> decimal.Decimal | numbers.Number | DecimalArray:
@@ -240,21 +245,24 @@ def astype(self, dtype, copy=True):
240245

241246
return super().astype(dtype, copy=copy)
242247

243-
def __setitem__(self, key: object, value: decimal._DecimalNew) -> None:
248+
def __setitem__(
249+
self,
250+
key: int | slice[Any, Any, Any] | ListLike,
251+
value: decimal._DecimalNew | Sequence[decimal._DecimalNew],
252+
) -> None:
244253
if is_list_like(value):
254+
assert isinstance(value, Iterable)
245255
if is_scalar(key):
246256
raise ValueError("setting an array element with a sequence.")
247-
value = [ # type: ignore[assignment]
248-
decimal.Decimal(v) # type: ignore[arg-type]
249-
for v in value # type: ignore[union-attr] # pyright: ignore[reportAssignmentType,reportGeneralTypeIssues]
257+
value = [
258+
decimal.Decimal(v) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
259+
for v in value
250260
]
251261
else:
252-
value = decimal.Decimal(value)
262+
value = decimal.Decimal(value) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
253263

254-
key = check_array_indexer( # type: ignore[call-overload]
255-
self, key # pyright: ignore[reportArgumentType,reportCallIssue]
256-
)
257-
self._data[key] = value # type: ignore[call-overload] # pyright: ignore[reportArgumentType,reportCallIssue]
264+
key = check_array_indexer(self, key)
265+
self._data[key] = value
258266

259267
def __len__(self) -> int:
260268
return len(self._data)

0 commit comments

Comments
 (0)