Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 222 additions & 15 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import ast
import asyncio
import builtins
import copy
import datetime
import io
Expand Down Expand Up @@ -72,7 +74,6 @@
Self,
T_ChunkDim,
T_ChunksFreq,
T_DataArray,
T_DataArrayOrSet,
ZarrWriteModes,
)
Expand Down Expand Up @@ -9533,19 +9534,182 @@ def argmax(self, dim: Hashable | None = None, **kwargs) -> Self:
"Dataset.argmin() with a sequence or ... for dim"
)

# Base namespace for eval expressions (modules added lazily in _eval_expression
# to avoid circular imports for xarray).
# We add common builtins back since we use an empty __builtins__ dict.
# Note: builtins.map is used explicitly because 'map' in class scope refers
# to the Dataset.map method defined earlier in this class body.
_EVAL_NAMESPACE_BUILTINS: dict[str, Any] = {
# Numeric/aggregation functions
"abs": abs,
"min": min,
"max": max,
"round": round,
"len": len,
"sum": sum,
"pow": pow,
"any": any,
"all": all,
# Type constructors
"int": int,
"float": float,
"bool": bool,
"str": str,
"list": list,
"tuple": tuple,
"dict": dict,
"set": set,
"slice": slice,
# Iteration helpers
"range": range,
"zip": zip,
"enumerate": enumerate,
"map": builtins.map,
"filter": filter,
}

# -------------------------------------------------------------------------
# eval() Implementation Notes (for future maintainers):
#
# This implementation uses native AST-based evaluation instead of pd.eval()
# to support N-dimensional arrays (N > 2). See GitHub issue #11062.
#
# We retain logical operator transformation ('and'/'or'/'not' to '&'/'|'/'~',
# and chained comparisons) for consistency with query(), which still uses
# pd.eval(). We don't migrate query() to this implementation because:
# - query() typically works fine (expressions usually compare 1D coordinates)
# - pd.eval() with numexpr is faster and well-tested for query's use case
# -------------------------------------------------------------------------

class _LogicalOperatorTransformer(ast.NodeTransformer):
"""Transform operators for consistency with query().

query() uses pd.eval() which transforms these operators automatically.
We replicate that behavior here so syntax that works in query() also
works in eval().

Transformations:
1. 'and'/'or'/'not' -> '&'/'|'/'~'
2. 'a < b < c' -> '(a < b) & (b < c)'

These constructs fail on arrays in standard Python because they call
__bool__(), which is ambiguous for multi-element arrays.
"""

def visit_BoolOp(self, node: ast.BoolOp) -> ast.AST:
# Transform: a and b -> a & b, a or b -> a | b
self.generic_visit(node)
op: ast.BitAnd | ast.BitOr
if isinstance(node.op, ast.And):
op = ast.BitAnd()
elif isinstance(node.op, ast.Or):
op = ast.BitOr()
else:
return node

# BoolOp can have multiple values: a and b and c
# Transform to chained BinOp: (a & b) & c
result = node.values[0]
for value in node.values[1:]:
result = ast.BinOp(left=result, op=op, right=value)
return ast.fix_missing_locations(result)

def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST:
# Transform: not a -> ~a
self.generic_visit(node)
if isinstance(node.op, ast.Not):
return ast.fix_missing_locations(
ast.UnaryOp(op=ast.Invert(), operand=node.operand)
)
return node

def visit_Compare(self, node: ast.Compare) -> ast.AST:
# Transform chained comparisons: 1 < x < 5 -> (1 < x) & (x < 5)
# Python's chained comparisons use short-circuit evaluation at runtime,
# which calls __bool__ on intermediate results. This fails for arrays.
# We transform to bitwise AND which works element-wise.
self.generic_visit(node)

if len(node.ops) == 1:
# Simple comparison, no transformation needed
return node

# Build individual comparisons and chain with BitAnd
# For: a < b < c < d
# We need: (a < b) & (b < c) & (c < d)
comparisons = []
left = node.left
for op, comparator in zip(node.ops, node.comparators, strict=True):
comp = ast.Compare(left=left, ops=[op], comparators=[comparator])
comparisons.append(comp)
left = comparator

# Chain with BitAnd: (a < b) & (b < c) & ...
result: ast.Compare | ast.BinOp = comparisons[0]
for comp in comparisons[1:]:
result = ast.BinOp(left=result, op=ast.BitAnd(), right=comp)
return ast.fix_missing_locations(result)

def _validate_eval_expression(self, tree: ast.AST) -> None:
"""Validate that an AST doesn't contain patterns we don't support.

These restrictions emulate pd.eval() behavior for consistency.
"""
for node in ast.walk(tree):
# Block lambda expressions (pd.eval: "Only named functions are supported")
if isinstance(node, ast.Lambda):
raise ValueError(
"Lambda expressions are not allowed in eval(). "
"Use direct operations on data variables instead."
)
# Block private/dunder attributes (consistent with pd.eval restrictions)
if isinstance(node, ast.Attribute) and node.attr.startswith("_"):
raise ValueError(
f"Access to private attributes is not allowed: '{node.attr}'"
)

def _eval_expression(self, expr: str) -> DataArray:
"""Evaluate an expression string using xarray's native operations."""
try:
tree = ast.parse(expr, mode="eval")
except SyntaxError as e:
raise ValueError(f"Invalid expression syntax: {expr}") from e

# Transform logical operators for consistency with query().
# See _LogicalOperatorTransformer docstring for details.
tree = self._LogicalOperatorTransformer().visit(tree)
ast.fix_missing_locations(tree)

self._validate_eval_expression(tree)

# Build namespace: data variables, coordinates, modules, and safe builtins.
# Empty __builtins__ blocks dangerous functions like __import__, exec, open.
# Priority order (highest to lowest): data variables > coordinates > modules > builtins
# This ensures user data always wins when names collide with builtins.
import xarray as xr # Lazy import to avoid circular dependency

namespace: dict[str, Any] = dict(self._EVAL_NAMESPACE_BUILTINS)
namespace.update({"np": np, "pd": pd, "xr": xr})
namespace.update({str(name): self.coords[name] for name in self.coords})
namespace.update({str(name): self[name] for name in self.data_vars})

code = compile(tree, "<xarray.eval>", "eval")
return builtins.eval(code, {"__builtins__": {}}, namespace)

def eval(
self,
statement: str,
*,
parser: QueryParserOptions = "pandas",
) -> Self | T_DataArray:
parser: QueryParserOptions | Default = _default,
) -> Self | DataArray:
"""
Calculate an expression supplied as a string in the context of the dataset.

This is currently experimental; the API may change particularly around
assignments, which currently return a ``Dataset`` with the additional variable.
Currently only the ``python`` engine is supported, which has the same
performance as executing in python.

Logical operators (``and``, ``or``, ``not``) are automatically transformed
to bitwise operators (``&``, ``|``, ``~``) which work element-wise on arrays.

Parameters
----------
Expand All @@ -9555,7 +9719,11 @@ def eval(
Returns
-------
result : Dataset or DataArray, depending on whether ``statement`` contains an
assignment.
assignment.

Warning
-------
Like ``pd.eval()``, this method should not be used with untrusted input.

Examples
--------
Expand Down Expand Up @@ -9584,16 +9752,55 @@ def eval(
b (x) float64 40B 0.0 0.25 0.5 0.75 1.0
c (x) float64 40B 0.0 1.25 2.5 3.75 5.0
"""
if parser is not _default:
emit_user_level_warning(
"The 'parser' argument to Dataset.eval() is deprecated and will be "
"removed in a future version. Logical operators (and/or/not) are now "
"always transformed to bitwise operators (&/|/~) for array compatibility.",
FutureWarning,
)

return pd.eval( # type: ignore[return-value]
statement,
resolvers=[self],
target=self,
parser=parser,
# Because numexpr returns a numpy array, using that engine results in
# different behavior. We'd be very open to a contribution handling this.
engine="python",
)
statement = statement.strip()

# Check for assignment: "target = expr"
# Must handle compound operators like ==, !=, <=, >=
# Use ast to detect assignment properly
try:
tree = ast.parse(statement, mode="exec")
except SyntaxError as e:
raise ValueError(f"Invalid statement syntax: {statement}") from e

if len(tree.body) != 1:
raise ValueError("Only single statements are supported")

stmt = tree.body[0]

if isinstance(stmt, ast.Assign):
# Assignment: "c = a + b"
if len(stmt.targets) != 1:
raise ValueError("Only single assignment targets are supported")
target = stmt.targets[0]
if not isinstance(target, ast.Name):
raise ValueError(
f"Assignment target must be a simple name, got {type(target).__name__}"
)
target_name = target.id

# Get the expression source
expr_source = ast.unparse(stmt.value)
result: DataArray = self._eval_expression(expr_source)
return self.assign({target_name: result})

elif isinstance(stmt, ast.Expr):
# Expression: "a + b"
expr_source = ast.unparse(stmt.value)
return self._eval_expression(expr_source)

else:
raise ValueError(
f"Unsupported statement type: {type(stmt).__name__}. "
f"Only expressions and assignments are supported."
)

def query(
self,
Expand Down
Loading
Loading