diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 42d5117946..2acc228dad 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -180,15 +180,15 @@ def aggregate(self, *exprs: DuckDBExpr) -> Self: selection = [val.alias(name) for name, val in evaluate_exprs(self, *exprs)] try: return self._with_native(self.native.aggregate(selection)) # type: ignore[arg-type] - except Exception as e: # noqa: BLE001 - raise catch_duckdb_exception(e, self) from None + except Exception as e: + raise catch_duckdb_exception(e, self) from e def select(self, *exprs: DuckDBExpr) -> Self: selection = (val.alias(name) for name, val in evaluate_exprs(self, *exprs)) try: return self._with_native(self.native.select(*selection)) - except Exception as e: # noqa: BLE001 - raise catch_duckdb_exception(e, self) from None + except Exception as e: + raise catch_duckdb_exception(e, self) from e def drop(self, columns: Sequence[str], *, strict: bool) -> Self: columns_to_drop = parse_columns_to_drop(self, columns, strict=strict) diff --git a/narwhals/_duckdb/expr_dt.py b/narwhals/_duckdb/expr_dt.py index 43bc249416..1386da0528 100644 --- a/narwhals/_duckdb/expr_dt.py +++ b/narwhals/_duckdb/expr_dt.py @@ -74,23 +74,6 @@ def total_microseconds(self) -> DuckDBExpr: + F("datepart", lit("microsecond"), expr) ) - def truncate(self, every: str) -> DuckDBExpr: - interval = Interval.parse(every) - multiple, unit = interval.multiple, interval.unit - if multiple != 1: - # https://github.com/duckdb/duckdb/issues/17554 - msg = f"Only multiple 1 is currently supported for DuckDB.\nGot {multiple!s}." - raise ValueError(msg) - if unit == "ns": - msg = "Truncating to nanoseconds is not yet supported for DuckDB." - raise NotImplementedError(msg) - format = lit(UNITS_DICT[unit]) - - def _truncate(expr: Expression) -> Expression: - return F("date_trunc", format, expr) - - return self.compliant._with_elementwise(_truncate) - def offset_by(self, by: str) -> DuckDBExpr: interval = Interval.parse_no_constraints(by) format = lit(f"{interval.multiple!s} {UNITS_DICT[interval.unit]}") diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 09b5ecd8eb..b149637632 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -3,7 +3,7 @@ import operator from functools import reduce from itertools import chain -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar import duckdb from duckdb import CoalesceOperator, Expression @@ -44,6 +44,18 @@ class DuckDBNamespace( SQLNamespace[DuckDBLazyFrame, DuckDBExpr, "DuckDBPyRelation", Expression] ): _implementation: Implementation = Implementation.DUCKDB + UNITS_DICT: ClassVar = { + "y": lit("year"), + "q": lit("quarter"), + "mo": lit("month"), + "d": lit("day"), + "h": lit("hour"), + "m": lit("minute"), + "s": lit("second"), + "ms": lit("millisecond"), + "us": lit("microsecond"), + "ns": lit("nanosecond"), + } def __init__(self, *, version: Version) -> None: self._version = version diff --git a/narwhals/_ibis/expr_dt.py b/narwhals/_ibis/expr_dt.py index 7d98dba0df..2b0dd75f56 100644 --- a/narwhals/_ibis/expr_dt.py +++ b/narwhals/_ibis/expr_dt.py @@ -1,21 +1,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING from narwhals._duration import Interval -from narwhals._ibis.utils import ( - UNITS_DICT_BUCKET, - UNITS_DICT_TRUNCATE, - timedelta_to_ibis_interval, -) +from narwhals._ibis.utils import timedelta_to_ibis_interval from narwhals._sql.expr_dt import SQLExprDateTimeNamesSpace from narwhals._utils import not_implemented if TYPE_CHECKING: - import ibis.expr.types as ir - from narwhals._ibis.expr import IbisExpr - from narwhals._ibis.utils import BucketUnit, TruncateUnit class IbisExprDateTimeNamespace(SQLExprDateTimeNamesSpace["IbisExpr"]): @@ -32,32 +25,6 @@ def weekday(self) -> IbisExpr: # Ibis uses 0-6 for Monday-Sunday. Add 1 to match polars. return self.compliant._with_callable(lambda expr: expr.day_of_week.index() + 1) - def _bucket(self, kwds: dict[BucketUnit, Any], /) -> Callable[..., ir.TimestampValue]: - def fn(expr: ir.TimestampValue) -> ir.TimestampValue: - return expr.bucket(**kwds) - - return fn - - def _truncate(self, unit: TruncateUnit, /) -> Callable[..., ir.TimestampValue]: - def fn(expr: ir.TimestampValue) -> ir.TimestampValue: - return expr.truncate(unit) - - return fn - - def truncate(self, every: str) -> IbisExpr: - interval = Interval.parse(every) - multiple, unit = interval.multiple, interval.unit - if unit == "q": - multiple, unit = 3 * multiple, "mo" - if multiple != 1: - if self.compliant._backend_version < (7, 1): # pragma: no cover - msg = "Truncating datetimes with multiples of the unit is only supported in Ibis >= 7.1." - raise NotImplementedError(msg) - fn = self._bucket({UNITS_DICT_BUCKET[unit]: multiple}) - else: - fn = self._truncate(UNITS_DICT_TRUNCATE[unit]) - return self.compliant._with_callable(fn) - def offset_by(self, by: str) -> IbisExpr: interval = Interval.parse_no_constraints(by) unit = interval.unit diff --git a/narwhals/_ibis/namespace.py b/narwhals/_ibis/namespace.py index 3509d805fc..4982b3d310 100644 --- a/narwhals/_ibis/namespace.py +++ b/narwhals/_ibis/namespace.py @@ -3,7 +3,7 @@ import operator from functools import reduce from itertools import chain -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar import ibis import ibis.expr.types as ir @@ -29,6 +29,18 @@ class IbisNamespace(SQLNamespace[IbisLazyFrame, IbisExpr, "ir.Table", "ir.Value"]): _implementation: Implementation = Implementation.IBIS + UNITS_DICT: ClassVar = { + "y": "Y", + "q": "Q", + "mo": "M", + "d": "D", + "h": "h", + "m": "m", + "s": "s", + "ms": "ms", + "us": "us", + "ns": "ns", + } def __init__(self, *, version: Version) -> None: self._version = version diff --git a/narwhals/_ibis/utils.py b/narwhals/_ibis/utils.py index 09a8ce2ffb..9c81a78755 100644 --- a/narwhals/_ibis/utils.py +++ b/narwhals/_ibis/utils.py @@ -53,35 +53,10 @@ def lit(value: Any, dtype: Any | None = None) -> Incomplete: desc_nulls_last = cast("SortFn", partial(ibis.desc, nulls_first=False)) -BucketUnit: TypeAlias = Literal[ - "years", - "quarters", - "months", - "days", - "hours", - "minutes", - "seconds", - "milliseconds", - "microseconds", - "nanoseconds", -] TruncateUnit: TypeAlias = Literal[ "Y", "Q", "M", "W", "D", "h", "m", "s", "ms", "us", "ns" ] -UNITS_DICT_BUCKET: Mapping[IntervalUnit, BucketUnit] = { - "y": "years", - "q": "quarters", - "mo": "months", - "d": "days", - "h": "hours", - "m": "minutes", - "s": "seconds", - "ms": "milliseconds", - "us": "microseconds", - "ns": "nanoseconds", -} - UNITS_DICT_TRUNCATE: Mapping[IntervalUnit, TruncateUnit] = { "y": "Y", "q": "Q", @@ -277,4 +252,6 @@ def function(name: str, *args: ir.Value | PythonLiteral) -> ir.Value: if name == "substr": # Ibis is 0-indexed here, SQL is 1-indexed return cast("ir.StringColumn", expr).substr(args[1] - 1, *args[2:]) # type: ignore[operator] # pyright: ignore[reportArgumentType] + if name == "date_trunc": + return cast("ir.TimestampColumn", args[1]).truncate(args[0]) # pyright: ignore[reportArgumentType] return getattr(expr, FUNCTION_REMAPPING.get(name, name))(*args[1:]) diff --git a/narwhals/_spark_like/expr_dt.py b/narwhals/_spark_like/expr_dt.py index b5f4752425..feb2800769 100644 --- a/narwhals/_spark_like/expr_dt.py +++ b/narwhals/_spark_like/expr_dt.py @@ -83,22 +83,6 @@ def _nanosecond(expr: Column) -> Column: def weekday(self) -> SparkLikeExpr: return self.compliant._with_elementwise(self._weekday) - def truncate(self, every: str) -> SparkLikeExpr: - interval = Interval.parse(every) - multiple, unit = interval.multiple, interval.unit - if multiple != 1: - msg = f"Only multiple 1 is currently supported for Spark-like.\nGot {multiple!s}." - raise ValueError(msg) - if unit == "ns": - msg = "Truncating to nanoseconds is not yet supported for Spark-like." - raise NotImplementedError(msg) - format = UNITS_DICT[unit] - - def _truncate(expr: Column) -> Column: - return self.compliant._F.date_trunc(format, expr) - - return self.compliant._with_elementwise(_truncate) - def offset_by(self, by: str) -> SparkLikeExpr: interval = Interval.parse_no_constraints(by) multiple, unit = interval.multiple, interval.unit diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 2e0be736ca..d0c233bcb8 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -2,7 +2,7 @@ import operator from functools import reduce -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar from narwhals._expression_parsing import ( combine_alias_output_names, @@ -43,6 +43,19 @@ class SparkLikeNamespace( SQLNamespace[SparkLikeLazyFrame, SparkLikeExpr, "SQLFrameDataFrame", "Column"] ): + UNITS_DICT: ClassVar = { + "y": "year", + "q": "quarter", + "mo": "month", + "d": "day", + "h": "hour", + "m": "minute", + "s": "second", + "ms": "millisecond", + "us": "microsecond", + "ns": "nanosecond", + } + def __init__(self, *, version: Version, implementation: Implementation) -> None: self._version = version self._implementation = implementation diff --git a/narwhals/_sql/expr_dt.py b/narwhals/_sql/expr_dt.py index 85b65aaf05..07d79800ee 100644 --- a/narwhals/_sql/expr_dt.py +++ b/narwhals/_sql/expr_dt.py @@ -4,6 +4,7 @@ from narwhals._compliant import LazyExprNamespace from narwhals._compliant.any_namespace import DateTimeNamespace +from narwhals._duration import Interval from narwhals._sql.typing import SQLExprT @@ -13,6 +14,9 @@ class SQLExprDateTimeNamesSpace( def _function(self, name: str, *args: Any) -> SQLExprT: return self.compliant._function(name, *args) # type: ignore[no-any-return] + def _lit(self, value: Any) -> SQLExprT: + return self.compliant._lit(value) # type: ignore[no-any-return] + def year(self) -> SQLExprT: return self.compliant._with_elementwise(lambda expr: self._function("year", expr)) @@ -46,3 +50,20 @@ def date(self) -> SQLExprT: return self.compliant._with_elementwise( lambda expr: self._function("to_date", expr) ) + + def truncate(self, every: str) -> SQLExprT: + interval = Interval.parse(every) + multiple, unit = interval.multiple, interval.unit + if multiple != 1: + msg = f"Only multiple 1 is currently supported for SQL-like backends.\nGot {multiple!s}." + raise ValueError(msg) + if unit == "ns": + msg = "Truncating to nanoseconds is not yet supported for Spark-like." + raise NotImplementedError(msg) + ns = self.compliant.__narwhals_namespace__() + format = ns.UNITS_DICT[unit] + + def _truncate(expr: Any) -> Any: + return self._function("date_trunc", format, expr) + + return self.compliant._with_elementwise(_truncate) diff --git a/narwhals/_sql/namespace.py b/narwhals/_sql/namespace.py index dee8a7e470..304333ca3f 100644 --- a/narwhals/_sql/namespace.py +++ b/narwhals/_sql/namespace.py @@ -9,8 +9,9 @@ from narwhals._sql.typing import SQLExprT, SQLLazyFrameT if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Mapping + from narwhals._duration import IntervalUnit from narwhals.typing import PythonLiteral @@ -18,6 +19,8 @@ class SQLNamespace( LazyNamespace[SQLLazyFrameT, SQLExprT, NativeFrameT], Protocol[SQLLazyFrameT, SQLExprT, NativeFrameT, NativeExprT], ): + UNITS_DICT: Mapping[IntervalUnit, Any] + def _function(self, name: str, *args: NativeExprT | PythonLiteral) -> NativeExprT: ... def _lit(self, value: Any) -> NativeExprT: ... def _when( diff --git a/tests/expr_and_series/dt/truncate_test.py b/tests/expr_and_series/dt/truncate_test.py index 120d030878..d7ec5f1206 100644 --- a/tests/expr_and_series/dt/truncate_test.py +++ b/tests/expr_and_series/dt/truncate_test.py @@ -104,10 +104,10 @@ def test_truncate_multiples( every: str, expected: list[datetime], ) -> None: - if any(x in str(constructor) for x in ("cudf", "pyspark", "duckdb")): + if any(x in str(constructor) for x in ("cudf", "pyspark", "duckdb", "ibis")): # Reasons: # - cudf: https://github.com/rapidsai/cudf/issues/18654 - # - pyspark/sqlframe: Only multiple 1 is currently supported + # - sql-like: Only multiple 1 is currently supported request.applymarker(pytest.mark.xfail()) if every.endswith("ns") and any( x in str(constructor) for x in ("polars", "duckdb", "ibis")