From 3648752f1259c84e115fc3ed43bb8d3f8e7fa7fd Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue, 8 Aug 2023 17:16:21 +0200 Subject: [PATCH] Fix authentication issues (#6127) * Fix hf_token fixture * Do not store token but pass it explicitly * Fix test with no token * Fix style * Test private load_dataset_builder and get_dataset_config_info * Fix DownloadConfig to pass token to storage_options * Set config HUB_DATASETS_HFFS_URL * Use HUB_DATASETS_HFFS_URL in Audio/Image decode_example * Pass download_config create_builder_configs_from_metadata_configs --- src/datasets/config.py | 1 + src/datasets/download/download_config.py | 8 ++++++++ src/datasets/features/audio.py | 5 ++++- src/datasets/features/image.py | 7 ++++++- src/datasets/load.py | 3 +++ tests/fixtures/hub.py | 6 +----- tests/test_inspect.py | 5 +++++ tests/test_load.py | 19 +++++++++++-------- tests/test_upstream_hub.py | 6 ++++-- 9 files changed, 43 insertions(+), 17 deletions(-) diff --git a/src/datasets/config.py b/src/datasets/config.py index 08644c895b0..1a1d68d39c2 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -24,6 +24,7 @@ # Hub HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") HUB_DATASETS_URL = HF_ENDPOINT + "/datasets/{repo_id}/resolve/{revision}/{path}" +HUB_DATASETS_HFFS_URL = "hf://datasets/{repo_id}@{revision}/{path}" HUB_DEFAULT_VERSION = "main" PY_VERSION = version.parse(platform.python_version()) diff --git a/src/datasets/download/download_config.py b/src/datasets/download/download_config.py index 27603f4953a..8ba032f75ba 100644 --- a/src/datasets/download/download_config.py +++ b/src/datasets/download/download_config.py @@ -92,3 +92,11 @@ def __post_init__(self, use_auth_token): def copy(self) -> "DownloadConfig": return self.__class__(**{k: copy.deepcopy(v) for k, v in self.__dict__.items()}) + + def __setattr__(self, name, value): + if name == "token" and getattr(self, "storage_options", None) is not None: + if "hf" not in self.storage_options: + self.storage_options["hf"] = {"token": value, "endpoint": config.HF_ENDPOINT} + elif getattr(self.storage_options["hf"], "token", None) is None: + self.storage_options["hf"]["token"] = value + super().__setattr__(name, value) diff --git a/src/datasets/features/audio.py b/src/datasets/features/audio.py index 69c85a2e03a..b5d517cd955 100644 --- a/src/datasets/features/audio.py +++ b/src/datasets/features/audio.py @@ -173,8 +173,11 @@ def decode_example( if file is None: token_per_repo_id = token_per_repo_id or {} source_url = path.split("::")[-1] + pattern = ( + config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL + ) try: - repo_id = string_to_dict(source_url, config.HUB_DATASETS_URL)["repo_id"] + repo_id = string_to_dict(source_url, pattern)["repo_id"] token = token_per_repo_id[repo_id] except (ValueError, KeyError): token = None diff --git a/src/datasets/features/image.py b/src/datasets/features/image.py index 212e6c234d7..86ee6eb646b 100644 --- a/src/datasets/features/image.py +++ b/src/datasets/features/image.py @@ -166,8 +166,13 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "PIL.Image.Imag image = PIL.Image.open(path) else: source_url = path.split("::")[-1] + pattern = ( + config.HUB_DATASETS_URL + if source_url.startswith(config.HF_ENDPOINT) + else config.HUB_DATASETS_HFFS_URL + ) try: - repo_id = string_to_dict(source_url, config.HUB_DATASETS_URL)["repo_id"] + repo_id = string_to_dict(source_url, pattern)["repo_id"] token = token_per_repo_id.get(repo_id) except ValueError: token = None diff --git a/src/datasets/load.py b/src/datasets/load.py index e39800df3a1..1f9ebc0e69b 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -539,6 +539,7 @@ def create_builder_configs_from_metadata_configs( base_path: Optional[str] = None, default_builder_kwargs: Dict[str, Any] = None, allowed_extensions: Optional[List[str]] = None, + download_config: Optional[DownloadConfig] = None, ) -> Tuple[List[BuilderConfig], str]: builder_cls = import_main_class(module_path) builder_config_cls = builder_cls.BUILDER_CONFIG_CLASS @@ -560,6 +561,7 @@ def create_builder_configs_from_metadata_configs( config_patterns, base_path=config_base_path, allowed_extensions=ALL_ALLOWED_EXTENSIONS, + download_config=download_config, ) except EmptyDatasetError as e: raise EmptyDatasetError( @@ -1070,6 +1072,7 @@ def get_module(self) -> DatasetModule: base_path=base_path, supports_metadata=supports_metadata, default_builder_kwargs=default_builder_kwargs, + download_config=self.download_config, ) else: builder_configs, default_config_name = None, None diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index 28c54d91089..bda110406c0 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -48,12 +48,8 @@ def hf_api(): @pytest.fixture(scope="session") -def hf_token(hf_api: HfApi): - previous_token = HfFolder.get_token() - HfFolder.save_token(CI_HUB_USER_TOKEN) +def hf_token(): yield CI_HUB_USER_TOKEN - if previous_token is not None: - HfFolder.save_token(previous_token) @pytest.fixture diff --git a/tests/test_inspect.py b/tests/test_inspect.py index 034cbd57918..c6f578baa3a 100644 --- a/tests/test_inspect.py +++ b/tests/test_inspect.py @@ -47,6 +47,11 @@ def test_get_dataset_config_info(path, config_name, expected_splits): assert list(info.splits.keys()) == expected_splits +def test_get_dataset_config_info_private(hf_token, hf_private_dataset_repo_txt_data): + info = get_dataset_config_info(hf_private_dataset_repo_txt_data, config_name="default", token=hf_token) + assert list(info.splits.keys()) == ["train"] + + @pytest.mark.parametrize( "path, config_name, expected_exception", [ diff --git a/tests/test_load.py b/tests/test_load.py index 880912596c5..0b48ed56a69 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -38,6 +38,7 @@ PackagedDatasetModuleFactory, infer_module_for_data_files_list, infer_module_for_data_files_list_in_archives, + load_dataset_builder, ) from datasets.packaged_modules.audiofolder.audiofolder import AudioFolder, AudioFolderConfig from datasets.packaged_modules.imagefolder.imagefolder import ImageFolder, ImageFolderConfig @@ -1223,13 +1224,19 @@ def assert_auth(method, url, *args, headers, **kwargs): @pytest.mark.integration def test_load_streaming_private_dataset(hf_token, hf_private_dataset_repo_txt_data): - ds = load_dataset(hf_private_dataset_repo_txt_data, streaming=True) + ds = load_dataset(hf_private_dataset_repo_txt_data, streaming=True, token=hf_token) assert next(iter(ds)) is not None +@pytest.mark.integration +def test_load_dataset_builder_private_dataset(hf_token, hf_private_dataset_repo_txt_data): + builder = load_dataset_builder(hf_private_dataset_repo_txt_data, token=hf_token) + assert isinstance(builder, DatasetBuilder) + + @pytest.mark.integration def test_load_streaming_private_dataset_with_zipped_data(hf_token, hf_private_dataset_repo_zipped_txt_data): - ds = load_dataset(hf_private_dataset_repo_zipped_txt_data, streaming=True) + ds = load_dataset(hf_private_dataset_repo_zipped_txt_data, streaming=True, token=hf_token) assert next(iter(ds)) is not None @@ -1309,13 +1316,9 @@ def test_load_hub_dataset_without_script_with_metadata_config_in_parallel(): @require_pil @pytest.mark.integration -@pytest.mark.parametrize("implicit_token", [True]) @pytest.mark.parametrize("streaming", [True]) -def test_load_dataset_private_zipped_images( - hf_private_dataset_repo_zipped_img_data, hf_token, streaming, implicit_token -): - token = None if implicit_token else hf_token - ds = load_dataset(hf_private_dataset_repo_zipped_img_data, split="train", streaming=streaming, token=token) +def test_load_dataset_private_zipped_images(hf_private_dataset_repo_zipped_img_data, hf_token, streaming): + ds = load_dataset(hf_private_dataset_repo_zipped_img_data, split="train", streaming=streaming, token=hf_token) assert isinstance(ds, IterableDataset if streaming else Dataset) ds_items = list(ds) assert len(ds_items) == 2 diff --git a/tests/test_upstream_hub.py b/tests/test_upstream_hub.py index 7ab1de5d2ec..1c722c65446 100644 --- a/tests/test_upstream_hub.py +++ b/tests/test_upstream_hub.py @@ -33,12 +33,12 @@ @for_all_test_methods(xfail_if_500_502_http_error) -@pytest.mark.usefixtures("set_ci_hub_access_token", "ci_hfh_hf_hub_url") +@pytest.mark.usefixtures("ci_hub_config", "ci_hfh_hf_hub_url") class TestPushToHub: _api = HfApi(endpoint=CI_HUB_ENDPOINT) _token = CI_HUB_USER_TOKEN - def test_push_dataset_dict_to_hub_no_token(self, temporary_repo): + def test_push_dataset_dict_to_hub_no_token(self, temporary_repo, set_ci_hub_access_token): ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) local_ds = DatasetDict({"train": ds}) @@ -778,6 +778,7 @@ def test_push_dataset_to_hub_with_config_no_metadata_configs(self, temporary_rep path_in_repo="data/train-00000-of-00001.parquet", repo_id=ds_name, repo_type="dataset", + token=self._token, ) ds_another_config.push_to_hub(ds_name, "another_config", token=self._token) ds_builder = load_dataset_builder(ds_name, download_mode="force_redownload") @@ -811,6 +812,7 @@ def test_push_dataset_dict_to_hub_with_config_no_metadata_configs(self, temporar path_in_repo="data/random-00000-of-00001.parquet", repo_id=ds_name, repo_type="dataset", + token=self._token, ) local_ds_another_config.push_to_hub(ds_name, "another_config", token=self._token) ds_builder = load_dataset_builder(ds_name, download_mode="force_redownload")