Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions array_api_strict/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,9 @@ def full(
_check_valid_dtype(dtype)
_check_device(device)

if isinstance(fill_value, Array) and fill_value.ndim == 0:
fill_value = fill_value._array
if not isinstance(fill_value, bool | int | float | complex):
msg = f"Expected Python scalar fill_value, got type {type(fill_value)}"
raise TypeError(msg)
res = np.full(shape, fill_value, dtype=_np_dtype(dtype))
if DType(res.dtype) not in _all_dtypes:
# This will happen if the fill value is not something that NumPy
Expand Down Expand Up @@ -270,6 +271,10 @@ def full_like(
if device is None:
device = x.device

if not isinstance(fill_value, bool | int | float | complex):
msg = f"Expected Python scalar fill_value, got type {type(fill_value)}"
raise TypeError(msg)

res = np.full_like(x._array, fill_value, dtype=_np_dtype(dtype))
if DType(res.dtype) not in _all_dtypes:
# This will happen if the fill value is not something that NumPy
Expand Down
2 changes: 2 additions & 0 deletions array_api_strict/tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def test_full_errors():
assert_raises(ValueError, lambda: full((1,), 0, device="gpu"))
assert_raises(ValueError, lambda: full((1,), 0, dtype=int))
assert_raises(ValueError, lambda: full((1,), 0, dtype="i"))
assert_raises(TypeError, lambda: full((1,), asarray(0)))


def test_full_like_errors():
Expand All @@ -169,6 +170,7 @@ def test_full_like_errors():
assert_raises(ValueError, lambda: full_like(asarray(1), 0, device="gpu"))
assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype=int))
assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype="i"))
assert_raises(TypeError, lambda: full(asarray(1), asarray(0)))


def test_linspace_errors():
Expand Down
Loading