Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import contextlib
from functools import reduce
from operator import and_
from typing import TYPE_CHECKING, Any, Iterator, Mapping, Sequence
Expand Down Expand Up @@ -47,9 +46,6 @@
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
from narwhals.utils import _FullContext

with contextlib.suppress(ImportError): # requires duckdb>=1.3.0
from duckdb import SQLExpression


class DuckDBLazyFrame(
CompliantLazyFrame[
Expand Down Expand Up @@ -382,6 +378,8 @@ def unique(
"with `subset` specified."
)
raise NotImplementedError(msg)
from duckdb import SQLExpression

# Sanitise input
if any(x not in self.columns for x in subset_):
msg = f"Columns {set(subset_).difference(self.columns)} not found in {self.columns}."
Expand Down
24 changes: 20 additions & 4 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import contextlib
import operator
from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence, cast

Expand Down Expand Up @@ -46,9 +45,6 @@
)
from narwhals.utils import Version, _FullContext

with contextlib.suppress(ImportError): # requires duckdb>=1.3.0
from duckdb import SQLExpression


class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "Expression"]):
_implementation = Implementation.DUCKDB
Expand Down Expand Up @@ -95,6 +91,8 @@ def _cum_window_func(
reverse: bool,
func_name: Literal["sum", "max", "min", "count", "product"],
) -> WindowFunction:
from duckdb import SQLExpression

def func(window_inputs: WindowInputs) -> Expression:
order_by_sql = generate_order_by_sql(
*window_inputs.order_by, ascending=not reverse
Expand All @@ -117,6 +115,8 @@ def _rolling_window_func(
min_samples: int,
ddof: int | None = None,
) -> WindowFunction:
from duckdb import SQLExpression

ensure_type(window_size, int, type(None))
ensure_type(min_samples, int)
supported_funcs = ["sum", "mean", "std", "var"]
Expand Down Expand Up @@ -162,6 +162,7 @@ def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Se
if self._backend_version < (1, 3):
msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns."
raise NotImplementedError(msg)
from duckdb import SQLExpression

template = "{expr} over ()"

Expand Down Expand Up @@ -492,6 +493,8 @@ def null_count(self) -> Self:

@requires.backend_version((1, 3))
def over(self, partition_by: Sequence[str], order_by: Sequence[str] | None) -> Self:
from duckdb import SQLExpression

if (window_function := self._window_function) is not None:
assert order_by is not None # noqa: S101

Expand Down Expand Up @@ -551,6 +554,8 @@ def round(self, decimals: int) -> Self:

@requires.backend_version((1, 3))
def shift(self, n: int) -> Self:
from duckdb import SQLExpression

ensure_type(n, int)

def func(window_inputs: WindowInputs) -> Expression:
Expand All @@ -565,6 +570,8 @@ def func(window_inputs: WindowInputs) -> Expression:

@requires.backend_version((1, 3))
def is_first_distinct(self) -> Self:
from duckdb import SQLExpression

def func(window_inputs: WindowInputs) -> Expression:
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
if window_inputs.partition_by:
Expand All @@ -581,6 +588,8 @@ def func(window_inputs: WindowInputs) -> Expression:

@requires.backend_version((1, 3))
def is_last_distinct(self) -> Self:
from duckdb import SQLExpression

def func(window_inputs: WindowInputs) -> Expression:
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=False)
if window_inputs.partition_by:
Expand All @@ -597,6 +606,8 @@ def func(window_inputs: WindowInputs) -> Expression:

@requires.backend_version((1, 3))
def diff(self) -> Self:
from duckdb import SQLExpression

def func(window_inputs: WindowInputs) -> Expression:
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
Expand Down Expand Up @@ -695,6 +706,7 @@ def fill_null(
if self._backend_version < (1, 3): # pragma: no cover
msg = f"`fill_null` with `strategy={strategy}` is only available in 'duckdb>=1.3.0'."
raise NotImplementedError(msg)
from duckdb import SQLExpression

def _fill_with_strategy(window_inputs: WindowInputs) -> Expression:
order_by_sql = generate_order_by_sql(
Expand Down Expand Up @@ -731,6 +743,8 @@ def func(expr: Expression) -> Expression:

@requires.backend_version((1, 3))
def is_unique(self) -> Self:
from duckdb import SQLExpression

def func(expr: Expression) -> Expression:
sql = f"count(*) over (partition by {expr})"
return SQLExpression(sql) == lit(1) # type: ignore[no-any-return, unused-ignore]
Expand All @@ -739,6 +753,8 @@ def func(expr: Expression) -> Expression:

@requires.backend_version((1, 3))
def rank(self, method: RankMethod, *, descending: bool) -> Self:
from duckdb import SQLExpression

if method in {"min", "max", "average"}:
func = FunctionExpression("rank")
elif method == "dense":
Expand Down
3 changes: 2 additions & 1 deletion narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901,
results = [result_frame.get_column(name) for name in aliases]
if order_by:
for s in results:
s._scatter_in_place(sorting_indices, s)
# `sorting_indices` was initialised in `if order_by` block above.
s._scatter_in_place(sorting_indices, s) # pyright: ignore[reportPossiblyUnboundVariable]
return results
if reverse:
return [s._gather_slice(slice(None, None, -1)) for s in results]
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,10 @@ warn_return_any = false
pythonPlatform = "All"
# NOTE (stubs do unsafe `TypeAlias` and `TypeVar` imports)
# pythonVersion = "3.8"
reportFunctionMemberAccess = "error"
reportMissingImports = "none"
reportMissingModuleSource = "none"
reportPossiblyUnboundVariable = "error"
reportPrivateImportUsage = "none"
reportUnusedExpression = "none" # handled by (https://docs.astral.sh/ruff/rules/unused-variable/)
typeCheckingMode = "basic"
Expand Down
Loading