Skip to content

Commit 110b817

Browse files
authored
Add validation of loader kwargs to dataset_load (#241)
As discussed on #238: #238 (comment) http://b/388077145
1 parent 4fe7fd0 commit 110b817

File tree

3 files changed

+101
-5
lines changed

3 files changed

+101
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
* Fix `model_signing` breaking changes from `1.0.0` release
66
* Add `KaggleDatasetAdapter.POLARS` support to `dataset_load`
7+
* Add validation of kwargs to `dataset_load`
78

89
## v0.3.11 (April 1, 2025)
910

src/kagglehub/datasets.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
# Patterns that are always ignored for dataset uploading.
1616
DEFAULT_IGNORE_PATTERNS = [".git/", "*/.git/", ".cache/", ".huggingface/"]
1717
# Mapping of adapters to the optional dependencies required to run them
18-
LOAD_DATASET_ADAPTER_OPTIONAL_DEPENDENCIES_MAP = {
18+
DATASET_LOAD_ADAPTER_OPTIONAL_DEPENDENCIES_MAP = {
1919
KaggleDatasetAdapter.HUGGING_FACE: "hf-datasets",
2020
KaggleDatasetAdapter.PANDAS: "pandas-datasets",
2121
KaggleDatasetAdapter.POLARS: "polars-datasets",
2222
}
2323

24+
# Mapping of adapters to the valid kwargs to use for that adapter
25+
_DATASET_LOAD_VALID_KWARGS_MAP = {
26+
KaggleDatasetAdapter.HUGGING_FACE: {"hf_kwargs", "pandas_kwargs", "sql_query"},
27+
KaggleDatasetAdapter.PANDAS: {"pandas_kwargs", "sql_query"},
28+
KaggleDatasetAdapter.POLARS: {"sql_query", "polars_frame_type", "polars_kwargs"},
29+
}
30+
2431

2532
def dataset_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
2633
"""Download dataset files
@@ -80,7 +87,7 @@ def dataset_load(
8087
pandas_kwargs: Any = None, # noqa: ANN401
8188
sql_query: Optional[str] = None,
8289
hf_kwargs: Any = None, # noqa: ANN401
83-
polars_frame_type: PolarsFrameType = PolarsFrameType.LAZY_FRAME,
90+
polars_frame_type: Optional[PolarsFrameType] = None,
8491
polars_kwargs: Any = None, # noqa: ANN401
8592
) -> Any: # noqa: ANN401
8693
"""Load a Kaggle Dataset into a python object based on the selected adapter
@@ -96,7 +103,8 @@ def dataset_load(
96103
for details: https://pandas.pydata.org/docs/reference/api/pandas.read_sql_query.html
97104
hf_kwargs:
98105
(dict) Optional set of kwargs to pass to Dataset.from_pandas() while constructing the Dataset
99-
polars_frame_type: (PolarsFrameType) Optional value to control what type of frame to return from polars
106+
polars_frame_type: (PolarsFrameType) Optional control for which Frame to return: LazyFrame or DataFrame. The
107+
default is PolarsFrameType.LAZY_FRAME.
100108
polars_kwargs:
101109
(dict) Optional set of kwargs to pass to the polars `read_*` method while constructing the DataFrame(s)
102110
Returns:
@@ -107,6 +115,15 @@ def dataset_load(
107115
A LazyFrame or DataFrame (or dict[int | str, LazyFrame] / dict[int | str, DataFrame] for Excel-like
108116
files with multiple sheets)
109117
"""
118+
validate_dataset_load_args(
119+
adapter,
120+
pandas_kwargs=pandas_kwargs,
121+
sql_query=sql_query,
122+
hf_kwargs=hf_kwargs,
123+
polars_frame_type=polars_frame_type,
124+
polars_kwargs=polars_kwargs,
125+
)
126+
polars_frame_type = polars_frame_type if polars_frame_type is not None else PolarsFrameType.LAZY_FRAME
110127
try:
111128
if adapter is KaggleDatasetAdapter.HUGGING_FACE:
112129
import kagglehub.hf_datasets
@@ -134,7 +151,7 @@ def dataset_load(
134151
not_implemented_error_message = f"{adapter} is not yet implemented"
135152
raise NotImplementedError(not_implemented_error_message)
136153
except ImportError:
137-
adapter_optional_dependency = LOAD_DATASET_ADAPTER_OPTIONAL_DEPENDENCIES_MAP[adapter]
154+
adapter_optional_dependency = DATASET_LOAD_ADAPTER_OPTIONAL_DEPENDENCIES_MAP[adapter]
138155
import_warning_message = (
139156
f"The 'dataset_load' function requires the '{adapter_optional_dependency}' extras. "
140157
f"Install them with 'pip install kagglehub[{adapter_optional_dependency}]'"
@@ -157,3 +174,20 @@ def load_dataset(
157174
"load_dataset is deprecated and will be removed in a future version.", DeprecationWarning, stacklevel=2
158175
)
159176
return dataset_load(adapter, handle, path, pandas_kwargs=pandas_kwargs, sql_query=sql_query, hf_kwargs=hf_kwargs)
177+
178+
179+
def validate_dataset_load_args(
180+
adapter: KaggleDatasetAdapter,
181+
**kwargs: Any, # noqa: ANN401
182+
) -> None:
183+
valid_kwargs = _DATASET_LOAD_VALID_KWARGS_MAP[adapter]
184+
invalid_kwargs_list: list[str] = []
185+
for key, value in kwargs.items():
186+
if key not in valid_kwargs and value is not None:
187+
invalid_kwargs_list.append(key)
188+
189+
if len(invalid_kwargs_list) == 0:
190+
return
191+
192+
invalid_kwargs = ", ".join(invalid_kwargs_list)
193+
logger.warning(f"{invalid_kwargs} are invalid for {adapter} and will be ignored")

tests/test_dataset_load.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import io
2+
import logging
13
import os
24
from typing import Any
35
from unittest.mock import MagicMock, patch
46

57
import polars as pl
68
from requests import Response
79

8-
from kagglehub.datasets import KaggleDatasetAdapter, PolarsFrameType, dataset_load
10+
from kagglehub.datasets import KaggleDatasetAdapter, PolarsFrameType, dataset_load, logger
911
from kagglehub.exceptions import KaggleApiHTTPError
1012
from tests.fixtures import BaseTestCase
1113

@@ -42,6 +44,22 @@ def _load_hf_dataset_with_invalid_file_type_and_assert_raises(self) -> None:
4244
)
4345
self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception))
4446

47+
def _load_hf_dataset_with_other_loader_kwargs_and_assert_warning(self) -> None:
48+
output_stream = io.StringIO()
49+
handler = logging.StreamHandler(output_stream)
50+
logger.addHandler(handler)
51+
dataset_load(
52+
KaggleDatasetAdapter.HUGGING_FACE,
53+
DATASET_HANDLE,
54+
AUTO_COMPRESSED_FILE_NAME,
55+
polars_frame_type=PolarsFrameType.LAZY_FRAME,
56+
polars_kwargs={},
57+
)
58+
captured_output = output_stream.getvalue()
59+
self.assertIn(
60+
"polars_frame_type, polars_kwargs are invalid for KaggleDatasetAdapter.HUGGING_FACE", captured_output
61+
)
62+
4563
def _load_hf_dataset_with_multiple_tables_and_assert_raises(self) -> None:
4664
with self.assertRaises(ValueError) as cm:
4765
dataset_load(
@@ -85,6 +103,10 @@ def _load_hf_dataset_with_splits_and_assert_loaded(self) -> None:
85103
self.assertEqual(TEST_SPLIT_SIZE if split_name == "test" else TRAIN_SPLIT_SIZE, dataset.num_rows)
86104
self.assertEqual(SHAPES_COLUMNS, dataset.column_names)
87105

106+
def test_hf_dataset_with_other_loader_kwargs_prints_warning(self) -> None:
107+
with create_test_cache():
108+
self._load_hf_dataset_with_other_loader_kwargs_and_assert_warning()
109+
88110
def test_hf_dataset_with_invalid_file_type_raises(self) -> None:
89111
with create_test_cache():
90112
self._load_hf_dataset_with_invalid_file_type_and_assert_raises()
@@ -139,6 +161,23 @@ def _load_pandas_dataset_with_invalid_file_type_and_assert_raises(self) -> None:
139161
)
140162
self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception))
141163

164+
def _load_pandas_dataset_with_other_loader_kwargs_and_assert_warning(self) -> None:
165+
output_stream = io.StringIO()
166+
handler = logging.StreamHandler(output_stream)
167+
logger.addHandler(handler)
168+
dataset_load(
169+
KaggleDatasetAdapter.PANDAS,
170+
DATASET_HANDLE,
171+
AUTO_COMPRESSED_FILE_NAME,
172+
hf_kwargs={},
173+
polars_frame_type=PolarsFrameType.LAZY_FRAME,
174+
polars_kwargs={},
175+
)
176+
captured_output = output_stream.getvalue()
177+
self.assertIn(
178+
"hf_kwargs, polars_frame_type, polars_kwargs are invalid for KaggleDatasetAdapter.PANDAS", captured_output
179+
)
180+
142181
def _load_pandas_simple_dataset_and_assert_loaded(
143182
self,
144183
file_extension: str,
@@ -187,6 +226,10 @@ def test_pandas_dataset_with_invalid_file_type_raises(self) -> None:
187226
with create_test_cache():
188227
self._load_pandas_dataset_with_invalid_file_type_and_assert_raises()
189228

229+
def test_pandas_dataset_with_other_loader_kwargs_prints_warning(self) -> None:
230+
with create_test_cache():
231+
self._load_pandas_dataset_with_other_loader_kwargs_and_assert_warning()
232+
190233
def test_pandas_dataset_with_multiple_tables_succeeds(self) -> None:
191234
with create_test_cache():
192235
self._load_pandas_dataset_with_multiple_tables_and_assert_loaded()
@@ -249,6 +292,20 @@ def _load_polars_dataset_with_invalid_file_type_and_assert_raises(self) -> None:
249292
)
250293
self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception))
251294

295+
def _load_polars_dataset_with_other_loader_kwargs_and_assert_warning(self) -> None:
296+
output_stream = io.StringIO()
297+
handler = logging.StreamHandler(output_stream)
298+
logger.addHandler(handler)
299+
dataset_load(
300+
KaggleDatasetAdapter.POLARS,
301+
DATASET_HANDLE,
302+
AUTO_COMPRESSED_FILE_NAME,
303+
pandas_kwargs={},
304+
hf_kwargs={},
305+
)
306+
captured_output = output_stream.getvalue()
307+
self.assertIn("pandas_kwargs, hf_kwargs are invalid for KaggleDatasetAdapter.POLARS", captured_output)
308+
252309
def _load_polars_simple_dataset_and_assert_loaded(
253310
self,
254311
file_extension: str,
@@ -332,6 +389,10 @@ def test_polars_dataset_with_invalid_file_type_raises(self) -> None:
332389
with create_test_cache():
333390
self._load_polars_dataset_with_invalid_file_type_and_assert_raises()
334391

392+
def test_polars_dataset_with_other_loader_kwargs_prints_warning(self) -> None:
393+
with create_test_cache():
394+
self._load_polars_dataset_with_other_loader_kwargs_and_assert_warning()
395+
335396
def test_polars_dataset_with_multiple_tables_succeeds(self) -> None:
336397
with create_test_cache():
337398
self._load_polars_dataset_with_multiple_tables_and_assert_loaded(PolarsFrameType.LAZY_FRAME)

0 commit comments

Comments
 (0)