Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,15 @@ def aggregate(self, *exprs: DuckDBExpr) -> Self:
selection = [val.alias(name) for name, val in evaluate_exprs(self, *exprs)]
try:
return self._with_native(self.native.aggregate(selection)) # type: ignore[arg-type]
except Exception as e: # noqa: BLE001
raise catch_duckdb_exception(e, self) from None
except Exception as e:
raise catch_duckdb_exception(e, self) from e

def select(self, *exprs: DuckDBExpr) -> Self:
selection = (val.alias(name) for name, val in evaluate_exprs(self, *exprs))
try:
return self._with_native(self.native.select(*selection))
except Exception as e: # noqa: BLE001
raise catch_duckdb_exception(e, self) from None
except Exception as e:
raise catch_duckdb_exception(e, self) from e

def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
columns_to_drop = parse_columns_to_drop(self, columns, strict=strict)
Expand Down
17 changes: 0 additions & 17 deletions narwhals/_duckdb/expr_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,6 @@ def total_microseconds(self) -> DuckDBExpr:
+ F("datepart", lit("microsecond"), expr)
)

def truncate(self, every: str) -> DuckDBExpr:
interval = Interval.parse(every)
multiple, unit = interval.multiple, interval.unit
if multiple != 1:
# https://github.com/duckdb/duckdb/issues/17554
msg = f"Only multiple 1 is currently supported for DuckDB.\nGot {multiple!s}."
raise ValueError(msg)
if unit == "ns":
msg = "Truncating to nanoseconds is not yet supported for DuckDB."
raise NotImplementedError(msg)
format = lit(UNITS_DICT[unit])

def _truncate(expr: Expression) -> Expression:
return F("date_trunc", format, expr)

return self.compliant._with_elementwise(_truncate)

def offset_by(self, by: str) -> DuckDBExpr:
interval = Interval.parse_no_constraints(by)
format = lit(f"{interval.multiple!s} {UNITS_DICT[interval.unit]}")
Expand Down
14 changes: 13 additions & 1 deletion narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import operator
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar

import duckdb
from duckdb import CoalesceOperator, Expression
Expand Down Expand Up @@ -44,6 +44,18 @@ class DuckDBNamespace(
SQLNamespace[DuckDBLazyFrame, DuckDBExpr, "DuckDBPyRelation", Expression]
):
_implementation: Implementation = Implementation.DUCKDB
UNITS_DICT: ClassVar = {
"y": lit("year"),
"q": lit("quarter"),
"mo": lit("month"),
"d": lit("day"),
"h": lit("hour"),
"m": lit("minute"),
"s": lit("second"),
"ms": lit("millisecond"),
"us": lit("microsecond"),
"ns": lit("nanosecond"),
}

def __init__(self, *, version: Version) -> None:
self._version = version
Expand Down
37 changes: 2 additions & 35 deletions narwhals/_ibis/expr_dt.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING

from narwhals._duration import Interval
from narwhals._ibis.utils import (
UNITS_DICT_BUCKET,
UNITS_DICT_TRUNCATE,
timedelta_to_ibis_interval,
)
from narwhals._ibis.utils import timedelta_to_ibis_interval
from narwhals._sql.expr_dt import SQLExprDateTimeNamesSpace
from narwhals._utils import not_implemented

if TYPE_CHECKING:
import ibis.expr.types as ir

from narwhals._ibis.expr import IbisExpr
from narwhals._ibis.utils import BucketUnit, TruncateUnit


class IbisExprDateTimeNamespace(SQLExprDateTimeNamesSpace["IbisExpr"]):
Expand All @@ -32,32 +25,6 @@ def weekday(self) -> IbisExpr:
# Ibis uses 0-6 for Monday-Sunday. Add 1 to match polars.
return self.compliant._with_callable(lambda expr: expr.day_of_week.index() + 1)

def _bucket(self, kwds: dict[BucketUnit, Any], /) -> Callable[..., ir.TimestampValue]:
def fn(expr: ir.TimestampValue) -> ir.TimestampValue:
return expr.bucket(**kwds)

return fn

def _truncate(self, unit: TruncateUnit, /) -> Callable[..., ir.TimestampValue]:
def fn(expr: ir.TimestampValue) -> ir.TimestampValue:
return expr.truncate(unit)

return fn

def truncate(self, every: str) -> IbisExpr:
interval = Interval.parse(every)
multiple, unit = interval.multiple, interval.unit
if unit == "q":
multiple, unit = 3 * multiple, "mo"
if multiple != 1:
if self.compliant._backend_version < (7, 1): # pragma: no cover
msg = "Truncating datetimes with multiples of the unit is only supported in Ibis >= 7.1."
raise NotImplementedError(msg)
fn = self._bucket({UNITS_DICT_BUCKET[unit]: multiple})
else:
fn = self._truncate(UNITS_DICT_TRUNCATE[unit])
return self.compliant._with_callable(fn)

def offset_by(self, by: str) -> IbisExpr:
interval = Interval.parse_no_constraints(by)
unit = interval.unit
Expand Down
14 changes: 13 additions & 1 deletion narwhals/_ibis/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import operator
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar

import ibis
import ibis.expr.types as ir
Expand All @@ -29,6 +29,18 @@

class IbisNamespace(SQLNamespace[IbisLazyFrame, IbisExpr, "ir.Table", "ir.Value"]):
_implementation: Implementation = Implementation.IBIS
UNITS_DICT: ClassVar = {
"y": "Y",
"q": "Q",
"mo": "M",
"d": "D",
"h": "h",
"m": "m",
"s": "s",
"ms": "ms",
"us": "us",
"ns": "ns",
}

def __init__(self, *, version: Version) -> None:
self._version = version
Expand Down
27 changes: 2 additions & 25 deletions narwhals/_ibis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,35 +53,10 @@ def lit(value: Any, dtype: Any | None = None) -> Incomplete:
desc_nulls_last = cast("SortFn", partial(ibis.desc, nulls_first=False))


BucketUnit: TypeAlias = Literal[
"years",
"quarters",
"months",
"days",
"hours",
"minutes",
"seconds",
"milliseconds",
"microseconds",
"nanoseconds",
]
TruncateUnit: TypeAlias = Literal[
"Y", "Q", "M", "W", "D", "h", "m", "s", "ms", "us", "ns"
]

UNITS_DICT_BUCKET: Mapping[IntervalUnit, BucketUnit] = {
"y": "years",
"q": "quarters",
"mo": "months",
"d": "days",
"h": "hours",
"m": "minutes",
"s": "seconds",
"ms": "milliseconds",
"us": "microseconds",
"ns": "nanoseconds",
}

UNITS_DICT_TRUNCATE: Mapping[IntervalUnit, TruncateUnit] = {
"y": "Y",
"q": "Q",
Expand Down Expand Up @@ -277,4 +252,6 @@ def function(name: str, *args: ir.Value | PythonLiteral) -> ir.Value:
if name == "substr":
# Ibis is 0-indexed here, SQL is 1-indexed
return cast("ir.StringColumn", expr).substr(args[1] - 1, *args[2:]) # type: ignore[operator] # pyright: ignore[reportArgumentType]
if name == "date_trunc":
return cast("ir.TimestampColumn", args[1]).truncate(args[0]) # pyright: ignore[reportArgumentType]
return getattr(expr, FUNCTION_REMAPPING.get(name, name))(*args[1:])
16 changes: 0 additions & 16 deletions narwhals/_spark_like/expr_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,6 @@ def _nanosecond(expr: Column) -> Column:
def weekday(self) -> SparkLikeExpr:
return self.compliant._with_elementwise(self._weekday)

def truncate(self, every: str) -> SparkLikeExpr:
interval = Interval.parse(every)
multiple, unit = interval.multiple, interval.unit
if multiple != 1:
msg = f"Only multiple 1 is currently supported for Spark-like.\nGot {multiple!s}."
raise ValueError(msg)
if unit == "ns":
msg = "Truncating to nanoseconds is not yet supported for Spark-like."
raise NotImplementedError(msg)
format = UNITS_DICT[unit]

def _truncate(expr: Column) -> Column:
return self.compliant._F.date_trunc(format, expr)

return self.compliant._with_elementwise(_truncate)

def offset_by(self, by: str) -> SparkLikeExpr:
interval = Interval.parse_no_constraints(by)
multiple, unit = interval.multiple, interval.unit
Expand Down
15 changes: 14 additions & 1 deletion narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import operator
from functools import reduce
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar

from narwhals._expression_parsing import (
combine_alias_output_names,
Expand Down Expand Up @@ -43,6 +43,19 @@
class SparkLikeNamespace(
SQLNamespace[SparkLikeLazyFrame, SparkLikeExpr, "SQLFrameDataFrame", "Column"]
):
UNITS_DICT: ClassVar = {
"y": "year",
"q": "quarter",
"mo": "month",
"d": "day",
"h": "hour",
"m": "minute",
"s": "second",
"ms": "millisecond",
"us": "microsecond",
"ns": "nanosecond",
}

def __init__(self, *, version: Version, implementation: Implementation) -> None:
self._version = version
self._implementation = implementation
Expand Down
21 changes: 21 additions & 0 deletions narwhals/_sql/expr_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import DateTimeNamespace
from narwhals._duration import Interval
from narwhals._sql.typing import SQLExprT


Expand All @@ -13,6 +14,9 @@ class SQLExprDateTimeNamesSpace(
def _function(self, name: str, *args: Any) -> SQLExprT:
return self.compliant._function(name, *args) # type: ignore[no-any-return]

def _lit(self, value: Any) -> SQLExprT:
return self.compliant._lit(value) # type: ignore[no-any-return]

def year(self) -> SQLExprT:
return self.compliant._with_elementwise(lambda expr: self._function("year", expr))

Expand Down Expand Up @@ -46,3 +50,20 @@ def date(self) -> SQLExprT:
return self.compliant._with_elementwise(
lambda expr: self._function("to_date", expr)
)

def truncate(self, every: str) -> SQLExprT:
interval = Interval.parse(every)
multiple, unit = interval.multiple, interval.unit
if multiple != 1:
msg = f"Only multiple 1 is currently supported for SQL-like backends.\nGot {multiple!s}."
raise ValueError(msg)
if unit == "ns":
msg = "Truncating to nanoseconds is not yet supported for Spark-like."
raise NotImplementedError(msg)
ns = self.compliant.__narwhals_namespace__()
format = ns.UNITS_DICT[unit]

def _truncate(expr: Any) -> Any:
return self._function("date_trunc", format, expr)

return self.compliant._with_elementwise(_truncate)
5 changes: 4 additions & 1 deletion narwhals/_sql/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@
from narwhals._sql.typing import SQLExprT, SQLLazyFrameT

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Iterable, Mapping

from narwhals._duration import IntervalUnit
from narwhals.typing import PythonLiteral


class SQLNamespace(
LazyNamespace[SQLLazyFrameT, SQLExprT, NativeFrameT],
Protocol[SQLLazyFrameT, SQLExprT, NativeFrameT, NativeExprT],
):
UNITS_DICT: Mapping[IntervalUnit, Any]

def _function(self, name: str, *args: NativeExprT | PythonLiteral) -> NativeExprT: ...
def _lit(self, value: Any) -> NativeExprT: ...
def _when(
Expand Down
4 changes: 2 additions & 2 deletions tests/expr_and_series/dt/truncate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ def test_truncate_multiples(
every: str,
expected: list[datetime],
) -> None:
if any(x in str(constructor) for x in ("cudf", "pyspark", "duckdb")):
if any(x in str(constructor) for x in ("cudf", "pyspark", "duckdb", "ibis")):
# Reasons:
# - cudf: https://github.com/rapidsai/cudf/issues/18654
# - pyspark/sqlframe: Only multiple 1 is currently supported
# - sql-like: Only multiple 1 is currently supported
request.applymarker(pytest.mark.xfail())
if every.endswith("ns") and any(
x in str(constructor) for x in ("polars", "duckdb", "ibis")
Expand Down
Loading