Skip to content
Draft
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 255 additions & 5 deletions array_api_tests/test_special_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def or_(i: float) -> bool:

def make_and(cond1: UnaryCheck, cond2: UnaryCheck) -> UnaryCheck:
def and_(i: float) -> bool:
return cond1(i) or cond2(i)
return cond1(i) and cond2(i)

return and_

Expand Down Expand Up @@ -492,6 +492,179 @@ def check_result(result: float) -> bool:
return check_result, expr


def parse_complex_value(value_str: str) -> complex:
"""
Parses a complex value string to return a complex number, e.g.

>>> parse_complex_value('+0 + 0j')
0j
>>> parse_complex_value('NaN + NaN j')
(nan+nanj)
>>> parse_complex_value('0 + NaN j')
nanj
>>> parse_complex_value('+0 + πj/2')
1.5707963267948966j
>>> parse_complex_value('+infinity + 3πj/4')
(inf+2.356194490192345j)

Handles formats: "A + Bj", "A + B j", "A + πj/N", "A + Nπj/M"
"""
m = r_complex_value.match(value_str)
if m is None:
raise ParseError(value_str)

# Parse real part with its sign
real_sign = m.group(1) if m.group(1) else "+"
real_val_str = m.group(2)
real_val = parse_value(real_sign + real_val_str)

# Parse imaginary part with its sign
imag_sign = m.group(3)
# Group 4 is πj form (e.g., "πj/2"), group 5 is plain form (e.g., "NaN")
if m.group(4): # πj form
imag_val_str_raw = m.group(4)
# Remove 'j' to get coefficient: "πj/2" -> "π/2"
imag_val_str = imag_val_str_raw.replace('j', '')
else: # plain form
imag_val_str_raw = m.group(5)
# Strip trailing 'j' if present: "0j" -> "0"
imag_val_str = imag_val_str_raw[:-1] if imag_val_str_raw.endswith('j') else imag_val_str_raw

imag_val = parse_value(imag_sign + imag_val_str)

return complex(real_val, imag_val)


def make_strict_eq_complex(v: complex) -> Callable[[complex], bool]:
"""
Creates a checker for complex values that respects sign of zero and NaN.
"""
real_check = make_strict_eq(v.real)
imag_check = make_strict_eq(v.imag)

def strict_eq_complex(z: complex) -> bool:
return real_check(z.real) and imag_check(z.imag)

return strict_eq_complex


def parse_complex_cond(
a_cond_str: str, b_cond_str: str
) -> Tuple[Callable[[complex], bool], str, FromDtypeFunc]:
"""
Parses complex condition strings for real (a) and imaginary (b) parts.

Returns:
- cond: Function that checks if a complex number meets the condition
- expr: String expression for the condition
- from_dtype: Strategy generator for complex numbers meeting the condition
"""
# Parse conditions for real and imaginary parts separately
a_cond, a_expr_template, a_from_dtype = parse_cond(a_cond_str)
b_cond, b_expr_template, b_from_dtype = parse_cond(b_cond_str)

# Create compound condition
def complex_cond(z: complex) -> bool:
return a_cond(z.real) and b_cond(z.imag)

# Create expression
a_expr = a_expr_template.replace("{}", "real(x_i)")
b_expr = b_expr_template.replace("{}", "imag(x_i)")
expr = f"{a_expr} and {b_expr}"

# Create strategy that generates complex numbers
def complex_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[complex]:
assert len(kw) == 0 # sanity check
# For complex dtype, we need to get the corresponding float dtype
# complex64 -> float32, complex128 -> float64
if hasattr(dtype, 'name'):
if 'complex64' in str(dtype):
float_dtype = xp.float32
elif 'complex128' in str(dtype):
float_dtype = xp.float64
else:
# Fallback to float64
float_dtype = xp.float64
else:
float_dtype = xp.float64

real_strat = a_from_dtype(float_dtype)
imag_strat = b_from_dtype(float_dtype)
return st.builds(complex, real_strat, imag_strat)

return complex_cond, expr, complex_from_dtype


def _check_component_with_tolerance(actual: float, expected: float, allow_any_sign: bool) -> bool:
"""
Helper to check if actual matches expected, with optional sign flexibility and tolerance.
"""
if allow_any_sign and not math.isnan(expected):
return abs(actual) == abs(expected) or math.isclose(abs(actual), abs(expected), abs_tol=0.01)
elif not math.isnan(expected):
check_fn = make_strict_eq(expected) if expected == 0 or math.isinf(expected) else make_rough_eq(expected)
return check_fn(actual)
else:
return math.isnan(actual)


def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], str]:
"""
Parses a complex result string to return a checker and expression.

Handles cases like:
- "``+0 + 0j``" - exact complex value
- "``0 + NaN j`` (sign of the real component is unspecified)"
- "``+0 + πj/2``" - with π expressions (uses approximate equality)
"""
# Check for unspecified sign notes
unspecified_real_sign = "sign of the real component is unspecified" in result_str
unspecified_imag_sign = "sign of the imaginary component is unspecified" in result_str

# Extract the complex value from backticks - need to handle spaces in complex values
# Pattern: ``...`` where ... can contain spaces (for complex values like "0 + NaN j")
m = re.search(r"``([^`]+)``", result_str)
if m:
value_str = m.group(1)
# Check if the value contains π expressions (for approximate comparison)
has_pi = 'π' in value_str

try:
expected = parse_complex_value(value_str)
except ParseError:
raise ParseError(result_str)

# Create checker based on whether signs are unspecified and whether π is involved
if has_pi:
# Use approximate equality for both real and imaginary parts if they involve π
def check_result(z: complex) -> bool:
real_match = _check_component_with_tolerance(z.real, expected.real, unspecified_real_sign)
imag_match = _check_component_with_tolerance(z.imag, expected.imag, unspecified_imag_sign)
return real_match and imag_match
elif unspecified_real_sign and not math.isnan(expected.real):
# Allow any sign for real part
def check_result(z: complex) -> bool:
imag_check = make_strict_eq(expected.imag)
return abs(z.real) == abs(expected.real) and imag_check(z.imag)
elif unspecified_imag_sign and not math.isnan(expected.imag):
# Allow any sign for imaginary part
def check_result(z: complex) -> bool:
real_check = make_strict_eq(expected.real)
return real_check(z.real) and abs(z.imag) == abs(expected.imag)
elif unspecified_real_sign and unspecified_imag_sign:
# Allow any sign for both parts
def check_result(z: complex) -> bool:
return abs(z.real) == abs(expected.real) and abs(z.imag) == abs(expected.imag)
else:
# Exact match including signs
check_result = make_strict_eq_complex(expected)

expr = value_str
return check_result, expr
else:
raise ParseError(result_str)


class Case(Protocol):
cond_expr: str
result_expr: str
Expand Down Expand Up @@ -549,6 +722,16 @@ class UnaryCase(Case):
"If ``x_i`` is ``NaN`` and the sign bit of ``x_i`` is ``(.+)``, "
"the result is ``(.+)``"
)
# Regex patterns for complex special cases
r_complex_marker = re.compile(
r"For complex floating-point operands, let ``a = real\(x_i\)``, ``b = imag\(x_i\)``"
)
r_complex_case = re.compile(r"If ``a`` is (.+) and ``b`` is (.+), the result is (.+)")
# Matches complex values like "+0 + 0j", "NaN + NaN j", "infinity + NaN j", "πj/2", "3πj/4"
# Two formats: 1) πj/N expressions where j is part of the coefficient, 2) plain values followed by j
r_complex_value = re.compile(
r"([+-]?)([^\s]+)\s*([+-])\s*(?:(\d*πj(?:/\d+)?)|([^\s]+))\s*j?"
)


def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
Expand Down Expand Up @@ -630,6 +813,14 @@ def check_result(i: float, result: float) -> bool:
return check_result


def make_complex_unary_check_result(check_fn: Callable[[complex], bool]) -> UnaryResultCheck:
"""Wraps a complex check function for use in UnaryCase."""
def check_result(in_value, out_value):
# in_value is complex, out_value is complex
return check_fn(out_value)
return check_result


def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]:
"""
Parses a Sphinx-formatted docstring of a unary function to return a list of
Expand Down Expand Up @@ -677,8 +868,46 @@ def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]:

"""
cases = []
# Check if the case block contains complex cases by looking for the marker
in_complex_section = r_complex_marker.search(case_block) is not None

for case_m in r_case.finditer(case_block):
case_str = case_m.group(1)

# Try to parse complex cases if we're in the complex section
if in_complex_section and (m := r_complex_case.search(case_str)):
try:
a_cond_str = m.group(1)
b_cond_str = m.group(2)
result_str = m.group(3)

# Skip cases with complex expressions like "cis(b)"
if "cis" in result_str or "*" in result_str:
warn(f"case for {func_name} not machine-readable: '{case_str}'")
continue

# Parse the complex condition and result
complex_cond, cond_expr, complex_from_dtype = parse_complex_cond(
a_cond_str, b_cond_str
)
_check_result, result_expr = parse_complex_result(result_str)

check_result = make_complex_unary_check_result(_check_result)

case = UnaryCase(
cond_expr=cond_expr,
cond=complex_cond,
cond_from_dtype=complex_from_dtype,
result_expr=result_expr,
check_result=check_result,
raw_case=case_str,
)
cases.append(case)
except ParseError as e:
warn(f"case for {func_name} not machine-readable: '{e.value}'")
continue

# Parse regular (real-valued) cases
if r_already_int_case.search(case_str):
cases.append(already_int_case)
elif r_even_round_halves_case.search(case_str):
Expand Down Expand Up @@ -1257,10 +1486,31 @@ def test_unary(func_name, func, case):
# drawing multiple examples like a normal test, or just hard-coding a
# single example test case without using hypothesis.
filterwarnings('ignore', category=NonInteractiveExampleWarning)
in_value = case.cond_from_dtype(xp.float64).example()
x = xp.asarray(in_value, dtype=xp.float64)
out = func(x)
out_value = float(out)

# Determine if this is a complex case by checking the strategy
# Try to generate an example to see if it's complex
try:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot Please avoid this nested try-except clause. When parsing and constructing the case, could you add a flag to whether the case comes from a float or complex section, and use it here to select the appropriate strategy.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored to use an is_complex flag on UnaryCase instead of nested try-except. Added the flag to the dataclass (default=False), set it to True when parsing complex cases, and use it in test_unary to select the appropriate dtype strategy. Commit e48216b.

in_value = case.cond_from_dtype(xp.float64).example()
except Exception:
# If float64 fails, try complex128
try:
in_value = case.cond_from_dtype(xp.complex128).example()
except Exception:
# Fallback to float64
in_value = case.cond_from_dtype(xp.float64).example()

# Determine appropriate dtype based on input value type
if isinstance(in_value, complex):
dtype = xp.complex128
x = xp.asarray(in_value, dtype=dtype)
out = func(x)
out_value = complex(out)
else:
dtype = xp.float64
x = xp.asarray(in_value, dtype=dtype)
out = func(x)
out_value = float(out)

assert case.check_result(in_value, out_value), (
f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n"
)
Expand Down
Loading