|
| 1 | +import io |
| 2 | +import logging |
1 | 3 | import os |
2 | 4 | from typing import Any |
3 | 5 | from unittest.mock import MagicMock, patch |
4 | 6 |
|
5 | 7 | import polars as pl |
6 | 8 | from requests import Response |
7 | 9 |
|
8 | | -from kagglehub.datasets import KaggleDatasetAdapter, PolarsFrameType, dataset_load |
| 10 | +from kagglehub.datasets import KaggleDatasetAdapter, PolarsFrameType, dataset_load, logger |
9 | 11 | from kagglehub.exceptions import KaggleApiHTTPError |
10 | 12 | from tests.fixtures import BaseTestCase |
11 | 13 |
|
@@ -42,6 +44,22 @@ def _load_hf_dataset_with_invalid_file_type_and_assert_raises(self) -> None: |
42 | 44 | ) |
43 | 45 | self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception)) |
44 | 46 |
|
| 47 | + def _load_hf_dataset_with_other_loader_kwargs_and_assert_warning(self) -> None: |
| 48 | + output_stream = io.StringIO() |
| 49 | + handler = logging.StreamHandler(output_stream) |
| 50 | + logger.addHandler(handler) |
| 51 | + dataset_load( |
| 52 | + KaggleDatasetAdapter.HUGGING_FACE, |
| 53 | + DATASET_HANDLE, |
| 54 | + AUTO_COMPRESSED_FILE_NAME, |
| 55 | + polars_frame_type=PolarsFrameType.LAZY_FRAME, |
| 56 | + polars_kwargs={}, |
| 57 | + ) |
| 58 | + captured_output = output_stream.getvalue() |
| 59 | + self.assertIn( |
| 60 | + "polars_frame_type, polars_kwargs are invalid for KaggleDatasetAdapter.HUGGING_FACE", captured_output |
| 61 | + ) |
| 62 | + |
45 | 63 | def _load_hf_dataset_with_multiple_tables_and_assert_raises(self) -> None: |
46 | 64 | with self.assertRaises(ValueError) as cm: |
47 | 65 | dataset_load( |
@@ -85,6 +103,10 @@ def _load_hf_dataset_with_splits_and_assert_loaded(self) -> None: |
85 | 103 | self.assertEqual(TEST_SPLIT_SIZE if split_name == "test" else TRAIN_SPLIT_SIZE, dataset.num_rows) |
86 | 104 | self.assertEqual(SHAPES_COLUMNS, dataset.column_names) |
87 | 105 |
|
| 106 | + def test_hf_dataset_with_other_loader_kwargs_prints_warning(self) -> None: |
| 107 | + with create_test_cache(): |
| 108 | + self._load_hf_dataset_with_other_loader_kwargs_and_assert_warning() |
| 109 | + |
88 | 110 | def test_hf_dataset_with_invalid_file_type_raises(self) -> None: |
89 | 111 | with create_test_cache(): |
90 | 112 | self._load_hf_dataset_with_invalid_file_type_and_assert_raises() |
@@ -139,6 +161,23 @@ def _load_pandas_dataset_with_invalid_file_type_and_assert_raises(self) -> None: |
139 | 161 | ) |
140 | 162 | self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception)) |
141 | 163 |
|
| 164 | + def _load_pandas_dataset_with_other_loader_kwargs_and_assert_warning(self) -> None: |
| 165 | + output_stream = io.StringIO() |
| 166 | + handler = logging.StreamHandler(output_stream) |
| 167 | + logger.addHandler(handler) |
| 168 | + dataset_load( |
| 169 | + KaggleDatasetAdapter.PANDAS, |
| 170 | + DATASET_HANDLE, |
| 171 | + AUTO_COMPRESSED_FILE_NAME, |
| 172 | + hf_kwargs={}, |
| 173 | + polars_frame_type=PolarsFrameType.LAZY_FRAME, |
| 174 | + polars_kwargs={}, |
| 175 | + ) |
| 176 | + captured_output = output_stream.getvalue() |
| 177 | + self.assertIn( |
| 178 | + "hf_kwargs, polars_frame_type, polars_kwargs are invalid for KaggleDatasetAdapter.PANDAS", captured_output |
| 179 | + ) |
| 180 | + |
142 | 181 | def _load_pandas_simple_dataset_and_assert_loaded( |
143 | 182 | self, |
144 | 183 | file_extension: str, |
@@ -187,6 +226,10 @@ def test_pandas_dataset_with_invalid_file_type_raises(self) -> None: |
187 | 226 | with create_test_cache(): |
188 | 227 | self._load_pandas_dataset_with_invalid_file_type_and_assert_raises() |
189 | 228 |
|
| 229 | + def test_pandas_dataset_with_other_loader_kwargs_prints_warning(self) -> None: |
| 230 | + with create_test_cache(): |
| 231 | + self._load_pandas_dataset_with_other_loader_kwargs_and_assert_warning() |
| 232 | + |
190 | 233 | def test_pandas_dataset_with_multiple_tables_succeeds(self) -> None: |
191 | 234 | with create_test_cache(): |
192 | 235 | self._load_pandas_dataset_with_multiple_tables_and_assert_loaded() |
@@ -249,6 +292,20 @@ def _load_polars_dataset_with_invalid_file_type_and_assert_raises(self) -> None: |
249 | 292 | ) |
250 | 293 | self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception)) |
251 | 294 |
|
| 295 | + def _load_polars_dataset_with_other_loader_kwargs_and_assert_warning(self) -> None: |
| 296 | + output_stream = io.StringIO() |
| 297 | + handler = logging.StreamHandler(output_stream) |
| 298 | + logger.addHandler(handler) |
| 299 | + dataset_load( |
| 300 | + KaggleDatasetAdapter.POLARS, |
| 301 | + DATASET_HANDLE, |
| 302 | + AUTO_COMPRESSED_FILE_NAME, |
| 303 | + pandas_kwargs={}, |
| 304 | + hf_kwargs={}, |
| 305 | + ) |
| 306 | + captured_output = output_stream.getvalue() |
| 307 | + self.assertIn("pandas_kwargs, hf_kwargs are invalid for KaggleDatasetAdapter.POLARS", captured_output) |
| 308 | + |
252 | 309 | def _load_polars_simple_dataset_and_assert_loaded( |
253 | 310 | self, |
254 | 311 | file_extension: str, |
@@ -332,6 +389,10 @@ def test_polars_dataset_with_invalid_file_type_raises(self) -> None: |
332 | 389 | with create_test_cache(): |
333 | 390 | self._load_polars_dataset_with_invalid_file_type_and_assert_raises() |
334 | 391 |
|
| 392 | + def test_polars_dataset_with_other_loader_kwargs_prints_warning(self) -> None: |
| 393 | + with create_test_cache(): |
| 394 | + self._load_polars_dataset_with_other_loader_kwargs_and_assert_warning() |
| 395 | + |
335 | 396 | def test_polars_dataset_with_multiple_tables_succeeds(self) -> None: |
336 | 397 | with create_test_cache(): |
337 | 398 | self._load_polars_dataset_with_multiple_tables_and_assert_loaded(PolarsFrameType.LAZY_FRAME) |
|
0 commit comments