From 6c669380138fd9b8c1956df9d628448cc60207e5 Mon Sep 17 00:00:00 2001 From: ArjunJagdale Date: Tue, 29 Jul 2025 01:07:54 +0530 Subject: [PATCH 01/12] Reimplemented partial split download support (revival of #6832) --- src/datasets/arrow_reader.py | 2 +- src/datasets/builder.py | 158 +++++++++++++++--- src/datasets/info.py | 2 +- src/datasets/load.py | 1 + src/datasets/naming.py | 21 +-- src/datasets/packaged_modules/arrow/arrow.py | 12 +- src/datasets/packaged_modules/cache/cache.py | 9 +- src/datasets/packaged_modules/csv/csv.py | 10 +- .../folder_based_builder.py | 7 +- src/datasets/packaged_modules/json/json.py | 12 +- .../packaged_modules/parquet/parquet.py | 10 +- src/datasets/packaged_modules/text/text.py | 12 +- .../packaged_modules/webdataset/webdataset.py | 12 +- src/datasets/utils/info_utils.py | 16 +- tests/test_arrow_reader.py | 21 ++- tests/test_download_manager.py | 1 - tests/test_load.py | 64 +++++++ 17 files changed, 294 insertions(+), 76 deletions(-) diff --git a/src/datasets/arrow_reader.py b/src/datasets/arrow_reader.py index 3bbb58a59c3..7011dbc9585 100644 --- a/src/datasets/arrow_reader.py +++ b/src/datasets/arrow_reader.py @@ -120,7 +120,7 @@ def make_file_instructions( dataset_name=name, split=info.name, filetype_suffix=filetype_suffix, - shard_lengths=name2shard_lengths[info.name], + num_shards=len(name2shard_lengths[info.name] or ()), ) for info in split_infos } diff --git a/src/datasets/builder.py b/src/datasets/builder.py index e63960dcabf..f4f13194284 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -18,6 +18,7 @@ import abc import contextlib import copy +import fnmatch import inspect import os import posixpath @@ -29,7 +30,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, Union from unittest.mock import patch import fsspec @@ -59,7 +60,12 @@ from .info import DatasetInfo, PostProcessedInfo from .iterable_dataset import ArrowExamplesIterable, ExamplesIterable, IterableDataset from .keyhash import DuplicatedKeysError -from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase +from .naming import ( + INVALID_WINDOWS_CHARACTERS_IN_PATH, + camelcase_to_snakecase, + filenames_for_dataset_split, + filepattern_for_dataset_split, +) from .splits import Split, SplitDict, SplitGenerator, SplitInfo from .streaming import extend_dataset_builder_for_streaming from .table import CastError @@ -69,6 +75,7 @@ from .utils.file_utils import is_remote_url from .utils.info_utils import VerificationMode, get_size_checksum_dict, verify_checksums, verify_splits from .utils.py_utils import ( + NestedDataStructure, classproperty, convert_file_size_to_int, has_sufficient_disk_space, @@ -77,6 +84,7 @@ memoize, size_str, temporary_assignment, + unique_values, ) from .utils.sharding import _number_of_shards_in_gen_kwargs, _split_gen_kwargs from .utils.track import tracked_list @@ -679,6 +687,10 @@ def _info(self) -> DatasetInfo: info: (DatasetInfo) The dataset information """ raise NotImplementedError + + def _supports_partial_generation(self) -> bool: + """Whether the dataset supports generation of specific splits.""" + return hasattr(self, "_available_splits") and "splits" in inspect.signature(self._split_generators).parameters @classmethod def get_imported_module_dir(cls): @@ -691,6 +703,7 @@ def _rename(self, src: str, dst: str): def download_and_prepare( self, output_dir: Optional[str] = None, + split: Optional[Union[str, ReadInstruction, Split]] = None, download_config: Optional[DownloadConfig] = None, download_mode: Optional[Union[DownloadMode, str]] = None, verification_mode: Optional[Union[VerificationMode, str]] = None, @@ -710,6 +723,8 @@ def download_and_prepare( Default to this builder's `cache_dir`, which is inside `~/.cache/huggingface/datasets` by default. + split (`Union[str, ReadInstruction, Split]`, *optional*): + Splits to generate. Default to all splits. download_config (`DownloadConfig`, *optional*): Specific download configuration parameters. download_mode ([`DownloadMode`] or `str`, *optional*): @@ -828,12 +843,60 @@ def download_and_prepare( # File locking only with local paths; no file locking on GCS or S3 with FileLock(lock_path) if is_local else contextlib.nullcontext(): # Check if the data already exists - data_exists = self._fs.exists(posixpath.join(self._output_dir, config.DATASET_INFO_FILENAME)) - if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: - logger.info(f"Found cached dataset {self.dataset_name} ({self._output_dir})") + info_exists = self._fs.exists(posixpath.join(self._output_dir, config.DATASET_INFO_FILENAME)) + if info_exists: # We need to update the info in case some splits were added in the meantime # for example when calling load_dataset from multiple workers. self.info = self._load_info() + _dataset_name = self.name if self._check_legacy_cache() else self.dataset_name + splits: Optional[List[str]] = None + cached_split_filepatterns = [] + supports_partial_generation = self._supports_partial_generation() + if supports_partial_generation: + if split: + splits = [] + for split in NestedDataStructure(split).flatten(): + if not isinstance(split, ReadInstruction): + split = str(split) + if split == Split.ALL: + splits = None # generate all splits + break + split = ReadInstruction.from_spec(split) + split_names = [rel_instr.splitname for rel_instr in split._relative_instructions] + splits.extend(split_names) + splits = list(unique_values(splits)) # remove duplicates + available_splits = self._available_splits() + if splits is None: + splits = available_splits + missing_splits = set(splits) - set(available_splits) + if missing_splits: + raise ValueError(f"Splits {list(missing_splits)} not found. Available splits: {available_splits}") + if DownloadMode.REUSE_DATASET_IF_EXISTS: + for split_name in splits[:]: + num_shards = 1 + if self.info.splits is not None: + try: + num_shards = len(self.info.splits[split_name].shard_lengths or ()) + except Exception: + pass + split_filenames = filenames_for_dataset_split( + self._output_dir, + _dataset_name, + split_name, + filetype_suffix=file_format, + num_shards=num_shards, + ) + if self._fs.exists(split_filenames[0]): + splits.remove(split_name) + split_filepattern = filepattern_for_dataset_split( + self._output_dir, _dataset_name, split_name, filetype_suffix=file_format + ) + cached_split_filepatterns.append(split_filepattern) + # We cannot use info as the source of truth if the builder supports partial generation + # as the info can be incomplete in that case + requested_splits_exist = not splits if supports_partial_generation else info_exists + if requested_splits_exist and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: + logger.info(f"Found cached dataset {self.dataset_name} ({self._output_dir})") self.download_post_processing_resources(dl_manager) return @@ -858,16 +921,33 @@ def incomplete_dir(dirname): try: yield tmp_dir if os.path.isdir(dirname): - shutil.rmtree(dirname) + for root, dirnames, filenames in os.walk(dirname, topdown=False): # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory - shutil.move(tmp_dir, dirname) + for filename in filenames: + filename = os.path.join(root, filename) + delete_filename = True + for cached_split_filepattern in cached_split_filepatterns: + if fnmatch.fnmatch(filename, cached_split_filepattern): + delete_filename = False + break + if delete_filename: + os.remove(filename) + for dirname in dirnames: + dirname = os.path.join(root, dirname) + if len(os.listdir(dirname)) == 0: + os.rmdir(dirname) + for file_or_dir in os.listdir(tmp_dir): + try: + shutil.move(os.path.join(tmp_dir, file_or_dir), dirname) + except shutil.Error: + # If the file already exists in the distributed setup + pass + else: + shutil.move(tmp_dir, dirname) finally: if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir) - # Print is intentional: we want this to always go to stdout so user has - # information needed to cancel download/preparation if needed. - # This comes right before the progress bar. if self.info.size_in_bytes: logger.info( f"Downloading and preparing dataset {self.dataset_name}/{self.config.name} " @@ -886,7 +966,7 @@ def incomplete_dir(dirname): # Temporarily assign _output_dir to tmp_data_dir to avoid having to forward # it to every sub function. with temporary_assignment(self, "_output_dir", tmp_output_dir): - prepare_split_kwargs = {"file_format": file_format} + prepare_split_kwargs = {"file_format": file_format, "splits": splits} if max_shard_size is not None: prepare_split_kwargs["max_shard_size"] = max_shard_size if num_proc is not None: @@ -898,7 +978,15 @@ def incomplete_dir(dirname): **download_and_prepare_kwargs, ) # Sync info + if supports_partial_generation and self.info.download_checksums is not None: + self.info.download_checksums.update(dl_manager.get_recorded_sizes_checksums()) + else: + self.info.download_checksums = dl_manager.get_recorded_sizes_checksums() + self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values()) + self.info.download_size = sum( + checksum["num_bytes"] for checksum in self.info.download_checksums.values() + ) self.info.download_checksums = dl_manager.get_recorded_sizes_checksums() if self.info.download_size is not None: self.info.size_in_bytes = self.info.dataset_size + self.info.download_size @@ -942,7 +1030,8 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_k if `NO_CHECKS`, do not perform any verification. prepare_split_kwargs: Additional options, such as `file_format`, `max_shard_size` """ - # Generating data for all splits + # If `splits` is specified and the builder supports `splits` in `_split_generators`, then only generate the specified splits. + # Otherwise, generate all splits split_dict = SplitDict(dataset_name=self.dataset_name) split_generators_kwargs = self._make_split_generators_kwargs(prepare_split_kwargs) split_generators = self._split_generators(dl_manager, **split_generators_kwargs) @@ -950,7 +1039,9 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_k # Checksums verification if verification_mode == VerificationMode.ALL_CHECKS and dl_manager.record_checksums: verify_checksums( - self.info.download_checksums, dl_manager.get_recorded_sizes_checksums(), "dataset source files" + self.info.download_checksums, + dl_manager.get_recorded_sizes_checksums(), + "dataset source files", ) # Build splits @@ -987,9 +1078,18 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_k if verification_mode == VerificationMode.BASIC_CHECKS or verification_mode == VerificationMode.ALL_CHECKS: verify_splits(self.info.splits, split_dict) - # Update the info object with the splits. - self.info.splits = split_dict - self.info.download_size = dl_manager.downloaded_size + # Update the info object with the generated splits. + if self._supports_partial_generation(): + split_infos = self.info.splits or {} + ordered_split_infos = {} + for split_name in self._available_splits(): + if split_name in split_dict: + ordered_split_infos[split_name] = split_dict[split_name] + elif split_name in split_infos: + ordered_split_infos[split_name] = split_infos[split_name] + self.info.splits = SplitDict.from_split_dict(ordered_split_infos, dataset_name=self.dataset_name) + else: + self.info.splits = split_dict def download_post_processing_resources(self, dl_manager): for split in self.info.splits or []: @@ -1021,7 +1121,9 @@ def _save_info(self): def _make_split_generators_kwargs(self, prepare_split_kwargs): """Get kwargs for `self._split_generators()` from `prepare_split_kwargs`.""" - del prepare_split_kwargs + splits = prepare_split_kwargs.pop("splits", None) + if self._supports_partial_generation(): + return {"splits": splits} return {} def as_dataset( @@ -1075,11 +1177,12 @@ def as_dataset( "datasets.load_dataset() before trying to access the Dataset object." ) - logger.debug(f"Constructing Dataset for split {split or ', '.join(self.info.splits)}, from {self._output_dir}") + available_splits = self._available_splits() if self._supports_partial_generation() else self.info.splits + logger.debug(f'Constructing Dataset for split {split or ", ".join(available_splits)}, from {self._output_dir}') # By default, return all splits if split is None: - split = {s: s for s in self.info.splits} + split = {s: s for s in available_splits} verification_mode = VerificationMode(verification_mode or VerificationMode.BASIC_CHECKS) @@ -1107,10 +1210,11 @@ def _build_single_dataset( in_memory: bool = False, ): """as_dataset for a single split.""" + available_splits = self._available_splits() if self._supports_partial_generation() else self.info.splits if not isinstance(split, ReadInstruction): split = str(split) - if split == "all": - split = "+".join(self.info.splits.keys()) + if split == Split.ALL: + split = "+".join(available_splits) split = Split(split) # Build base dataset @@ -1222,8 +1326,12 @@ def as_streaming_dataset( data_dir=self.config.data_dir, ) self._check_manual_download(dl_manager) - splits_generators = {sg.name: sg for sg in self._split_generators(dl_manager)} - # By default, return all splits + splits_generators_kwargs = {} + if self._supports_partial_generation(): + splits_generators_kwargs["splits"] = [split] if split else None + splits_generators = {sg.name: sg for sg in self._split_generators(dl_manager, **splits_generators_kwargs)} + # We still need this in case the builder's `_splits_generators` does not support the `splits` argument + # to filter the splits if split is None: splits_generator = splits_generators elif split in splits_generators: @@ -1403,9 +1511,9 @@ def _prepare_split( ): max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) - if self.info.splits is not None: + try: split_info = self.info.splits[split_generator.name] - else: + except Exception: split_info = split_generator.split_info SUFFIX = "-JJJJJ-SSSSS-of-NNNNN" diff --git a/src/datasets/info.py b/src/datasets/info.py index 3723439fb91..1a1e86b5e10 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -176,7 +176,7 @@ def __post_init__(self): else: self.version = Version.from_dict(self.version) if self.splits is not None and not isinstance(self.splits, SplitDict): - self.splits = SplitDict.from_split_dict(self.splits) + self.splits = SplitDict.from_split_dict(self.splits, self.dataset_name) if self.supervised_keys is not None and not isinstance(self.supervised_keys, SupervisedKeysData): if isinstance(self.supervised_keys, (tuple, list)): self.supervised_keys = SupervisedKeysData(*self.supervised_keys) diff --git a/src/datasets/load.py b/src/datasets/load.py index bc2b0e679b6..17e2bd8d13d 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1410,6 +1410,7 @@ def load_dataset( # Download and prepare data builder_instance.download_and_prepare( + split = split, download_config=download_config, download_mode=download_mode, verification_mode=verification_mode, diff --git a/src/datasets/naming.py b/src/datasets/naming.py index 65e7ede10dc..70eeb63e423 100644 --- a/src/datasets/naming.py +++ b/src/datasets/naming.py @@ -17,6 +17,7 @@ import itertools import os +import posixpath import re @@ -46,33 +47,33 @@ def snakecase_to_camelcase(name): def filename_prefix_for_name(name): - if os.path.basename(name) != name: + if posixpath.basename(name) != name: raise ValueError(f"Should be a dataset name, not a path: {name}") return camelcase_to_snakecase(name) def filename_prefix_for_split(name, split): - if os.path.basename(name) != name: + if posixpath.basename(name) != name: raise ValueError(f"Should be a dataset name, not a path: {name}") if not re.match(_split_re, split): raise ValueError(f"Split name should match '{_split_re}'' but got '{split}'.") return f"{filename_prefix_for_name(name)}-{split}" -def filepattern_for_dataset_split(dataset_name, split, data_dir, filetype_suffix=None): +def filepattern_for_dataset_split(path, dataset_name, split, filetype_suffix=None): prefix = filename_prefix_for_split(dataset_name, split) + filepath = posixpath.join(path, prefix) + filepath = f"{filepath}*" if filetype_suffix: - prefix += f".{filetype_suffix}" - filepath = os.path.join(data_dir, prefix) - return f"{filepath}*" + filepath += f".{filetype_suffix}" + return filepath -def filenames_for_dataset_split(path, dataset_name, split, filetype_suffix=None, shard_lengths=None): +def filenames_for_dataset_split(path, dataset_name, split, filetype_suffix=None, num_shards=1): prefix = filename_prefix_for_split(dataset_name, split) - prefix = os.path.join(path, prefix) + prefix = posixpath.join(path, prefix) - if shard_lengths: - num_shards = len(shard_lengths) + if num_shards > 1: filenames = [f"{prefix}-{shard_id:05d}-of-{num_shards:05d}" for shard_id in range(num_shards)] if filetype_suffix: filenames = [filename + f".{filetype_suffix}" for filename in filenames] diff --git a/src/datasets/packaged_modules/arrow/arrow.py b/src/datasets/packaged_modules/arrow/arrow.py index bcf31c473d2..501179d6d54 100644 --- a/src/datasets/packaged_modules/arrow/arrow.py +++ b/src/datasets/packaged_modules/arrow/arrow.py @@ -1,6 +1,6 @@ import itertools from dataclasses import dataclass -from typing import Optional +from typing import List, Optional import pyarrow as pa @@ -27,12 +27,18 @@ class Arrow(datasets.ArrowBasedBuilder): def _info(self): return datasets.DatasetInfo(features=self.config.features) - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): """We handle string, list and dicts in datafiles""" if not self.config.data_files: raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") dl_manager.download_config.extract_on_the_fly = True - data_files = dl_manager.download_and_extract(self.config.data_files) + data_files = self.config.data_files + if splits and isinstance(data_files, dict): + data_files = {split: data_files[split] for split in splits} + data_files = dl_manager.download_and_extract(data_files) splits = [] for split_name, files in data_files.items(): if isinstance(files, str): diff --git a/src/datasets/packaged_modules/cache/cache.py b/src/datasets/packaged_modules/cache/cache.py index cdcfb4c20b6..a60b59af64d 100644 --- a/src/datasets/packaged_modules/cache/cache.py +++ b/src/datasets/packaged_modules/cache/cache.py @@ -153,10 +153,15 @@ def download_and_prepare(self, output_dir: Optional[str] = None, *args, **kwargs if output_dir is not None and output_dir != self.cache_dir: shutil.copytree(self.cache_dir, output_dir) - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.info.splits] + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): # used to stream from cache if isinstance(self.info.splits, datasets.SplitDict): split_infos: list[datasets.SplitInfo] = list(self.info.splits.values()) + if splits: + split_infos = [split_info for split_info in split_infos if split_info.name in splits] else: raise ValueError(f"Missing splits info for {self.dataset_name} in cache directory {self.cache_dir}") return [ @@ -168,7 +173,7 @@ def _split_generators(self, dl_manager): dataset_name=self.dataset_name, split=split_info.name, filetype_suffix="arrow", - shard_lengths=split_info.shard_lengths, + num_shards=len(split_info.shard_lengths or ()), ) }, ) diff --git a/src/datasets/packaged_modules/csv/csv.py b/src/datasets/packaged_modules/csv/csv.py index 2ae95ff5142..18edb6e23be 100644 --- a/src/datasets/packaged_modules/csv/csv.py +++ b/src/datasets/packaged_modules/csv/csv.py @@ -148,12 +148,18 @@ class Csv(datasets.ArrowBasedBuilder): def _info(self): return datasets.DatasetInfo(features=self.config.features) - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): """We handle string, list and dicts in datafiles""" if not self.config.data_files: raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") dl_manager.download_config.extract_on_the_fly = True - data_files = dl_manager.download_and_extract(self.config.data_files) + data_files = self.config.data_files + if splits and isinstance(data_files, dict): + data_files = {split: data_files[split] for split in splits} + data_files = dl_manager.download_and_extract(data_files) splits = [] for split_name, files in data_files.items(): if isinstance(files, str): diff --git a/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py b/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py index 182de467b14..07b4822753e 100644 --- a/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py +++ b/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py @@ -67,7 +67,10 @@ def _info(self): return datasets.DatasetInfo(features=self.config.features) - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): if not self.config.data_files: raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") dl_manager.download_config.extract_on_the_fly = True @@ -120,6 +123,8 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split): ) data_files = self.config.data_files + if splits and isinstance(data_files, dict): + data_files = {split: data_files[split] for split in splits} splits = [] for split_name, files in data_files.items(): if isinstance(files, str): diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index c5d8bcd03fc..75cf6fe0fea 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -1,7 +1,7 @@ import io import itertools from dataclasses import dataclass -from typing import Optional +from typing import List, Optional import pandas as pd import pyarrow as pa @@ -70,12 +70,18 @@ def _info(self): raise ValueError("The JSON loader parameter `newlines_in_values` is no longer supported") return datasets.DatasetInfo(features=self.config.features) - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): """We handle string, list and dicts in datafiles""" if not self.config.data_files: raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") dl_manager.download_config.extract_on_the_fly = True - data_files = dl_manager.download_and_extract(self.config.data_files) + data_files = self.config.data_files + if splits and isinstance(data_files, dict): + data_files = {split: data_files[split] for split in splits} + data_files = dl_manager.download_and_extract(data_files) splits = [] for split_name, files in data_files.items(): if isinstance(files, str): diff --git a/src/datasets/packaged_modules/parquet/parquet.py b/src/datasets/packaged_modules/parquet/parquet.py index 10797753657..f9129a57065 100644 --- a/src/datasets/packaged_modules/parquet/parquet.py +++ b/src/datasets/packaged_modules/parquet/parquet.py @@ -41,12 +41,18 @@ def _info(self): ) return datasets.DatasetInfo(features=self.config.features) - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): """We handle string, list and dicts in datafiles""" if not self.config.data_files: raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") dl_manager.download_config.extract_on_the_fly = True - data_files = dl_manager.download_and_extract(self.config.data_files) + data_files = self.config.data_files + if splits and isinstance(data_files, dict): + data_files = {split: data_files[split] for split in splits} + data_files = dl_manager.download_and_extract(data_files) splits = [] for split_name, files in data_files.items(): if isinstance(files, str): diff --git a/src/datasets/packaged_modules/text/text.py b/src/datasets/packaged_modules/text/text.py index a1f3ff5a744..f72a50139d0 100644 --- a/src/datasets/packaged_modules/text/text.py +++ b/src/datasets/packaged_modules/text/text.py @@ -1,7 +1,7 @@ import itertools from dataclasses import dataclass from io import StringIO -from typing import Optional +from typing import List, Optional import pyarrow as pa @@ -31,7 +31,10 @@ class Text(datasets.ArrowBasedBuilder): def _info(self): return datasets.DatasetInfo(features=self.config.features) - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): """The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]]. If str or List[str], then the dataset returns only the 'train' split. @@ -40,7 +43,10 @@ def _split_generators(self, dl_manager): if not self.config.data_files: raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") dl_manager.download_config.extract_on_the_fly = True - data_files = dl_manager.download_and_extract(self.config.data_files) + data_files = self.config.data_files + if splits and isinstance(data_files, dict): + data_files = {split: data_files[split] for split in splits} + data_files = dl_manager.download_and_extract(data_files) splits = [] for split_name, files in data_files.items(): if isinstance(files, str): diff --git a/src/datasets/packaged_modules/webdataset/webdataset.py b/src/datasets/packaged_modules/webdataset/webdataset.py index 571276a4cd5..fbd5fb95025 100644 --- a/src/datasets/packaged_modules/webdataset/webdataset.py +++ b/src/datasets/packaged_modules/webdataset/webdataset.py @@ -2,7 +2,7 @@ import json import re from itertools import islice -from typing import Any, Callable +from typing import Any, Callable, Dict, List, Optional import fsspec import numpy as np @@ -59,12 +59,18 @@ def _get_pipeline_from_tar(cls, tar_path, tar_iterator): def _info(self) -> datasets.DatasetInfo: return datasets.DatasetInfo() - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): """We handle string, list and dicts in datafiles""" # Download the data files if not self.config.data_files: raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") - data_files = dl_manager.download(self.config.data_files) + data_files = self.config.data_files + if splits and isinstance(data_files, dict): + data_files = {split: data_files[split] for split in splits} + data_files = dl_manager.download(data_files) splits = [] for split_name, tar_paths in data_files.items(): if isinstance(tar_paths, str): diff --git a/src/datasets/utils/info_utils.py b/src/datasets/utils/info_utils.py index d93f5b4509f..550d5219150 100644 --- a/src/datasets/utils/info_utils.py +++ b/src/datasets/utils/info_utils.py @@ -45,11 +45,11 @@ def verify_checksums(expected_checksums: Optional[dict], recorded_checksums: dic if expected_checksums is None: logger.info("Unable to verify checksums.") return - if len(set(expected_checksums) - set(recorded_checksums)) > 0: - raise ExpectedMoreDownloadedFilesError(str(set(expected_checksums) - set(recorded_checksums))) - if len(set(recorded_checksums) - set(expected_checksums)) > 0: - raise UnexpectedDownloadedFileError(str(set(recorded_checksums) - set(expected_checksums))) - bad_urls = [url for url in expected_checksums if expected_checksums[url] != recorded_checksums[url]] + bad_urls = [ + url + for url in (set(recorded_checksums) & set(expected_checksums)) + if expected_checksums[url] != recorded_checksums[url] + ] for_verification_name = " for " + verification_name if verification_name is not None else "" if len(bad_urls) > 0: raise NonMatchingChecksumError( @@ -64,13 +64,9 @@ def verify_splits(expected_splits: Optional[dict], recorded_splits: dict): if expected_splits is None: logger.info("Unable to verify splits sizes.") return - if len(set(expected_splits) - set(recorded_splits)) > 0: - raise ExpectedMoreSplitsError(str(set(expected_splits) - set(recorded_splits))) - if len(set(recorded_splits) - set(expected_splits)) > 0: - raise UnexpectedSplitsError(str(set(recorded_splits) - set(expected_splits))) bad_splits = [ {"expected": expected_splits[name], "recorded": recorded_splits[name]} - for name in expected_splits + for name in (set(recorded_splits) & set(expected_splits)) if expected_splits[name].num_examples != recorded_splits[name].num_examples ] if len(bad_splits) > 0: diff --git a/tests/test_arrow_reader.py b/tests/test_arrow_reader.py index 6987416f3a4..6320d401a27 100644 --- a/tests/test_arrow_reader.py +++ b/tests/test_arrow_reader.py @@ -1,4 +1,5 @@ import os +import posixpath import tempfile from pathlib import Path from unittest import TestCase @@ -103,8 +104,8 @@ def test_read_files(self): reader = ReaderTest(tmp_dir, info) files = [ - {"filename": os.path.join(tmp_dir, "train")}, - {"filename": os.path.join(tmp_dir, "test"), "skip": 10, "take": 10}, + {"filename": posixpath.join(tmp_dir, "train")}, + {"filename": posixpath.join(tmp_dir, "test"), "skip": 10, "take": 10}, ] dset = Dataset(**reader.read_files(files, original_instructions="train+test[10:20]")) self.assertEqual(dset.num_rows, 110) @@ -169,7 +170,7 @@ def test_make_file_instructions_basic(): assert isinstance(file_instructions, FileInstructions) assert file_instructions.num_examples == 33 assert file_instructions.file_instructions == [ - {"filename": os.path.join(prefix_path, f"{name}-train.arrow"), "skip": 0, "take": 33} + {"filename": posixpath.join(prefix_path, f"{name}-train.arrow"), "skip": 0, "take": 33} ] split_infos = [SplitInfo(name="train", num_examples=100, shard_lengths=[10] * 10)] @@ -177,10 +178,10 @@ def test_make_file_instructions_basic(): assert isinstance(file_instructions, FileInstructions) assert file_instructions.num_examples == 33 assert file_instructions.file_instructions == [ - {"filename": os.path.join(prefix_path, f"{name}-train-00000-of-00010.arrow"), "skip": 0, "take": -1}, - {"filename": os.path.join(prefix_path, f"{name}-train-00001-of-00010.arrow"), "skip": 0, "take": -1}, - {"filename": os.path.join(prefix_path, f"{name}-train-00002-of-00010.arrow"), "skip": 0, "take": -1}, - {"filename": os.path.join(prefix_path, f"{name}-train-00003-of-00010.arrow"), "skip": 0, "take": 3}, + {"filename": posixpath.join(prefix_path, f"{name}-train-00000-of-00010.arrow"), "skip": 0, "take": -1}, + {"filename": posixpath.join(prefix_path, f"{name}-train-00001-of-00010.arrow"), "skip": 0, "take": -1}, + {"filename": posixpath.join(prefix_path, f"{name}-train-00002-of-00010.arrow"), "skip": 0, "take": -1}, + {"filename": posixpath.join(prefix_path, f"{name}-train-00003-of-00010.arrow"), "skip": 0, "take": 3}, ] @@ -217,7 +218,7 @@ def test_make_file_instructions(split_name, instruction, shard_lengths, read_ran if not isinstance(shard_lengths, list): assert file_instructions.file_instructions == [ { - "filename": os.path.join(prefix_path, f"{name}-{split_name}.arrow"), + "filename": posixpath.join(prefix_path, f"{name}-{split_name}.arrow"), "skip": read_range[0], "take": read_range[1] - read_range[0], } @@ -226,7 +227,9 @@ def test_make_file_instructions(split_name, instruction, shard_lengths, read_ran file_instructions_list = [] shard_offset = 0 for i, shard_length in enumerate(shard_lengths): - filename = os.path.join(prefix_path, f"{name}-{split_name}-{i:05d}-of-{len(shard_lengths):05d}.arrow") + filename = posixpath.join( + prefix_path, f"{name}-{split_name}-{i:05d}-of-{len(shard_lengths):05d}.arrow" + ) if shard_offset <= read_range[0] < shard_offset + shard_length: file_instructions_list.append( { diff --git a/tests/test_download_manager.py b/tests/test_download_manager.py index 08eb77366c1..2b09741e3f2 100644 --- a/tests/test_download_manager.py +++ b/tests/test_download_manager.py @@ -131,7 +131,6 @@ def test_download_manager_delete_extracted_files(xz_file): assert extracted_path == dl_manager.extracted_paths[xz_file] extracted_path = Path(extracted_path) parts = extracted_path.parts - # import pdb; pdb.set_trace() assert parts[-1] == hash_url_to_filename(str(xz_file), etag=None) assert parts[-2] == extracted_subdir assert extracted_path.exists() diff --git a/tests/test_load.py b/tests/test_load.py index a532452eb4c..1fbc3c003a5 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1048,6 +1048,70 @@ def test_load_dataset_with_unsupported_extensions(text_dir_with_unsupported_exte assert ds.num_rows == 4 +def test_load_dataset_specific_splits(data_dir): + with tempfile.TemporaryDirectory() as tmp_dir: + with load_dataset(data_dir, split="train", cache_dir=tmp_dir) as dataset: + assert isinstance(dataset, Dataset) + assert len(dataset) > 0 + + processed_dataset_dir = load_dataset_builder(data_dir, cache_dir=tmp_dir).cache_dir + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + assert all(arrow_file.name.split("-", 1)[1].startswith("train") for arrow_file in arrow_files) + + with load_dataset(data_dir, split="test", cache_dir=tmp_dir) as dataset: + assert isinstance(dataset, Dataset) + assert len(dataset) > 0 + + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + assert all(arrow_file.name.split("-", 1)[1].startswith(("train", "test")) for arrow_file in arrow_files) + + with pytest.raises(ValueError): + load_dataset(data_dir, split="non-existing-split", cache_dir=tmp_dir) + + +def test_load_dataset_specific_splits_then_full(data_dir): + with tempfile.TemporaryDirectory() as tmp_dir: + with load_dataset(data_dir, split="train", cache_dir=tmp_dir) as dataset: + assert isinstance(dataset, Dataset) + assert len(dataset) > 0 + + processed_dataset_dir = load_dataset_builder(data_dir, cache_dir=tmp_dir).cache_dir + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + assert all(arrow_file.name.split("-", 1)[1].startswith("train") for arrow_file in arrow_files) + + with load_dataset(data_dir, cache_dir=tmp_dir) as dataset: + assert isinstance(dataset, DatasetDict) + assert len(dataset) > 0 + assert "train" in dataset + assert "test" in dataset + dataset_splits = list(dataset) + + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + assert all(arrow_file.name.split("-", 1)[1].startswith(tuple(dataset_splits)) for arrow_file in arrow_files) + + +@pytest.mark.integration +def test_loading_from_dataset_from_hub_specific_splits(): + with tempfile.TemporaryDirectory() as tmp_dir: + with load_dataset(SAMPLE_DATASET_IDENTIFIER2, split="train", cache_dir=tmp_dir) as dataset: + assert isinstance(dataset, Dataset) + assert len(dataset) > 0 + + processed_dataset_dir = load_dataset_builder(SAMPLE_DATASET_IDENTIFIER2, cache_dir=tmp_dir).cache_dir + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + assert all(arrow_file.name.split("-", 1)[1].startswith("train") for arrow_file in arrow_files) + + with load_dataset(SAMPLE_DATASET_IDENTIFIER2, split="test", cache_dir=tmp_dir) as dataset: + assert isinstance(dataset, Dataset) + assert len(dataset) > 0 + + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + assert all(arrow_file.name.split("-", 1)[1].startswith(("train", "test")) for arrow_file in arrow_files) + + with pytest.raises(ValueError): + load_dataset(SAMPLE_DATASET_IDENTIFIER2, split="non-existing-split", cache_dir=tmp_dir) + + @pytest.mark.integration def test_loading_from_the_datasets_hub_with_token(): true_request = requests.Session().request From 54300a82ca2d61a146b70b468cb9b056f5f6bbf9 Mon Sep 17 00:00:00 2001 From: ArjunJagdale Date: Wed, 29 Oct 2025 15:47:03 +0530 Subject: [PATCH 02/12] fix style and typing imports for CI --- src/datasets/builder.py | 8 ++++---- src/datasets/load.py | 2 +- src/datasets/naming.py | 1 - src/datasets/packaged_modules/cache/cache.py | 2 +- src/datasets/packaged_modules/csv/csv.py | 2 +- .../folder_based_builder/folder_based_builder.py | 2 +- src/datasets/packaged_modules/parquet/parquet.py | 2 +- src/datasets/packaged_modules/webdataset/webdataset.py | 2 +- src/datasets/utils/info_utils.py | 4 ---- tests/test_load.py | 2 +- 10 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index e9f1b30d00e..0f1efd35a8d 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -30,7 +30,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Union from unittest.mock import patch import fsspec @@ -690,7 +690,7 @@ def _info(self) -> DatasetInfo: info: (DatasetInfo) The dataset information """ raise NotImplementedError - + def _supports_partial_generation(self) -> bool: """Whether the dataset supports generation of specific splits.""" return hasattr(self, "_available_splits") and "splits" in inspect.signature(self._split_generators).parameters @@ -925,7 +925,7 @@ def incomplete_dir(dirname): yield tmp_dir if os.path.isdir(dirname): for root, dirnames, filenames in os.walk(dirname, topdown=False): - # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory + # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory for filename in filenames: filename = os.path.join(root, filename) delete_filename = True @@ -1181,7 +1181,7 @@ def as_dataset( ) available_splits = self._available_splits() if self._supports_partial_generation() else self.info.splits - logger.debug(f'Constructing Dataset for split {split or ", ".join(available_splits)}, from {self._output_dir}') + logger.debug(f"Constructing Dataset for split {split or ', '.join(available_splits)}, from {self._output_dir}") # By default, return all splits if split is None: diff --git a/src/datasets/load.py b/src/datasets/load.py index 14fc3ddadc2..7216ea68200 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1415,7 +1415,7 @@ def load_dataset( # Download and prepare data builder_instance.download_and_prepare( - split = split, + split=split, download_config=download_config, download_mode=download_mode, verification_mode=verification_mode, diff --git a/src/datasets/naming.py b/src/datasets/naming.py index 70eeb63e423..05762700f2e 100644 --- a/src/datasets/naming.py +++ b/src/datasets/naming.py @@ -16,7 +16,6 @@ """Utilities for file names.""" import itertools -import os import posixpath import re diff --git a/src/datasets/packaged_modules/cache/cache.py b/src/datasets/packaged_modules/cache/cache.py index a60b59af64d..7b5166353a5 100644 --- a/src/datasets/packaged_modules/cache/cache.py +++ b/src/datasets/packaged_modules/cache/cache.py @@ -4,7 +4,7 @@ import shutil import time from pathlib import Path -from typing import Optional, Union +from typing import List, Optional, Union import pyarrow as pa diff --git a/src/datasets/packaged_modules/csv/csv.py b/src/datasets/packaged_modules/csv/csv.py index 18edb6e23be..a72087c1325 100644 --- a/src/datasets/packaged_modules/csv/csv.py +++ b/src/datasets/packaged_modules/csv/csv.py @@ -1,6 +1,6 @@ import itertools from dataclasses import dataclass -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Union import pandas as pd import pyarrow as pa diff --git a/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py b/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py index 07b4822753e..ee0946c9540 100644 --- a/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py +++ b/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py @@ -3,7 +3,7 @@ import itertools import os from dataclasses import dataclass -from typing import Any, Callable, Iterator, Optional, Union +from typing import Any, Callable, Iterator, List, Optional, Union import pandas as pd import pyarrow as pa diff --git a/src/datasets/packaged_modules/parquet/parquet.py b/src/datasets/packaged_modules/parquet/parquet.py index 658060ccff9..351ce934eca 100644 --- a/src/datasets/packaged_modules/parquet/parquet.py +++ b/src/datasets/packaged_modules/parquet/parquet.py @@ -1,6 +1,6 @@ import itertools from dataclasses import dataclass -from typing import Literal, Optional, Union +from typing import List, Literal, Optional, Union import pyarrow as pa import pyarrow.dataset as ds diff --git a/src/datasets/packaged_modules/webdataset/webdataset.py b/src/datasets/packaged_modules/webdataset/webdataset.py index c27f989f435..ffe850b24a5 100644 --- a/src/datasets/packaged_modules/webdataset/webdataset.py +++ b/src/datasets/packaged_modules/webdataset/webdataset.py @@ -2,7 +2,7 @@ import json import re from itertools import islice -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, List, Optional import fsspec import numpy as np diff --git a/src/datasets/utils/info_utils.py b/src/datasets/utils/info_utils.py index 550d5219150..2d55dad984a 100644 --- a/src/datasets/utils/info_utils.py +++ b/src/datasets/utils/info_utils.py @@ -6,12 +6,8 @@ from .. import config from ..exceptions import ( - ExpectedMoreDownloadedFilesError, - ExpectedMoreSplitsError, NonMatchingChecksumError, NonMatchingSplitsSizesError, - UnexpectedDownloadedFileError, - UnexpectedSplitsError, ) from .logging import get_logger diff --git a/tests/test_load.py b/tests/test_load.py index 89bcb200d1d..2c5a84cd988 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1107,7 +1107,7 @@ def test_loading_from_dataset_from_hub_specific_splits(): with pytest.raises(ValueError): load_dataset(SAMPLE_DATASET_IDENTIFIER2, split="non-existing-split", cache_dir=tmp_dir) - + @pytest.mark.integration def test_loading_from_the_datasets_hub_with_token(): class CustomException(Exception): From 699ecbe62532e1188736932e47822597a8cf4119 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Sun, 9 Nov 2025 09:19:42 +0100 Subject: [PATCH 03/12] WIP: fix issue that download always happens, add couple todos and breakpoints --- src/datasets/builder.py | 9 ++++++++- src/datasets/packaged_modules/arrow/arrow.py | 1 + src/datasets/packaged_modules/parquet/parquet.py | 1 + 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 0f1efd35a8d..240c57a3a71 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -868,16 +868,19 @@ def download_and_prepare( split_names = [rel_instr.splitname for rel_instr in split._relative_instructions] splits.extend(split_names) splits = list(unique_values(splits)) # remove duplicates + # todo: can we simply use getattr(self.info.splits, 'keys', dict)() here? available_splits = self._available_splits() if splits is None: splits = available_splits missing_splits = set(splits) - set(available_splits) if missing_splits: raise ValueError(f"Splits {list(missing_splits)} not found. Available splits: {available_splits}") - if DownloadMode.REUSE_DATASET_IF_EXISTS: + # todo: this should check against anything, does always evaluate to true! + if download_mode is DownloadMode.REUSE_DATASET_IF_EXISTS: for split_name in splits[:]: num_shards = 1 if self.info.splits is not None: + # todo: what is this exception for? try: num_shards = len(self.info.splits[split_name].shard_lengths or ()) except Exception: @@ -895,6 +898,10 @@ def download_and_prepare( self._output_dir, _dataset_name, split_name, filetype_suffix=file_format ) cached_split_filepatterns.append(split_filepattern) + else: + # todo: how do we download files normally? + pass + # We cannot use info as the source of truth if the builder supports partial generation # as the info can be incomplete in that case requested_splits_exist = not splits if supports_partial_generation else info_exists diff --git a/src/datasets/packaged_modules/arrow/arrow.py b/src/datasets/packaged_modules/arrow/arrow.py index 501179d6d54..eb4c398ea09 100644 --- a/src/datasets/packaged_modules/arrow/arrow.py +++ b/src/datasets/packaged_modules/arrow/arrow.py @@ -28,6 +28,7 @@ def _info(self): return datasets.DatasetInfo(features=self.config.features) def _available_splits(self) -> Optional[List[str]]: + import pdb; pdb.set_trace() return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): diff --git a/src/datasets/packaged_modules/parquet/parquet.py b/src/datasets/packaged_modules/parquet/parquet.py index 351ce934eca..55808bc74ad 100644 --- a/src/datasets/packaged_modules/parquet/parquet.py +++ b/src/datasets/packaged_modules/parquet/parquet.py @@ -103,6 +103,7 @@ def _info(self): return datasets.DatasetInfo(features=self.config.features) def _available_splits(self) -> Optional[List[str]]: + import pdb; pdb.set_trace() return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): From 9a7a1bf66aa33cd75ee38ff4eae51e4c99e50370 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Sun, 9 Nov 2025 18:48:44 +0100 Subject: [PATCH 04/12] WIP: before rewriting to delete on FORCE_REDOWNLOAD --- src/datasets/builder.py | 57 ++++++++++--------- .../packaged_modules/parquet/parquet.py | 1 - tests/test_load.py | 46 +++++++++++++++ 3 files changed, 76 insertions(+), 28 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 240c57a3a71..53c7965cff7 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -691,7 +691,7 @@ def _info(self) -> DatasetInfo: """ raise NotImplementedError - def _supports_partial_generation(self) -> bool: + def _supports_split_by_split_generation(self) -> bool: """Whether the dataset supports generation of specific splits.""" return hasattr(self, "_available_splits") and "splits" in inspect.signature(self._split_generators).parameters @@ -844,6 +844,7 @@ def download_and_prepare( lock_path = self._output_dir + "_builder.lock" # File locking only with local paths; no file locking on GCS or S3 + # import pdb; pdb.set_trace() with FileLock(lock_path) if is_local else contextlib.nullcontext(): # Check if the data already exists info_exists = self._fs.exists(posixpath.join(self._output_dir, config.DATASET_INFO_FILENAME)) @@ -853,8 +854,8 @@ def download_and_prepare( self.info = self._load_info() _dataset_name = self.name if self._check_legacy_cache() else self.dataset_name splits: Optional[List[str]] = None - cached_split_filepatterns = [] - supports_partial_generation = self._supports_partial_generation() + patterns_of_split_files_to_overwrite = [] + supports_partial_generation = self._supports_split_by_split_generation() if supports_partial_generation: if split: splits = [] @@ -874,17 +875,17 @@ def download_and_prepare( splits = available_splits missing_splits = set(splits) - set(available_splits) if missing_splits: + # import pdb; pdb.set_trace() raise ValueError(f"Splits {list(missing_splits)} not found. Available splits: {available_splits}") - # todo: this should check against anything, does always evaluate to true! - if download_mode is DownloadMode.REUSE_DATASET_IF_EXISTS: + if download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: for split_name in splits[:]: num_shards = 1 - if self.info.splits is not None: - # todo: what is this exception for? + if self.info.splits: try: num_shards = len(self.info.splits[split_name].shard_lengths or ()) - except Exception: + except (TypeError, ValueError): pass + # import pdb; pdb.set_trace() split_filenames = filenames_for_dataset_split( self._output_dir, _dataset_name, @@ -894,18 +895,17 @@ def download_and_prepare( ) if self._fs.exists(split_filenames[0]): splits.remove(split_name) + # import pdb; pdb.set_trace() split_filepattern = filepattern_for_dataset_split( self._output_dir, _dataset_name, split_name, filetype_suffix=file_format ) - cached_split_filepatterns.append(split_filepattern) - else: - # todo: how do we download files normally? - pass + patterns_of_split_files_to_overwrite.append(split_filepattern) # We cannot use info as the source of truth if the builder supports partial generation # as the info can be incomplete in that case requested_splits_exist = not splits if supports_partial_generation else info_exists if requested_splits_exist and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: + import pdb; pdb.set_trace() logger.info(f"Found cached dataset {self.dataset_name} ({self._output_dir})") self.download_post_processing_resources(dl_manager) return @@ -930,15 +930,17 @@ def incomplete_dir(dirname): os.makedirs(tmp_dir, exist_ok=True) try: yield tmp_dir + # raise ValueError("debugging") if os.path.isdir(dirname): + import pdb; pdb.set_trace() for root, dirnames, filenames in os.walk(dirname, topdown=False): # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory for filename in filenames: filename = os.path.join(root, filename) - delete_filename = True - for cached_split_filepattern in cached_split_filepatterns: - if fnmatch.fnmatch(filename, cached_split_filepattern): - delete_filename = False + delete_filename = False + for split_filepattern_to_overwrite in patterns_of_split_files_to_overwrite: + if fnmatch.fnmatch(filename, split_filepattern_to_overwrite): + delete_filename = True break if delete_filename: os.remove(filename) @@ -972,6 +974,7 @@ def incomplete_dir(dirname): self._check_manual_download(dl_manager) # Create a tmp dir and rename to self._output_dir on successful exit. + import pdb; pdb.set_trace() with incomplete_dir(self._output_dir) as tmp_output_dir: # Temporarily assign _output_dir to tmp_data_dir to avoid having to forward # it to every sub function. @@ -1040,8 +1043,8 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_k if `NO_CHECKS`, do not perform any verification. prepare_split_kwargs: Additional options, such as `file_format`, `max_shard_size` """ - # If `splits` is specified and the builder supports `splits` in `_split_generators`, then only generate the specified splits. - # Otherwise, generate all splits + # Generating data for all splits + # import pdb; pdb.set_trace() split_dict = SplitDict(dataset_name=self.dataset_name) split_generators_kwargs = self._make_split_generators_kwargs(prepare_split_kwargs) split_generators = self._split_generators(dl_manager, **split_generators_kwargs) @@ -1089,7 +1092,7 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_k verify_splits(self.info.splits, split_dict) # Update the info object with the generated splits. - if self._supports_partial_generation(): + if self._supports_split_by_split_generation(): split_infos = self.info.splits or {} ordered_split_infos = {} for split_name in self._available_splits(): @@ -1104,7 +1107,7 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_k def download_post_processing_resources(self, dl_manager): for split in self.info.splits or []: for resource_name, resource_file_name in self._post_processing_resources(split).items(): - if not not is_remote_filesystem(self._fs): + if not is_remote_filesystem(self._fs): raise NotImplementedError(f"Post processing is not supported on filesystem {self._fs}") if os.sep in resource_file_name: raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}") @@ -1132,7 +1135,7 @@ def _save_info(self): def _make_split_generators_kwargs(self, prepare_split_kwargs): """Get kwargs for `self._split_generators()` from `prepare_split_kwargs`.""" splits = prepare_split_kwargs.pop("splits", None) - if self._supports_partial_generation(): + if self._supports_split_by_split_generation(): return {"splits": splits} return {} @@ -1187,7 +1190,7 @@ def as_dataset( "datasets.load_dataset() before trying to access the Dataset object." ) - available_splits = self._available_splits() if self._supports_partial_generation() else self.info.splits + available_splits = self._available_splits() if self._supports_split_by_split_generation() else self.info.splits logger.debug(f"Constructing Dataset for split {split or ', '.join(available_splits)}, from {self._output_dir}") # By default, return all splits @@ -1220,7 +1223,7 @@ def _build_single_dataset( in_memory: bool = False, ): """as_dataset for a single split.""" - available_splits = self._available_splits() if self._supports_partial_generation() else self.info.splits + available_splits = self._available_splits() if self._supports_split_by_split_generation() else self.info.splits if not isinstance(split, ReadInstruction): split = str(split) if split == Split.ALL: @@ -1337,7 +1340,7 @@ def as_streaming_dataset( ) self._check_manual_download(dl_manager) splits_generators_kwargs = {} - if self._supports_partial_generation(): + if self._supports_split_by_split_generation(): splits_generators_kwargs["splits"] = [split] if split else None splits_generators = {sg.name: sg for sg in self._split_generators(dl_manager, **splits_generators_kwargs)} # We still need this in case the builder's `_splits_generators` does not support the `splits` argument @@ -1521,9 +1524,9 @@ def _prepare_split( ): max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) - try: - split_info = self.info.splits[split_generator.name] - except Exception: + if self.info.splits: + split_info = self.info.splits.get(split_generator.name, split_generator.split_info) + else: split_info = split_generator.split_info SUFFIX = "-JJJJJ-SSSSS-of-NNNNN" diff --git a/src/datasets/packaged_modules/parquet/parquet.py b/src/datasets/packaged_modules/parquet/parquet.py index 55808bc74ad..351ce934eca 100644 --- a/src/datasets/packaged_modules/parquet/parquet.py +++ b/src/datasets/packaged_modules/parquet/parquet.py @@ -103,7 +103,6 @@ def _info(self): return datasets.DatasetInfo(features=self.config.features) def _available_splits(self) -> Optional[List[str]]: - import pdb; pdb.set_trace() return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): diff --git a/tests/test_load.py b/tests/test_load.py index 2c5a84cd988..03721754049 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1064,6 +1064,25 @@ def test_load_dataset_specific_splits(data_dir): with pytest.raises(ValueError): load_dataset(data_dir, split="non-existing-split", cache_dir=tmp_dir) +def test_load_dataset_full_then_specific_split_force_redownload(data_dir): + with tempfile.TemporaryDirectory() as tmp_dir: + with load_dataset(data_dir, cache_dir=tmp_dir) as dataset: + assert isinstance(dataset, DatasetDict) + assert len(dataset) > 0 + assert "train" in dataset + assert "test" in dataset + + processed_dataset_dir = load_dataset_builder(data_dir, cache_dir=tmp_dir).cache_dir + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + assert all(arrow_file.name.split("-", 1)[1].startswith(("train", "test")) for arrow_file in arrow_files) + + with load_dataset(data_dir, split="train", cache_dir=tmp_dir, download_mode="force_redownload") as dataset: + assert isinstance(dataset, Dataset) + assert len(dataset) > 0 + + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + # make sure test is gone after force_redownload + assert all(arrow_file.name.split("-", 1)[1].startswith("train") for arrow_file in arrow_files) def test_load_dataset_specific_splits_then_full(data_dir): with tempfile.TemporaryDirectory() as tmp_dir: @@ -1086,6 +1105,33 @@ def test_load_dataset_specific_splits_then_full(data_dir): assert all(arrow_file.name.split("-", 1)[1].startswith(tuple(dataset_splits)) for arrow_file in arrow_files) +def test_load_dataset_specific_splits_missing_split(data_dir): + import re + with pytest.raises(ValueError, match=re.escape("Splits ['missing_split'] not found. Available splits: ['train', 'test']")): + with tempfile.TemporaryDirectory() as tmp_dir: + load_dataset(data_dir, split="missing_split", cache_dir=tmp_dir) + + +def test_load_dataset_specific_splits_then_other(data_dir): + with tempfile.TemporaryDirectory() as tmp_dir: + with load_dataset(data_dir, split="train", cache_dir=tmp_dir) as dataset: + assert isinstance(dataset, Dataset) + assert len(dataset) > 0 + + processed_dataset_dir = load_dataset_builder(data_dir, cache_dir=tmp_dir).cache_dir + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + assert all(arrow_file.name.split("-", 1)[1].startswith("train") for arrow_file in arrow_files) + + import pdb; pdb.set_trace() + with load_dataset(data_dir, split="test", cache_dir=tmp_dir) as dataset: + assert isinstance(dataset, Dataset) + assert len(dataset) > 0 + + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + # after loading both splits independently, we should have both locally + assert all(arrow_file.name.split("-", 1)[1].startswith(("train", "test")) for arrow_file in arrow_files) + + @pytest.mark.integration def test_loading_from_dataset_from_hub_specific_splits(): with tempfile.TemporaryDirectory() as tmp_dir: From dcc20d9da8ff5d9d06b522b3ba5b77f2e9f809b4 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Sun, 9 Nov 2025 19:50:20 +0100 Subject: [PATCH 05/12] split tests passing --- src/datasets/builder.py | 76 ++++++++++++++++++++++++++++------------- 1 file changed, 52 insertions(+), 24 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 53c7965cff7..f75eed82e4e 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -877,38 +877,51 @@ def download_and_prepare( if missing_splits: # import pdb; pdb.set_trace() raise ValueError(f"Splits {list(missing_splits)} not found. Available splits: {available_splits}") - if download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: - for split_name in splits[:]: - num_shards = 1 - if self.info.splits: - try: - num_shards = len(self.info.splits[split_name].shard_lengths or ()) - except (TypeError, ValueError): - pass - # import pdb; pdb.set_trace() - split_filenames = filenames_for_dataset_split( - self._output_dir, - _dataset_name, - split_name, - filetype_suffix=file_format, - num_shards=num_shards, - ) + if download_mode == DownloadMode.FORCE_REDOWNLOAD: + for split_name in set(available_splits) - set(splits): + split_filenames = self._get_filenames_for_split(split_name, + dataset_name=_dataset_name, + file_format=file_format) if self._fs.exists(split_filenames[0]): - splits.remove(split_name) - # import pdb; pdb.set_trace() split_filepattern = filepattern_for_dataset_split( self._output_dir, _dataset_name, split_name, filetype_suffix=file_format ) + # import pdb; pdb.set_trace() patterns_of_split_files_to_overwrite.append(split_filepattern) + # for split_name in splits[:]: + # num_shards = 1 + # if self.info.splits: + # try: + # num_shards = len(self.info.splits[split_name].shard_lengths or ()) + # except (TypeError, ValueError): + # pass + # # import pdb; pdb.set_trace() + # split_filenames = filenames_for_dataset_split( + # self._output_dir, + # _dataset_name, + # split_name, + # filetype_suffix=file_format, + # num_shards=num_shards, + # ) + # if self._fs.exists(split_filenames[0]): + # splits.remove(split_name) + # # import pdb; pdb.set_trace() + # split_filepattern = filepattern_for_dataset_split( + # self._output_dir, _dataset_name, split_name, filetype_suffix=file_format + # ) + # patterns_of_split_files_to_overwrite.append(split_filepattern) # We cannot use info as the source of truth if the builder supports partial generation # as the info can be incomplete in that case - requested_splits_exist = not splits if supports_partial_generation else info_exists - if requested_splits_exist and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: - import pdb; pdb.set_trace() - logger.info(f"Found cached dataset {self.dataset_name} ({self._output_dir})") - self.download_post_processing_resources(dl_manager) - return + if download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS and supports_partial_generation is True: + requested_splits_exist = True + for split_name in splits[:]: + file_names = self._get_filenames_for_split(split_name, dataset_name=_dataset_name, file_format=file_format) + if not self._fs.exists(file_names[0]): + requested_splits_exist = False + break + if requested_splits_exist: + self.download_post_processing_resources(dl_manager) logger.info(f"Generating dataset {self.dataset_name} ({self._output_dir})") if is_local: # if cache dir is local, check for available space @@ -1014,6 +1027,21 @@ def incomplete_dir(dirname): f"Subsequent calls will reuse this data." ) + def _get_filenames_for_split(self, split_name: str, dataset_name: str, file_format: str) -> list[str]: + num_shards = 1 + if self.info.splits: + try: + num_shards = len(self.info.splits[split_name].shard_lengths or ()) + except (TypeError, ValueError): + pass + return filenames_for_dataset_split( + self._output_dir, + dataset_name, + split_name, + filetype_suffix=file_format, + num_shards=num_shards, + ) + def _check_manual_download(self, dl_manager): if self.manual_download_instructions is not None and dl_manager.manual_dir is None: raise ManualDownloadError( From 4e2ddccf5eb376f11a558b0a526ee406b7b33b12 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Sun, 9 Nov 2025 20:46:13 +0100 Subject: [PATCH 06/12] WIP: one test in test_load fails --- src/datasets/builder.py | 4 ++-- tests/test_load.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index f75eed82e4e..f4bda287d7a 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -945,7 +945,7 @@ def incomplete_dir(dirname): yield tmp_dir # raise ValueError("debugging") if os.path.isdir(dirname): - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() for root, dirnames, filenames in os.walk(dirname, topdown=False): # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory for filename in filenames: @@ -987,7 +987,7 @@ def incomplete_dir(dirname): self._check_manual_download(dl_manager) # Create a tmp dir and rename to self._output_dir on successful exit. - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() with incomplete_dir(self._output_dir) as tmp_output_dir: # Temporarily assign _output_dir to tmp_data_dir to avoid having to forward # it to every sub function. diff --git a/tests/test_load.py b/tests/test_load.py index 03721754049..6c7452e32f1 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1122,7 +1122,7 @@ def test_load_dataset_specific_splits_then_other(data_dir): arrow_files = Path(processed_dataset_dir).glob("*.arrow") assert all(arrow_file.name.split("-", 1)[1].startswith("train") for arrow_file in arrow_files) - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() with load_dataset(data_dir, split="test", cache_dir=tmp_dir) as dataset: assert isinstance(dataset, Dataset) assert len(dataset) > 0 @@ -1278,6 +1278,7 @@ def test_load_dataset_then_move_then_reload(data_dir, tmp_path, caplog): del dataset os.rename(cache_dir1, cache_dir2) caplog.clear() + import pdb; pdb.set_trace() with caplog.at_level(INFO, logger=get_logger().name): dataset = load_dataset(data_dir, split="train", cache_dir=cache_dir2) assert "Found cached dataset" in caplog.text From 06e4cffc71fe8502e63d89b70cf361ae2bb227c2 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Sun, 9 Nov 2025 20:46:29 +0100 Subject: [PATCH 07/12] WIP: one test in test_load fails --- src/datasets/builder.py | 10 ++++++---- tests/test_load.py | 11 +++++++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index f4bda287d7a..dbf9e8decce 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -879,9 +879,9 @@ def download_and_prepare( raise ValueError(f"Splits {list(missing_splits)} not found. Available splits: {available_splits}") if download_mode == DownloadMode.FORCE_REDOWNLOAD: for split_name in set(available_splits) - set(splits): - split_filenames = self._get_filenames_for_split(split_name, - dataset_name=_dataset_name, - file_format=file_format) + split_filenames = self._get_filenames_for_split( + split_name, dataset_name=_dataset_name, file_format=file_format + ) if self._fs.exists(split_filenames[0]): split_filepattern = filepattern_for_dataset_split( self._output_dir, _dataset_name, split_name, filetype_suffix=file_format @@ -916,7 +916,9 @@ def download_and_prepare( if download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS and supports_partial_generation is True: requested_splits_exist = True for split_name in splits[:]: - file_names = self._get_filenames_for_split(split_name, dataset_name=_dataset_name, file_format=file_format) + file_names = self._get_filenames_for_split( + split_name, dataset_name=_dataset_name, file_format=file_format + ) if not self._fs.exists(file_names[0]): requested_splits_exist = False break diff --git a/tests/test_load.py b/tests/test_load.py index 6c7452e32f1..38146188cde 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1064,6 +1064,7 @@ def test_load_dataset_specific_splits(data_dir): with pytest.raises(ValueError): load_dataset(data_dir, split="non-existing-split", cache_dir=tmp_dir) + def test_load_dataset_full_then_specific_split_force_redownload(data_dir): with tempfile.TemporaryDirectory() as tmp_dir: with load_dataset(data_dir, cache_dir=tmp_dir) as dataset: @@ -1084,6 +1085,7 @@ def test_load_dataset_full_then_specific_split_force_redownload(data_dir): # make sure test is gone after force_redownload assert all(arrow_file.name.split("-", 1)[1].startswith("train") for arrow_file in arrow_files) + def test_load_dataset_specific_splits_then_full(data_dir): with tempfile.TemporaryDirectory() as tmp_dir: with load_dataset(data_dir, split="train", cache_dir=tmp_dir) as dataset: @@ -1107,7 +1109,10 @@ def test_load_dataset_specific_splits_then_full(data_dir): def test_load_dataset_specific_splits_missing_split(data_dir): import re - with pytest.raises(ValueError, match=re.escape("Splits ['missing_split'] not found. Available splits: ['train', 'test']")): + + with pytest.raises( + ValueError, match=re.escape("Splits ['missing_split'] not found. Available splits: ['train', 'test']") + ): with tempfile.TemporaryDirectory() as tmp_dir: load_dataset(data_dir, split="missing_split", cache_dir=tmp_dir) @@ -1278,7 +1283,9 @@ def test_load_dataset_then_move_then_reload(data_dir, tmp_path, caplog): del dataset os.rename(cache_dir1, cache_dir2) caplog.clear() - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() with caplog.at_level(INFO, logger=get_logger().name): dataset = load_dataset(data_dir, split="train", cache_dir=cache_dir2) assert "Found cached dataset" in caplog.text From 0bbc6fe165b47e31417fa63a3d724dee2bb26fe5 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Sun, 9 Nov 2025 21:09:18 +0100 Subject: [PATCH 08/12] remove debug statements, cleanup things --- src/datasets/builder.py | 1 + src/datasets/packaged_modules/arrow/arrow.py | 1 - tests/test_load.py | 2 -- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index dbf9e8decce..bdb5d3f882c 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -923,6 +923,7 @@ def download_and_prepare( requested_splits_exist = False break if requested_splits_exist: + logger.info(f"Found cached dataset {self.dataset_name} ({self._output_dir})") self.download_post_processing_resources(dl_manager) logger.info(f"Generating dataset {self.dataset_name} ({self._output_dir})") diff --git a/src/datasets/packaged_modules/arrow/arrow.py b/src/datasets/packaged_modules/arrow/arrow.py index eb4c398ea09..501179d6d54 100644 --- a/src/datasets/packaged_modules/arrow/arrow.py +++ b/src/datasets/packaged_modules/arrow/arrow.py @@ -28,7 +28,6 @@ def _info(self): return datasets.DatasetInfo(features=self.config.features) def _available_splits(self) -> Optional[List[str]]: - import pdb; pdb.set_trace() return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): diff --git a/tests/test_load.py b/tests/test_load.py index 38146188cde..03a88aa94c0 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1283,9 +1283,7 @@ def test_load_dataset_then_move_then_reload(data_dir, tmp_path, caplog): del dataset os.rename(cache_dir1, cache_dir2) caplog.clear() - import pdb - pdb.set_trace() with caplog.at_level(INFO, logger=get_logger().name): dataset = load_dataset(data_dir, split="train", cache_dir=cache_dir2) assert "Found cached dataset" in caplog.text From 2fd4ad37aabe8f9657149ae530fd979d8d00ab1f Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Sun, 9 Nov 2025 21:10:13 +0100 Subject: [PATCH 09/12] remove commented out code --- src/datasets/builder.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index bdb5d3f882c..6312424ca57 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -844,7 +844,6 @@ def download_and_prepare( lock_path = self._output_dir + "_builder.lock" # File locking only with local paths; no file locking on GCS or S3 - # import pdb; pdb.set_trace() with FileLock(lock_path) if is_local else contextlib.nullcontext(): # Check if the data already exists info_exists = self._fs.exists(posixpath.join(self._output_dir, config.DATASET_INFO_FILENAME)) @@ -875,7 +874,6 @@ def download_and_prepare( splits = available_splits missing_splits = set(splits) - set(available_splits) if missing_splits: - # import pdb; pdb.set_trace() raise ValueError(f"Splits {list(missing_splits)} not found. Available splits: {available_splits}") if download_mode == DownloadMode.FORCE_REDOWNLOAD: for split_name in set(available_splits) - set(splits): @@ -886,30 +884,7 @@ def download_and_prepare( split_filepattern = filepattern_for_dataset_split( self._output_dir, _dataset_name, split_name, filetype_suffix=file_format ) - # import pdb; pdb.set_trace() patterns_of_split_files_to_overwrite.append(split_filepattern) - # for split_name in splits[:]: - # num_shards = 1 - # if self.info.splits: - # try: - # num_shards = len(self.info.splits[split_name].shard_lengths or ()) - # except (TypeError, ValueError): - # pass - # # import pdb; pdb.set_trace() - # split_filenames = filenames_for_dataset_split( - # self._output_dir, - # _dataset_name, - # split_name, - # filetype_suffix=file_format, - # num_shards=num_shards, - # ) - # if self._fs.exists(split_filenames[0]): - # splits.remove(split_name) - # # import pdb; pdb.set_trace() - # split_filepattern = filepattern_for_dataset_split( - # self._output_dir, _dataset_name, split_name, filetype_suffix=file_format - # ) - # patterns_of_split_files_to_overwrite.append(split_filepattern) # We cannot use info as the source of truth if the builder supports partial generation # as the info can be incomplete in that case @@ -948,7 +923,6 @@ def incomplete_dir(dirname): yield tmp_dir # raise ValueError("debugging") if os.path.isdir(dirname): - # import pdb; pdb.set_trace() for root, dirnames, filenames in os.walk(dirname, topdown=False): # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory for filename in filenames: @@ -990,7 +964,6 @@ def incomplete_dir(dirname): self._check_manual_download(dl_manager) # Create a tmp dir and rename to self._output_dir on successful exit. - # import pdb; pdb.set_trace() with incomplete_dir(self._output_dir) as tmp_output_dir: # Temporarily assign _output_dir to tmp_data_dir to avoid having to forward # it to every sub function. @@ -1075,7 +1048,6 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_k prepare_split_kwargs: Additional options, such as `file_format`, `max_shard_size` """ # Generating data for all splits - # import pdb; pdb.set_trace() split_dict = SplitDict(dataset_name=self.dataset_name) split_generators_kwargs = self._make_split_generators_kwargs(prepare_split_kwargs) split_generators = self._split_generators(dl_manager, **split_generators_kwargs) From 26e71152b20b4f5b7381a775fd6421dcddfa5bd4 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Wed, 10 Dec 2025 12:15:58 +0100 Subject: [PATCH 10/12] WIP: try fix tests --- src/datasets/arrow_reader.py | 2 +- src/datasets/builder.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/datasets/arrow_reader.py b/src/datasets/arrow_reader.py index 75b56919be2..8776be74752 100644 --- a/src/datasets/arrow_reader.py +++ b/src/datasets/arrow_reader.py @@ -120,7 +120,7 @@ def make_file_instructions( dataset_name=name, split=info.name, filetype_suffix=filetype_suffix, - num_shards=len(name2shard_lengths[info.name] or ()), + shard_lengths=name2shard_lengths[info.name] or [], ) for info in split_infos } diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 7e5736b3f2f..dae5347dd44 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -991,10 +991,10 @@ def incomplete_dir(dirname): ) def _get_filenames_for_split(self, split_name: str, dataset_name: str, file_format: str) -> list[str]: - num_shards = 1 + shard_lengths = [1] if self.info.splits: try: - num_shards = len(self.info.splits[split_name].shard_lengths or ()) + shard_lengths = self.info.splits[split_name].shard_lengths or [] except (TypeError, ValueError): pass return filenames_for_dataset_split( @@ -1002,7 +1002,7 @@ def _get_filenames_for_split(self, split_name: str, dataset_name: str, file_form dataset_name, split_name, filetype_suffix=file_format, - num_shards=num_shards, + shard_lengths=shard_lengths, ) def _check_manual_download(self, dl_manager): From 9105964c8d40c4e3089133fb78890341435d0bf9 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Wed, 10 Dec 2025 12:28:15 +0100 Subject: [PATCH 11/12] fix pre-commit and naming errors --- src/datasets/builder.py | 14 -------------- src/datasets/packaged_modules/cache/cache.py | 2 +- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index dae5347dd44..a30a879db15 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -1005,19 +1005,6 @@ def _get_filenames_for_split(self, split_name: str, dataset_name: str, file_form shard_lengths=shard_lengths, ) - def _check_manual_download(self, dl_manager): - if self.manual_download_instructions is not None and dl_manager.manual_dir is None: - raise ManualDownloadError( - textwrap.dedent( - f"""\ - The dataset {self.dataset_name} with config {self.config.name} requires manual data. - Please follow the manual download instructions: - {self.manual_download_instructions} - Manual data can be loaded with: - datasets.load_dataset("{self.repo_id or self.dataset_name}", data_dir="")""" - ) - ) - def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_kwargs): """Downloads and prepares dataset for reading. @@ -1316,7 +1303,6 @@ def as_streaming_dataset( dataset_name=self.dataset_name, data_dir=self.config.data_dir, ) - self._check_manual_download(dl_manager) splits_generators_kwargs = {} if self._supports_split_by_split_generation(): splits_generators_kwargs["splits"] = [split] if split else None diff --git a/src/datasets/packaged_modules/cache/cache.py b/src/datasets/packaged_modules/cache/cache.py index cab5cca58e1..0594061afdb 100644 --- a/src/datasets/packaged_modules/cache/cache.py +++ b/src/datasets/packaged_modules/cache/cache.py @@ -174,7 +174,7 @@ def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): dataset_name=self.dataset_name, split=split_info.name, filetype_suffix="arrow", - num_shards=len(split_info.shard_lengths or ()), + shard_lengths=split_info.shard_lengths or [], ) }, ) From 1022e48b431d950a79ec610791c46493388097d7 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Wed, 10 Dec 2025 12:57:21 +0100 Subject: [PATCH 12/12] WIP: try get tests passing --- src/datasets/builder.py | 10 ++++++++-- src/datasets/packaged_modules/cache/cache.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index a30a879db15..bb9c7f4eed0 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -835,8 +835,9 @@ def download_and_prepare( # File locking only with local paths; no file locking on GCS or S3 with FileLock(lock_path) if is_local else contextlib.nullcontext(): # Check if the data already exists - info_exists = self._fs.exists(posixpath.join(self._output_dir, config.DATASET_INFO_FILENAME)) - if info_exists: + data_exists = self._fs.exists(posixpath.join(self._output_dir, config.DATASET_INFO_FILENAME)) + if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: + logger.info(f"Found cached dataset {self.dataset_name} ({self._output_dir})") # We need to update the info in case some splits were added in the meantime # for example when calling load_dataset from multiple workers. self.info = self._load_info() @@ -1060,14 +1061,19 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_k if self._supports_split_by_split_generation(): split_infos = self.info.splits or {} ordered_split_infos = {} + downloaded_size = 0 for split_name in self._available_splits(): if split_name in split_dict: ordered_split_infos[split_name] = split_dict[split_name] + downloaded_size += ordered_split_infos[split_name].num_bytes or 0 elif split_name in split_infos: ordered_split_infos[split_name] = split_infos[split_name] + downloaded_size += ordered_split_infos[split_name].num_bytes or 0 self.info.splits = SplitDict.from_split_dict(ordered_split_infos, dataset_name=self.dataset_name) + self.info.download_size = downloaded_size else: self.info.splits = split_dict + self.info.download_size = dl_manager.downloaded_size def download_post_processing_resources(self, dl_manager): for split in self.info.splits or []: diff --git a/src/datasets/packaged_modules/cache/cache.py b/src/datasets/packaged_modules/cache/cache.py index 0594061afdb..62486fbeef7 100644 --- a/src/datasets/packaged_modules/cache/cache.py +++ b/src/datasets/packaged_modules/cache/cache.py @@ -174,7 +174,7 @@ def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): dataset_name=self.dataset_name, split=split_info.name, filetype_suffix="arrow", - shard_lengths=split_info.shard_lengths or [], + shard_lengths=split_info.shard_lengths, ) }, )