Skip to content

Commit

Permalink
fix(ir): ensure that operation nodes and expressions are slotted
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Sep 24, 2024
1 parent 660d3aa commit 9dc4ea7
Show file tree
Hide file tree
Showing 14 changed files with 357 additions and 338 deletions.
3 changes: 1 addition & 2 deletions ibis/backends/polars/rewrites.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import annotations

from koerce import attribute, replace
from public import public

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.annotations import attribute
from ibis.common.collections import FrozenDict
from ibis.common.patterns import replace
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.schema import Schema

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
SnowflakeProgrammingError,
TrinoUserError,
)
from ibis.conftest import IS_SPARK_REMOTE
from ibis.common.grounds import ValidationError
from ibis.conftest import IS_SPARK_REMOTE

np = pytest.importorskip("numpy")
pd = pytest.importorskip("pandas")
Expand Down
8 changes: 4 additions & 4 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,15 +462,15 @@ def is_variadic(self) -> bool:


@public
class Unknown(DataType):
class Unknown(DataType, singleton=True):
"""An unknown type."""

scalar = "UnknownScalar"
column = "UnknownColumn"


@public
class Primitive(DataType):
class Primitive(DataType, singleton=True):
"""Values with known size."""


Expand Down Expand Up @@ -531,7 +531,7 @@ def nbytes(self) -> int:


@public
class String(Variadic):
class String(Variadic, singleton=True):
"""A type representing a string.
Notes
Expand All @@ -546,7 +546,7 @@ class String(Variadic):


@public
class Binary(Variadic):
class Binary(Variadic, singleton=True):
"""A type representing a sequence of bytes.
Notes
Expand Down
35 changes: 22 additions & 13 deletions ibis/expr/datatypes/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,20 +432,29 @@ def test_struct_equality():
assert st3 != st2


def test_booleqn_equality():
def test_singleton_datatypes():
assert dt.null is dt.Null()
assert dt.unknown is dt.Unknown()
assert dt.boolean is dt.Boolean()
assert dt.string is dt.String()
assert dt.binary is dt.Binary()


def test_singleton_boolean():
assert dt.Boolean() == dt.boolean
assert dt.Boolean() == dt.Boolean()
assert dt.Boolean(nullable=True) == dt.boolean
assert dt.Boolean(nullable=False) != dt.boolean
assert dt.Boolean(nullable=False) == dt.Boolean(nullable=False)
assert dt.Boolean(nullable=True) == dt.Boolean(nullable=True)
assert dt.Boolean(nullable=True) != dt.Boolean(nullable=False)


def test_primite_equality():
assert dt.Int64() == dt.int64
assert dt.Int64(nullable=False) != dt.int64
assert dt.Int64(nullable=False) == dt.Int64(nullable=False)
assert dt.Boolean() is dt.boolean
assert dt.Boolean() is dt.Boolean()
assert dt.Boolean(nullable=True) is dt.boolean
assert dt.Boolean(nullable=False) is not dt.boolean
assert dt.Boolean(nullable=False) is dt.Boolean(nullable=False)
assert dt.Boolean(nullable=True) is dt.Boolean(nullable=True)
assert dt.Boolean(nullable=True) is not dt.Boolean(nullable=False)


def test_singleton_primitive():
assert dt.Int64() is dt.int64
assert dt.Int64(nullable=False) is not dt.int64
assert dt.Int64(nullable=False) is dt.Int64(nullable=False)


def test_array_type_not_equals():
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/operations/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def name(self):


@public
class Constant(Scalar):
class Constant(Scalar, singleton=True):
"""A function that produces a constant."""

shape = ds.scalar
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/operations/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ class ArrayCollect(Filterable, Reduction):

def __init__(self, arg, order_by, distinct, **kwargs):
if distinct and order_by and [arg] != [key.expr for key in order_by]:
raise ValidationError(
raise ValueError(
"`collect` with `order_by` and `distinct=True` and may only "
"order by the collected column"
)
Expand Down
2 changes: 2 additions & 0 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ class Difference(Set):
class PhysicalTable(Relation):
"""Base class for tables with a name."""

__slots__ = ("__weakref__",)

name: str
values = FrozenOrderedDict()

Expand Down
1 change: 1 addition & 0 deletions ibis/expr/tests/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ibis.expr.operations as ops
from ibis import _
from ibis.common.exceptions import IbisTypeError
from ibis.common.grounds import ValidationError


@pytest.mark.parametrize(
Expand Down
8 changes: 6 additions & 2 deletions ibis/expr/types/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import webbrowser
from typing import TYPE_CHECKING, Any, NoReturn

from koerce import MatchError
from koerce import Immutable, MatchError
from public import public

import ibis
Expand Down Expand Up @@ -38,11 +38,15 @@

class _FixedTextJupyterMixin:
"""No-op when rich is not installed."""

__slots__ = ()

Check warning on line 42 in ibis/expr/types/core.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/types/core.py#L42

Added line #L42 was not covered by tests
else:

class _FixedTextJupyterMixin(JupyterMixin):
"""JupyterMixin adds a spurious newline to text, this fixes the issue."""

__slots__ = ()

def _repr_mimebundle_(self, *args, **kwargs):
try:
bundle = super()._repr_mimebundle_(*args, **kwargs)
Expand All @@ -63,7 +67,7 @@ def _capture_rich_renderable(renderable: RenderableType) -> str:


@public
class Expr:
class Expr(Immutable):
"""Base expression class."""

__slots__ = ("_arg",)
Expand Down
8 changes: 7 additions & 1 deletion ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2493,6 +2493,9 @@ class NullColumn(Column, NullValue):
pass


_THE_NULL = None


@public
@deferrable
def null(type: dt.DataType | str | None = None) -> Value:
Expand All @@ -2516,8 +2519,11 @@ def null(type: dt.DataType | str | None = None) -> Value:
│ True │
└──────┘
"""
global _THE_NULL # noqa: PLW0603
if type is None:
type = dt.null
if _THE_NULL is None:
_THE_NULL = ops.Literal(None, dt.null).to_expr()
return _THE_NULL
return ops.Literal(None, type).to_expr()


Expand Down
18 changes: 18 additions & 0 deletions ibis/tests/expr/test_value_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,3 +1735,21 @@ def test_value_fillna_depr_warn():
t = ibis.table({"a": "int", "b": "str"})
with pytest.warns(FutureWarning, match="v9.1"):
t.b.fillna("missing")


def assert_slotted(obj):
assert hasattr(obj, "__slots__")
assert not hasattr(obj, "__dict__")


def test_that_value_expressions_are_slotted():
t = ibis.table({"a": "int", "b": "str"})
exprs = [
t.a,
t.b,
t.a + 1,
t,
]
for expr in exprs:
assert_slotted(expr)
assert_slotted(expr.op())
Loading

0 comments on commit 9dc4ea7

Please sign in to comment.