Skip to content

Commit

Permalink
refactor(polars): remove casting where possible; handle conversion on…
Browse files Browse the repository at this point in the history
… output (#9673)
  • Loading branch information
cpcloud authored Jul 23, 2024
1 parent e4ff1bd commit 8717629
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 33 deletions.
8 changes: 6 additions & 2 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,15 @@ def execute(
return expr.__pandas_result__(df.to_pandas())
else:
assert isinstance(expr, ir.Column), type(expr)
if expr.type().is_temporal():

dtype = expr.type()
if dtype.is_temporal():
return expr.__pandas_result__(df.to_pandas())
else:
from ibis.formats.pandas import PandasData

# note: skip frame-construction overhead
return df.to_series().to_pandas()
return PandasData.convert_column(df.to_series().to_pandas(), dtype)

def to_polars(
self,
Expand Down
43 changes: 13 additions & 30 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def literal(op, **_):
return pl.struct(values)
elif dtype.is_interval():
return _make_duration(value, dtype)
elif dtype.is_null():
return pl.lit(value)
elif dtype.is_binary():
elif dtype.is_null() or dtype.is_binary():
return pl.lit(value)
else:
typ = PolarsType.from_ibis(dtype)
Expand Down Expand Up @@ -502,17 +500,15 @@ def in_values(op, **kw):
@translate.register(ops.StringLength)
def string_length(op, **kw):
arg = translate(op.arg, **kw)
typ = PolarsType.from_ibis(op.dtype)
return arg.str.len_bytes().cast(typ)
return arg.str.len_bytes()


@translate.register(ops.Capitalize)
def capitalize(op, **kw):
arg = translate(op.arg, **kw)
typ = PolarsType.from_ibis(op.dtype)
first = arg.str.slice(0, 1).str.to_uppercase()
rest = arg.str.slice(1, None).str.to_lowercase()
return (first + rest).cast(typ)
return first + rest


@translate.register(ops.StringUnary)
Expand Down Expand Up @@ -757,9 +753,7 @@ def reduction(op, **kw):

first, *rest = args
method = operator.methodcaller(agg, *rest)
return method(first.filter(reduce(operator.and_, predicates))).cast(
PolarsType.from_ibis(op.dtype)
)
return method(first.filter(reduce(operator.and_, predicates)))


@translate.register(ops.Mode)
Expand All @@ -770,13 +764,12 @@ def execute_mode(op, **kw):
if (where := op.where) is not None:
predicate &= translate(where, **kw)

dtype = PolarsType.from_ibis(op.dtype)
# `mode` can return more than one value so the additional `get(0)` call is
# necessary to enforce aggregation behavior of a scalar value per group
#
# eventually we may want to support an Ibis API like `modes` that returns a
# list of all the modes per group.
return arg.filter(predicate).mode().get(0).cast(dtype)
return arg.filter(predicate).mode().get(0)


@translate.register(ops.Quantile)
Expand Down Expand Up @@ -1018,16 +1011,13 @@ def array_flatten(op, **kw):
def extract_date_field(op, **kw):
arg = translate(op.arg, **kw)
method = operator.methodcaller(_date_methods[type(op)])
return method(arg.dt).cast(pl.Int32)
return method(arg.dt)


@translate.register(ops.ExtractEpochSeconds)
def extract_epoch_seconds(op, **kw):
arg = translate(op.arg, **kw)
return arg.dt.epoch("s").cast(pl.Int32)


_day_of_week_offset = vparse(pl.__version__) >= vparse("0.15.1")
return arg.dt.epoch("s")


_unary = {
Expand All @@ -1039,9 +1029,7 @@ def extract_epoch_seconds(op, **kw):
ops.Ceil: lambda arg: arg.ceil().cast(pl.Int64),
ops.Cos: operator.methodcaller("cos"),
ops.Cot: lambda arg: 1.0 / arg.tan(),
ops.DayOfWeekIndex: (
lambda arg: arg.dt.weekday().cast(pl.Int16) - _day_of_week_offset
),
ops.DayOfWeekIndex: lambda arg: arg.dt.weekday() - 1,
ops.Exp: operator.methodcaller("exp"),
ops.Floor: lambda arg: arg.floor().cast(pl.Int64),
ops.IsInf: operator.methodcaller("is_infinite"),
Expand All @@ -1062,7 +1050,7 @@ def extract_epoch_seconds(op, **kw):

@translate.register(ops.DayOfWeekName)
def day_of_week_name(op, **kw):
index = translate(op.arg, **kw).dt.weekday() - _day_of_week_offset
index = translate(op.arg, **kw).dt.weekday() - 1
arg = None
for i, name in enumerate(calendar.day_name):
arg = pl.when(index == i).then(pl.lit(name)).otherwise(arg)
Expand Down Expand Up @@ -1102,28 +1090,23 @@ def comparison(op, **kw):
def between(op, **kw):
op_arg = op.arg
arg = translate(op_arg, **kw)
dtype = op_arg.dtype
lower = translate(ops.Cast(op.lower_bound, dtype), **kw)
upper = translate(ops.Cast(op.upper_bound, dtype), **kw)
lower = translate(op.lower_bound, **kw)
upper = translate(op.upper_bound, **kw)
return arg.is_between(lower, upper, closed="both")


@translate.register(ops.BitwiseLeftShift)
def bitwise_left_shift(op, **kw):
left = translate(op.left, **kw)
right = translate(op.right, **kw)
return (left.cast(pl.Int64) * 2 ** right.cast(pl.Int64)).cast(
PolarsType.from_ibis(op.dtype)
)
return left.cast(pl.Int64) * 2 ** right.cast(pl.Int64)


@translate.register(ops.BitwiseRightShift)
def bitwise_right_shift(op, **kw):
left = translate(op.left, **kw)
right = translate(op.right, **kw)
return (left.cast(pl.Int64) // 2 ** right.cast(pl.Int64)).cast(
PolarsType.from_ibis(op.dtype)
)
return left.cast(pl.Int64) // 2 ** right.cast(pl.Int64)


_binops = {
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/polars/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_multiple_argument_udf(alltypes):
df = alltypes[["smallint_col", "int_col"]].execute()
expected = df.smallint_col + df.int_col

tm.assert_series_equal(result, expected.rename("tmp"))
tm.assert_series_equal(result, expected.astype("int64").rename("tmp"))


@pytest.mark.parametrize(
Expand Down

0 comments on commit 8717629

Please sign in to comment.