diff --git a/docs/api-reference/dataframe.md b/docs/api-reference/dataframe.md index 4105a376d1..29798df203 100644 --- a/docs/api-reference/dataframe.md +++ b/docs/api-reference/dataframe.md @@ -6,6 +6,7 @@ members: - __arrow_c_stream__ - __getitem__ + - clear - clone - collect_schema - columns diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index becec1d051..198864d8da 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -758,4 +758,11 @@ def unpivot( # TODO(Unassigned): Even with promote_options="permissive", pyarrow does not # upcast numeric to non-numeric (e.g. string) datatypes + def clear(self, n: int) -> Self: + schema = self.native.schema + data = { + name: pa.nulls(n, dtype) for name, dtype in zip(schema.names, schema.types) + } + return self._with_native(pa.table(data)) + pivot = not_implemented() diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 94bb1207b0..843e7e0e49 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -151,6 +151,7 @@ def columns(self) -> Sequence[str]: ... def schema(self) -> Mapping[str, DType]: ... @property def shape(self) -> tuple[int, int]: ... + def clear(self, n: int) -> Self: ... def clone(self) -> Self: ... def collect( self, backend: Implementation | None, **kwargs: Any diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index badafde27a..52dcce68b1 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -1137,3 +1137,14 @@ def explode(self, columns: Sequence[str]) -> Self: plx.concat([exploded_frame, *exploded_series], axis=1)[original_columns], validate_column_names=False, ) + + def clear(self, n: int) -> Self: + if n == 0: + return self.head(0) + + ns = self.__native_namespace__() + + native_dtypes = self.native.dtypes + schema = {col: native_dtypes[col] for col in self.native.columns} + result = ns.DataFrame(ns.NA, index=range(n), columns=self.columns).astype(schema) + return self._with_native(result) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 31136f36c6..7936476738 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -66,6 +66,7 @@ # DataFrame methods where PolarsDataFrame just defers to Polars.DataFrame directly. INHERITED_METHODS = frozenset( [ + "clear", "clone", "drop_nulls", "estimated_size", diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index be923f61f3..4d02de66a5 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -2283,6 +2283,12 @@ def explode(self, columns: str | Sequence[str], *more_columns: str) -> Self: """ return super().explode(columns, *more_columns) + def clear(self, n: int = 0) -> Self: + if n < 0: + msg = f"`n` should be greater than or equal to 0, got {n}" + raise ValueError(msg) + return self._with_compliant(self._compliant_frame.clear(n=n)) + class LazyFrame(BaseFrame[FrameT]): """Narwhals LazyFrame, backed by a native lazyframe. diff --git a/tests/frame/clear_test.py b/tests/frame/clear_test.py new file mode 100644 index 0000000000..d92e252265 --- /dev/null +++ b/tests/frame/clear_test.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import pytest + +import narwhals as nw +from tests.utils import ConstructorEager, assert_equal_data + + +@pytest.mark.parametrize("n", [0, 1, 10]) +def test_clear(constructor_eager: ConstructorEager, n: int) -> None: + data = { + "int": [1, 2, 3], + "str": ["foo", "bar", "baz"], + "float": [0.1, 0.2, 0.3], + "bool": [True, False, True], + } + df = nw.from_native(constructor_eager(data), eager_only=True) + df_clear = df.clear(n=n) + + assert len(df_clear) == n + assert df.schema == df_clear.schema + + assert_equal_data(df_clear, {k: [None] * n for k in data}) + + +def test_clear_negative(constructor_eager: ConstructorEager) -> None: + n = -1 + data = {"a": [1, 2, 3]} + df = nw.from_native(constructor_eager(data), eager_only=True) + + msg = f"`n` should be greater than or equal to 0, got {n}" + with pytest.raises(ValueError, match=msg): + df.clear(n=n)