From 871762929489d53cba288ab4b7028b432184734b Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 23 Jul 2024 07:47:56 -0700 Subject: [PATCH] refactor(polars): remove casting where possible; handle conversion on output (#9673) --- ibis/backends/polars/__init__.py | 8 +++-- ibis/backends/polars/compiler.py | 43 ++++++++------------------ ibis/backends/polars/tests/test_udf.py | 2 +- 3 files changed, 20 insertions(+), 33 deletions(-) diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index aca178ccb254..e3bcb62a0b8e 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -496,11 +496,15 @@ def execute( return expr.__pandas_result__(df.to_pandas()) else: assert isinstance(expr, ir.Column), type(expr) - if expr.type().is_temporal(): + + dtype = expr.type() + if dtype.is_temporal(): return expr.__pandas_result__(df.to_pandas()) else: + from ibis.formats.pandas import PandasData + # note: skip frame-construction overhead - return df.to_series().to_pandas() + return PandasData.convert_column(df.to_series().to_pandas(), dtype) def to_polars( self, diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 7d899aaa426f..b96dec94fa60 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -100,9 +100,7 @@ def literal(op, **_): return pl.struct(values) elif dtype.is_interval(): return _make_duration(value, dtype) - elif dtype.is_null(): - return pl.lit(value) - elif dtype.is_binary(): + elif dtype.is_null() or dtype.is_binary(): return pl.lit(value) else: typ = PolarsType.from_ibis(dtype) @@ -502,17 +500,15 @@ def in_values(op, **kw): @translate.register(ops.StringLength) def string_length(op, **kw): arg = translate(op.arg, **kw) - typ = PolarsType.from_ibis(op.dtype) - return arg.str.len_bytes().cast(typ) + return arg.str.len_bytes() @translate.register(ops.Capitalize) def capitalize(op, **kw): arg = translate(op.arg, **kw) - typ = PolarsType.from_ibis(op.dtype) first = arg.str.slice(0, 1).str.to_uppercase() rest = arg.str.slice(1, None).str.to_lowercase() - return (first + rest).cast(typ) + return first + rest @translate.register(ops.StringUnary) @@ -757,9 +753,7 @@ def reduction(op, **kw): first, *rest = args method = operator.methodcaller(agg, *rest) - return method(first.filter(reduce(operator.and_, predicates))).cast( - PolarsType.from_ibis(op.dtype) - ) + return method(first.filter(reduce(operator.and_, predicates))) @translate.register(ops.Mode) @@ -770,13 +764,12 @@ def execute_mode(op, **kw): if (where := op.where) is not None: predicate &= translate(where, **kw) - dtype = PolarsType.from_ibis(op.dtype) # `mode` can return more than one value so the additional `get(0)` call is # necessary to enforce aggregation behavior of a scalar value per group # # eventually we may want to support an Ibis API like `modes` that returns a # list of all the modes per group. - return arg.filter(predicate).mode().get(0).cast(dtype) + return arg.filter(predicate).mode().get(0) @translate.register(ops.Quantile) @@ -1018,16 +1011,13 @@ def array_flatten(op, **kw): def extract_date_field(op, **kw): arg = translate(op.arg, **kw) method = operator.methodcaller(_date_methods[type(op)]) - return method(arg.dt).cast(pl.Int32) + return method(arg.dt) @translate.register(ops.ExtractEpochSeconds) def extract_epoch_seconds(op, **kw): arg = translate(op.arg, **kw) - return arg.dt.epoch("s").cast(pl.Int32) - - -_day_of_week_offset = vparse(pl.__version__) >= vparse("0.15.1") + return arg.dt.epoch("s") _unary = { @@ -1039,9 +1029,7 @@ def extract_epoch_seconds(op, **kw): ops.Ceil: lambda arg: arg.ceil().cast(pl.Int64), ops.Cos: operator.methodcaller("cos"), ops.Cot: lambda arg: 1.0 / arg.tan(), - ops.DayOfWeekIndex: ( - lambda arg: arg.dt.weekday().cast(pl.Int16) - _day_of_week_offset - ), + ops.DayOfWeekIndex: lambda arg: arg.dt.weekday() - 1, ops.Exp: operator.methodcaller("exp"), ops.Floor: lambda arg: arg.floor().cast(pl.Int64), ops.IsInf: operator.methodcaller("is_infinite"), @@ -1062,7 +1050,7 @@ def extract_epoch_seconds(op, **kw): @translate.register(ops.DayOfWeekName) def day_of_week_name(op, **kw): - index = translate(op.arg, **kw).dt.weekday() - _day_of_week_offset + index = translate(op.arg, **kw).dt.weekday() - 1 arg = None for i, name in enumerate(calendar.day_name): arg = pl.when(index == i).then(pl.lit(name)).otherwise(arg) @@ -1102,9 +1090,8 @@ def comparison(op, **kw): def between(op, **kw): op_arg = op.arg arg = translate(op_arg, **kw) - dtype = op_arg.dtype - lower = translate(ops.Cast(op.lower_bound, dtype), **kw) - upper = translate(ops.Cast(op.upper_bound, dtype), **kw) + lower = translate(op.lower_bound, **kw) + upper = translate(op.upper_bound, **kw) return arg.is_between(lower, upper, closed="both") @@ -1112,18 +1099,14 @@ def between(op, **kw): def bitwise_left_shift(op, **kw): left = translate(op.left, **kw) right = translate(op.right, **kw) - return (left.cast(pl.Int64) * 2 ** right.cast(pl.Int64)).cast( - PolarsType.from_ibis(op.dtype) - ) + return left.cast(pl.Int64) * 2 ** right.cast(pl.Int64) @translate.register(ops.BitwiseRightShift) def bitwise_right_shift(op, **kw): left = translate(op.left, **kw) right = translate(op.right, **kw) - return (left.cast(pl.Int64) // 2 ** right.cast(pl.Int64)).cast( - PolarsType.from_ibis(op.dtype) - ) + return left.cast(pl.Int64) // 2 ** right.cast(pl.Int64) _binops = { diff --git a/ibis/backends/polars/tests/test_udf.py b/ibis/backends/polars/tests/test_udf.py index c9c905699361..f79295fd82c0 100644 --- a/ibis/backends/polars/tests/test_udf.py +++ b/ibis/backends/polars/tests/test_udf.py @@ -47,7 +47,7 @@ def test_multiple_argument_udf(alltypes): df = alltypes[["smallint_col", "int_col"]].execute() expected = df.smallint_col + df.int_col - tm.assert_series_equal(result, expected.rename("tmp")) + tm.assert_series_equal(result, expected.astype("int64").rename("tmp")) @pytest.mark.parametrize(