diff --git a/narwhals/functions.py b/narwhals/functions.py index 1b9cc3b264..bd2d4ae1c3 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -564,8 +564,36 @@ def show_versions() -> None: print(f"{k:>13}: {stat}") # noqa: T201 +def _validate_separator(separator: str, native_separator: str, **kwargs: Any) -> None: + if native_separator in kwargs and kwargs[native_separator] != separator: + msg = ( + f"`separator` and `{native_separator}` do not match: " + f"`separator`={separator} and `{native_separator}`={kwargs[native_separator]}." + ) + raise TypeError(msg) + + +def _validate_separator_pyarrow(separator: str, **kwargs: Any) -> Any: + if "parse_options" in kwargs: + parse_options = kwargs.pop("parse_options") + if parse_options.delimiter != separator: + msg = ( + "`separator` and `parse_options.delimiter` do not match: " + f"`separator`={separator} and `delimiter`={parse_options.delimiter}." + ) + raise TypeError(msg) + return kwargs + from pyarrow import csv # ignore-banned-import + + return {"parse_options": csv.ParseOptions(delimiter=separator)} + + def read_csv( - source: str, *, backend: IntoBackend[EagerAllowed], **kwargs: Any + source: str, + *, + backend: IntoBackend[EagerAllowed], + separator: str = ",", + **kwargs: Any, ) -> DataFrame[Any]: """Read a CSV file into a DataFrame. @@ -578,6 +606,7 @@ def read_csv( `POLARS`, `MODIN` or `CUDF`. - As a string: `"pandas"`, `"pyarrow"`, `"polars"`, `"modin"` or `"cudf"`. - Directly as a module `pandas`, `pyarrow`, `polars`, `modin` or `cudf`. + separator: Single byte character to use as separator in the file. kwargs: Extra keyword arguments which are passed to the native CSV reader. For example, you could use `nw.read_csv('file.csv', backend='pandas', engine='pyarrow')`. @@ -599,14 +628,13 @@ def read_csv( impl = Implementation.from_backend(backend) native_namespace = impl.to_native_namespace() native_frame: NativeDataFrame - if impl in { - Implementation.POLARS, - Implementation.PANDAS, - Implementation.MODIN, - Implementation.CUDF, - }: - native_frame = native_namespace.read_csv(source, **kwargs) + if impl in {Implementation.PANDAS, Implementation.MODIN, Implementation.CUDF}: + _validate_separator(separator, "sep", **kwargs) + native_frame = native_namespace.read_csv(source, sep=separator, **kwargs) + elif impl is Implementation.POLARS: + native_frame = native_namespace.read_csv(source, separator=separator, **kwargs) elif impl is Implementation.PYARROW: + kwargs = _validate_separator_pyarrow(separator, **kwargs) from pyarrow import csv # ignore-banned-import native_frame = csv.read_csv(source, **kwargs) @@ -635,7 +663,7 @@ def read_csv( def scan_csv( - source: str, *, backend: IntoBackend[Backend], **kwargs: Any + source: str, *, backend: IntoBackend[Backend], separator: str = ",", **kwargs: Any ) -> LazyFrame[Any]: """Lazily read from a CSV file. @@ -651,6 +679,7 @@ def scan_csv( `POLARS`, `MODIN` or `CUDF`. - As a string: `"pandas"`, `"pyarrow"`, `"polars"`, `"modin"` or `"cudf"`. - Directly as a module `pandas`, `pyarrow`, `polars`, `modin` or `cudf`. + separator: Single byte character to use as separator in the file. kwargs: Extra keyword arguments which are passed to the native CSV reader. For example, you could use `nw.scan_csv('file.csv', backend=pd, engine='pyarrow')`. @@ -676,33 +705,40 @@ def scan_csv( native_namespace = implementation.to_native_namespace() native_frame: NativeDataFrame | NativeLazyFrame if implementation is Implementation.POLARS: - native_frame = native_namespace.scan_csv(source, **kwargs) + native_frame = native_namespace.scan_csv(source, separator=separator, **kwargs) elif implementation in { Implementation.PANDAS, Implementation.MODIN, Implementation.CUDF, Implementation.DASK, - Implementation.DUCKDB, Implementation.IBIS, }: - native_frame = native_namespace.read_csv(source, **kwargs) + _validate_separator(separator, "sep", **kwargs) + native_frame = native_namespace.read_csv(source, sep=separator, **kwargs) + elif implementation is Implementation.DUCKDB: + _validate_separator(separator, "delimiter", **kwargs) + _validate_separator(separator, "delim", **kwargs) + native_frame = native_namespace.read_csv(source, delimiter=separator, **kwargs) elif implementation is Implementation.PYARROW: + kwargs = _validate_separator_pyarrow(separator, **kwargs) from pyarrow import csv # ignore-banned-import native_frame = csv.read_csv(source, **kwargs) elif implementation.is_spark_like(): + _validate_separator(separator, "sep", **kwargs) + _validate_separator(separator, "delimiter", **kwargs) if (session := kwargs.pop("session", None)) is None: msg = "Spark like backends require a session object to be passed in `kwargs`." raise ValueError(msg) csv_reader = session.read.format("csv") native_frame = ( - csv_reader.load(source) + csv_reader.load(source, sep=separator) if ( implementation is Implementation.SQLFRAME and implementation._backend_version() < (3, 27, 0) ) - else csv_reader.options(**kwargs).load(source) + else csv_reader.options(sep=separator, **kwargs).load(source) ) else: # pragma: no cover try: diff --git a/narwhals/stable/v2/__init__.py b/narwhals/stable/v2/__init__.py index fe8f7a71b0..1f3c2d457b 100644 --- a/narwhals/stable/v2/__init__.py +++ b/narwhals/stable/v2/__init__.py @@ -1138,7 +1138,11 @@ def from_numpy( def read_csv( - source: str, *, backend: IntoBackend[EagerAllowed], **kwargs: Any + source: str, + *, + backend: IntoBackend[EagerAllowed], + separator: str = ",", + **kwargs: Any, ) -> DataFrame[Any]: """Read a CSV file into a DataFrame. @@ -1151,6 +1155,7 @@ def read_csv( `POLARS`, `MODIN` or `CUDF`. - As a string: `"pandas"`, `"pyarrow"`, `"polars"`, `"modin"` or `"cudf"`. - Directly as a module `pandas`, `pyarrow`, `polars`, `modin` or `cudf`. + separator: Single byte character to use as separator in the file. kwargs: Extra keyword arguments which are passed to the native CSV reader. For example, you could use `nw.read_csv('file.csv', backend='pandas', engine='pyarrow')`. @@ -1158,11 +1163,13 @@ def read_csv( Returns: DataFrame. """ - return _stableify(nw_f.read_csv(source, backend=backend, **kwargs)) + return _stableify( + nw_f.read_csv(source, backend=backend, separator=separator, **kwargs) + ) def scan_csv( - source: str, *, backend: IntoBackend[Backend], **kwargs: Any + source: str, *, backend: IntoBackend[Backend], separator: str = ",", **kwargs: Any ) -> LazyFrame[Any]: """Lazily read from a CSV file. @@ -1178,6 +1185,7 @@ def scan_csv( `POLARS`, `MODIN` or `CUDF`. - As a string: `"pandas"`, `"pyarrow"`, `"polars"`, `"modin"` or `"cudf"`. - Directly as a module `pandas`, `pyarrow`, `polars`, `modin` or `cudf`. + separator: Single byte character to use as separator in the file. kwargs: Extra keyword arguments which are passed to the native CSV reader. For example, you could use `nw.scan_csv('file.csv', backend=pd, engine='pyarrow')`. @@ -1185,7 +1193,9 @@ def scan_csv( Returns: LazyFrame. """ - return _stableify(nw_f.scan_csv(source, backend=backend, **kwargs)) + return _stableify( + nw_f.scan_csv(source, backend=backend, separator=separator, **kwargs) + ) def read_parquet( diff --git a/tests/read_scan_test.py b/tests/read_scan_test.py index dce09cade5..c341424d51 100644 --- a/tests/read_scan_test.py +++ b/tests/read_scan_test.py @@ -27,6 +27,10 @@ def test_read_csv(tmpdir: pytest.TempdirFactory, eager_backend: EagerAllowed) -> result = nw.read_csv(filepath, backend=eager_backend) assert_equal_data(result, data) assert isinstance(result, nw.DataFrame) + df_pl.write_csv(filepath, separator=";") + result = nw.read_csv(filepath, backend=eager_backend, separator=";") + assert_equal_data(result, data) + assert isinstance(result, nw.DataFrame) @pytest.mark.skipif(PANDAS_VERSION < (1, 5), reason="too old for pyarrow") @@ -90,15 +94,87 @@ def test_scan_csv(tmpdir: pytest.TempdirFactory, constructor: Constructor) -> No result = nw.scan_csv(filepath, backend=backend, **kwargs) assert_equal_data(result, data) assert isinstance(result, nw.LazyFrame) + df_pl.write_csv(filepath, separator="|") + df = nw.from_native(constructor(data)) + backend = nw.get_native_namespace(df) + result = nw.scan_csv(filepath, backend=backend, separator="|", **kwargs) + assert_equal_data(result, data) + assert isinstance(result, nw.LazyFrame) @pytest.mark.skipif(PANDAS_VERSION < (1, 5), reason="too old for pyarrow") def test_scan_csv_kwargs(tmpdir: pytest.TempdirFactory) -> None: + pytest.importorskip("pyarrow") + from pyarrow import csv + df_pl = pl.DataFrame(data) filepath = str(tmpdir / "file.csv") # type: ignore[operator] df_pl.write_csv(filepath) result = nw.scan_csv(filepath, backend=pd, engine="pyarrow") assert_equal_data(result, data) + result = nw.scan_csv( + filepath, backend="pyarrow", parse_options=csv.ParseOptions(delimiter=",") + ) + assert_equal_data(result, data) + + +def test_read_csv_raise_sep_multiple(tmpdir: pytest.TempdirFactory) -> None: + pytest.importorskip("duckdb") + pytest.importorskip("pandas") + pytest.importorskip("pyarrow") + pytest.importorskip("sqlframe") + import duckdb + import pandas as pd + import pyarrow as pa + import sqlframe + from pyarrow import csv + from sqlframe.duckdb import DuckDBSession + + df_pl = pl.DataFrame(data) + filepath = str(tmpdir / "file.csv") # type: ignore[operator] + df_pl.write_csv(filepath) + + msg = "do not match:" + with pytest.raises(TypeError, match=msg): + nw.read_csv( + filepath, + backend=pa, + separator="|", + parse_options=csv.ParseOptions(delimiter=";"), + ) + with pytest.raises(TypeError, match=msg): + nw.scan_csv( + filepath, + backend=pa, + separator="|", + parse_options=csv.ParseOptions(delimiter=";"), + ) + with pytest.raises(TypeError, match=msg): + nw.read_csv(filepath, backend=pd, separator="|", sep=";") + with pytest.raises(TypeError, match=msg): + nw.scan_csv(filepath, backend=pd, separator="|", sep=";") + with pytest.raises(TypeError, match=msg): + nw.scan_csv(filepath, backend=duckdb, separator="|", delimiter=";") + with pytest.raises(TypeError, match=msg): + nw.scan_csv(filepath, backend=duckdb, separator="|", delim=";") + with pytest.raises(TypeError, match=msg): + nw.scan_csv( + filepath, + backend=sqlframe, + separator="|", + sep=";", + session=DuckDBSession(), + inferSchema=True, + ) + with pytest.raises(TypeError, match=msg): + nw.scan_csv( + filepath, + backend=sqlframe, + separator="|", + delimiter=";", + session=DuckDBSession(), + inferSchema=True, + ) @pytest.mark.skipif(PANDAS_VERSION < (1, 5), reason="too old for pyarrow")