From 75b31c2ebc0a136a11513b483a35088c38bdb0a9 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 29 Jul 2024 16:53:28 -0400 Subject: [PATCH] feat(api): support quarterly truncation (#9715) Adds support for quarterly date and timestamp truncate. Closes #9714. --- ibis/backends/dask/executor.py | 18 ------------------ ibis/backends/pandas/executor.py | 2 +- ibis/backends/sql/compilers/base.py | 9 ++++++--- ibis/backends/sql/compilers/clickhouse.py | 1 + ibis/backends/sql/compilers/mysql.py | 13 +++++++++++++ ibis/backends/sql/compilers/oracle.py | 1 + ibis/backends/sql/compilers/sqlite.py | 14 ++++++++++++++ ibis/backends/tests/test_temporal.py | 14 +++++++++++++- 8 files changed, 49 insertions(+), 23 deletions(-) diff --git a/ibis/backends/dask/executor.py b/ibis/backends/dask/executor.py index 921985a833c7..12d975d79966 100644 --- a/ibis/backends/dask/executor.py +++ b/ibis/backends/dask/executor.py @@ -7,7 +7,6 @@ import dask.dataframe as dd import numpy as np import pandas as pd -from packaging.version import parse as vparse import ibis.backends.dask.kernels as dask_kernels import ibis.expr.operations as ops @@ -97,23 +96,6 @@ def mapper(df, cases, results, default): return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype) - @classmethod - def visit(cls, op: ops.TimestampTruncate | ops.DateTruncate, arg, unit): - # TODO(kszucs): should use serieswise() - if vparse(pd.__version__) >= vparse("2.2"): - units = {"m": "min"} - else: - units = {"m": "Min", "ms": "L"} - - unit = units.get(unit.short, unit.short) - - if unit in "YMWD": - return arg.dt.to_period(unit).dt.to_timestamp() - try: - return arg.dt.floor(unit) - except ValueError: - return arg.dt.to_period(unit).dt.to_timestamp() - @classmethod def visit(cls, op: ops.IntervalFromInteger, unit, **kwargs): if unit.short in {"Y", "Q", "M", "W"}: diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index cca34cd9f0f2..2e2aa6504e5e 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -185,7 +185,7 @@ def visit(cls, op: ops.TimestampTruncate | ops.DateTruncate, arg, unit): unit = units.get(unit.short, unit.short) - if unit in "YMWD": + if unit in "YQMWD": return arg.dt.to_period(unit).dt.to_timestamp() try: return arg.dt.floor(unit) diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index a040609e0118..e0897e81416c 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -837,6 +837,7 @@ def visit_ExtractSecond(self, op, *, arg): def visit_TimestampTruncate(self, op, *, arg, unit): unit_mapping = { "Y": "year", + "Q": "quarter", "M": "month", "W": "week", "D": "day", @@ -847,10 +848,12 @@ def visit_TimestampTruncate(self, op, *, arg, unit): "us": "us", } - if (unit := unit_mapping.get(unit.short)) is None: - raise com.UnsupportedOperationError(f"Unsupported truncate unit {unit}") + if (raw_unit := unit_mapping.get(unit.short)) is None: + raise com.UnsupportedOperationError( + f"Unsupported truncate unit {unit.short!r}" + ) - return self.f.date_trunc(unit, arg) + return self.f.date_trunc(raw_unit, arg) def visit_DateTruncate(self, op, *, arg, unit): return self.visit_TimestampTruncate(op, arg=arg, unit=unit) diff --git a/ibis/backends/sql/compilers/clickhouse.py b/ibis/backends/sql/compilers/clickhouse.py index 743bd70ad398..31bdb11bf44b 100644 --- a/ibis/backends/sql/compilers/clickhouse.py +++ b/ibis/backends/sql/compilers/clickhouse.py @@ -373,6 +373,7 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): def visit_TimestampTruncate(self, op, *, arg, unit): converters = { "Y": "toStartOfYear", + "Q": "toStartOfQuarter", "M": "toStartOfMonth", "W": "toMonday", "D": "toDate", diff --git a/ibis/backends/sql/compilers/mysql.py b/ibis/backends/sql/compilers/mysql.py index 95f262eebfc5..bbf616d4a2ec 100644 --- a/ibis/backends/sql/compilers/mysql.py +++ b/ibis/backends/sql/compilers/mysql.py @@ -285,6 +285,19 @@ def visit_LRStrip(self, op, *, arg, position): ) def visit_DateTimestampTruncate(self, op, *, arg, unit): + if unit.short == "Q": + # adapted from https://stackoverflow.com/a/11884743 + return ( + # January 1 of the year of the `arg` + self.f.makedate(self.f.year(arg), 1) + # add the current quarter's number of quarters minus one to Jan 1 + # first quarter: add zero + # second quarter: add one + # third quarter: add two + # fourth quarter: add three + + sge.Interval(this=self.f.quarter(arg) - 1, unit=self.v.QUARTER) + ) + truncate_formats = { "s": "%Y-%m-%d %H:%i:%s", "m": "%Y-%m-%d %H:%i:00", diff --git a/ibis/backends/sql/compilers/oracle.py b/ibis/backends/sql/compilers/oracle.py index e952b71c047f..8953f74ec4b0 100644 --- a/ibis/backends/sql/compilers/oracle.py +++ b/ibis/backends/sql/compilers/oracle.py @@ -329,6 +329,7 @@ def visit_Xor(self, op, *, left, right): def visit_DateTruncate(self, op, *, arg, unit): trunc_unit_mapping = { "Y": "year", + "Q": "Q", "M": "MONTH", "W": "IW", "D": "DDD", diff --git a/ibis/backends/sql/compilers/sqlite.py b/ibis/backends/sql/compilers/sqlite.py index 2c12cd9d97eb..6067d0a2557a 100644 --- a/ibis/backends/sql/compilers/sqlite.py +++ b/ibis/backends/sql/compilers/sqlite.py @@ -290,6 +290,20 @@ def visit_Modulus(self, op, *, left, right): return self.f.anon.mod(left, right) def _temporal_truncate(self, func, arg, unit): + if unit.short == "Q": + return sge.Case( + ifs=[ + self.if_( + sge.Between( + this=self.cast(self.f.strftime("%m", arg), dt.int32), + low=sge.convert(lower), + high=sge.convert(lower + 2), + ), + self.f.strftime(f"%Y-{lower:0>2}-01", arg), + ) + for lower in range(1, 13, 3) + ], + ) modifiers = { DateUnit.DAY: ("start of day",), DateUnit.WEEK: ("weekday 0", "-6 days"), diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 7ea03c7069d2..a0835ce5ffbd 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -274,6 +274,17 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): ), ], ), + param( + "Q", + "Q", + marks=[ + pytest.mark.notimpl( + ["polars"], + raises=AssertionError, + reason="numpy array are different", + ), + ], + ), param( "M", "M", @@ -399,7 +410,7 @@ def test_timestamp_truncate(backend, alltypes, df, ibis_unit, pandas_unit): dtns = df.timestamp_col.dt - if ibis_unit in ("Y", "M", "D", "W"): + if ibis_unit in ("Y", "Q", "M", "D", "W"): expected = dtns.to_period(pandas_unit).dt.to_timestamp() else: expected = dtns.floor(pandas_unit) @@ -414,6 +425,7 @@ def test_timestamp_truncate(backend, alltypes, df, ibis_unit, pandas_unit): "unit", [ "Y", + "Q", "M", "D", param(