Skip to content

Commit

Permalink
feat(api): support quarterly truncation (#9715)
Browse files Browse the repository at this point in the history
Adds support for quarterly date and timestamp truncate. Closes #9714.
  • Loading branch information
cpcloud authored Jul 29, 2024
1 parent 20ceee5 commit 75b31c2
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 23 deletions.
18 changes: 0 additions & 18 deletions ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import dask.dataframe as dd
import numpy as np
import pandas as pd
from packaging.version import parse as vparse

import ibis.backends.dask.kernels as dask_kernels
import ibis.expr.operations as ops
Expand Down Expand Up @@ -97,23 +96,6 @@ def mapper(df, cases, results, default):

return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype)

@classmethod
def visit(cls, op: ops.TimestampTruncate | ops.DateTruncate, arg, unit):
# TODO(kszucs): should use serieswise()
if vparse(pd.__version__) >= vparse("2.2"):
units = {"m": "min"}
else:
units = {"m": "Min", "ms": "L"}

unit = units.get(unit.short, unit.short)

if unit in "YMWD":
return arg.dt.to_period(unit).dt.to_timestamp()
try:
return arg.dt.floor(unit)
except ValueError:
return arg.dt.to_period(unit).dt.to_timestamp()

@classmethod
def visit(cls, op: ops.IntervalFromInteger, unit, **kwargs):
if unit.short in {"Y", "Q", "M", "W"}:
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def visit(cls, op: ops.TimestampTruncate | ops.DateTruncate, arg, unit):

unit = units.get(unit.short, unit.short)

if unit in "YMWD":
if unit in "YQMWD":
return arg.dt.to_period(unit).dt.to_timestamp()
try:
return arg.dt.floor(unit)
Expand Down
9 changes: 6 additions & 3 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ def visit_ExtractSecond(self, op, *, arg):
def visit_TimestampTruncate(self, op, *, arg, unit):
unit_mapping = {
"Y": "year",
"Q": "quarter",
"M": "month",
"W": "week",
"D": "day",
Expand All @@ -847,10 +848,12 @@ def visit_TimestampTruncate(self, op, *, arg, unit):
"us": "us",
}

if (unit := unit_mapping.get(unit.short)) is None:
raise com.UnsupportedOperationError(f"Unsupported truncate unit {unit}")
if (raw_unit := unit_mapping.get(unit.short)) is None:
raise com.UnsupportedOperationError(
f"Unsupported truncate unit {unit.short!r}"
)

return self.f.date_trunc(unit, arg)
return self.f.date_trunc(raw_unit, arg)

def visit_DateTruncate(self, op, *, arg, unit):
return self.visit_TimestampTruncate(op, arg=arg, unit=unit)
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit):
def visit_TimestampTruncate(self, op, *, arg, unit):
converters = {
"Y": "toStartOfYear",
"Q": "toStartOfQuarter",
"M": "toStartOfMonth",
"W": "toMonday",
"D": "toDate",
Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/sql/compilers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,19 @@ def visit_LRStrip(self, op, *, arg, position):
)

def visit_DateTimestampTruncate(self, op, *, arg, unit):
if unit.short == "Q":
# adapted from https://stackoverflow.com/a/11884743
return (
# January 1 of the year of the `arg`
self.f.makedate(self.f.year(arg), 1)
# add the current quarter's number of quarters minus one to Jan 1
# first quarter: add zero
# second quarter: add one
# third quarter: add two
# fourth quarter: add three
+ sge.Interval(this=self.f.quarter(arg) - 1, unit=self.v.QUARTER)
)

truncate_formats = {
"s": "%Y-%m-%d %H:%i:%s",
"m": "%Y-%m-%d %H:%i:00",
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def visit_Xor(self, op, *, left, right):
def visit_DateTruncate(self, op, *, arg, unit):
trunc_unit_mapping = {
"Y": "year",
"Q": "Q",
"M": "MONTH",
"W": "IW",
"D": "DDD",
Expand Down
14 changes: 14 additions & 0 deletions ibis/backends/sql/compilers/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,20 @@ def visit_Modulus(self, op, *, left, right):
return self.f.anon.mod(left, right)

def _temporal_truncate(self, func, arg, unit):
if unit.short == "Q":
return sge.Case(
ifs=[
self.if_(
sge.Between(
this=self.cast(self.f.strftime("%m", arg), dt.int32),
low=sge.convert(lower),
high=sge.convert(lower + 2),
),
self.f.strftime(f"%Y-{lower:0>2}-01", arg),
)
for lower in range(1, 13, 3)
],
)
modifiers = {
DateUnit.DAY: ("start of day",),
DateUnit.WEEK: ("weekday 0", "-6 days"),
Expand Down
14 changes: 13 additions & 1 deletion ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,17 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df):
),
],
),
param(
"Q",
"Q",
marks=[
pytest.mark.notimpl(
["polars"],
raises=AssertionError,
reason="numpy array are different",
),
],
),
param(
"M",
"M",
Expand Down Expand Up @@ -399,7 +410,7 @@ def test_timestamp_truncate(backend, alltypes, df, ibis_unit, pandas_unit):

dtns = df.timestamp_col.dt

if ibis_unit in ("Y", "M", "D", "W"):
if ibis_unit in ("Y", "Q", "M", "D", "W"):
expected = dtns.to_period(pandas_unit).dt.to_timestamp()
else:
expected = dtns.floor(pandas_unit)
Expand All @@ -414,6 +425,7 @@ def test_timestamp_truncate(backend, alltypes, df, ibis_unit, pandas_unit):
"unit",
[
"Y",
"Q",
"M",
"D",
param(
Expand Down

0 comments on commit 75b31c2

Please sign in to comment.