From 96acc0421efa2b57a5446c713df58a577f2cada1 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Wed, 3 Mar 2021 16:45:30 +0100 Subject: [PATCH] Add datasets full offline mode with HF_DATASETS_OFFLINE (#1976) * add HF_DATASETS_OFFLINE env var * fail network calls instantly if offline is true * update offline simulation for test env * update tests with new offline mode * minor * docs * Update test_file_utils.py --- docs/source/loading_datasets.rst | 14 +++++++++++++ src/datasets/config.py | 6 ++++++ src/datasets/utils/file_utils.py | 20 +++++++++++++++++- tests/test_dataset_common.py | 6 +++--- tests/test_file_utils.py | 36 +++++++++++++++++++++++++++++++- tests/test_load.py | 22 +++++++++---------- tests/test_offline_util.py | 14 ++++++++++--- tests/utils.py | 35 ++++++++++++++++++++++--------- 8 files changed, 124 insertions(+), 29 deletions(-) diff --git a/docs/source/loading_datasets.rst b/docs/source/loading_datasets.rst index 00f0509e0c7..612e811b7f9 100644 --- a/docs/source/loading_datasets.rst +++ b/docs/source/loading_datasets.rst @@ -417,3 +417,17 @@ For example, run the following to skip integrity verifications when loading the >>> from datasets import load_dataset >>> dataset = load_dataset('imdb', ignore_verifications=True) + + +Loading datasets offline +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Each dataset builder (e.g. "squad") is a python script that is downloaded and cached from either from the huggingface/datasets GitHub repository or from the `HuggingFace Hub `__. +Only the ``text``, ``csv``, ``json`` and ``pandas`` builders are included in ``datasets`` without requiring external downloads. + +Therefore if you don't have an internet connection you can't load a dataset that is not packaged with ``datasets``, unless the dataset is already cached. +Indeed, if you've already loaded the dataset once before (when you had an internet connection), then the dataset is reloaded from the cache and you can use it offline. + +You can even set the environment variable `HF_DATASETS_OFFLINE` to ``1`` to tell ``datasets`` to run in full offline mode. +This mode disables all the network calls of the library. +This way, instead of waiting for a dataset builder download to time out, the library looks directly at the cache. diff --git a/src/datasets/config.py b/src/datasets/config.py index 35c14b91db1..dcbbb405b88 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -100,3 +100,9 @@ # Batch size constants. For more info, see: # https://github.com/apache/arrow/blob/master/docs/source/cpp/arrays.rst#size-limitations-and-recommendations) DEFAULT_MAX_BATCH_SIZE = 10_000 + +HF_DATASETS_OFFLINE = os.environ.get("HF_DATASETS_OFFLINE", "AUTO").upper() +if HF_DATASETS_OFFLINE in ("1", "ON", "YES"): + HF_DATASETS_OFFLINE = True +else: + HF_DATASETS_OFFLINE = False diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 100ebf1d138..2e1ac647928 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -388,6 +388,18 @@ def get_authentication_headers_for_url(url: str, use_auth_token: Optional[Union[ return headers +class OfflineModeIsEnabled(ConnectionError): + pass + + +def _raise_if_offline_mode_is_enabled(msg: Optional[str] = None): + """Raise a OfflineModeIsEnabled error (subclass of ConnectionError) if HF_DATASETS_OFFLINE is True.""" + if config.HF_DATASETS_OFFLINE: + raise OfflineModeIsEnabled( + "Offline mode is enabled." if msg is None else "Offline mode is enabled. " + str(msg) + ) + + def _request_with_retry( method: str, url: str, @@ -397,7 +409,9 @@ def _request_with_retry( timeout: float = 10.0, **params, ) -> requests.Response: - """Wrapper around requests to retry in case it fails with a ConnectTimeout, with exponential backoff + """Wrapper around requests to retry in case it fails with a ConnectTimeout, with exponential backoff. + + Note that if the environment variable HF_DATASETS_OFFLINE is set to 1, then a OfflineModeIsEnabled error is raised. Args: method (str): HTTP method, such as 'GET' or 'HEAD' @@ -408,6 +422,7 @@ def _request_with_retry( max_wait_time (float): Maximum amount of time between two retries, in seconds **params: Params to pass to `requests.request` """ + _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") tries, success = 0, False while not success: tries += 1 @@ -425,6 +440,7 @@ def _request_with_retry( def ftp_head(url, timeout=10.0): + _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") try: with closing(urllib.request.urlopen(url, timeout=timeout)) as r: r.read(1) @@ -434,6 +450,7 @@ def ftp_head(url, timeout=10.0): def ftp_get(url, temp_file, proxies=None, resume_size=0, headers=None, cookies=None, timeout=10.0): + _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") try: logger.info(f"Getting through FTP {url} into {temp_file.name}") with closing(urllib.request.urlopen(url, timeout=timeout)) as r: @@ -595,6 +612,7 @@ def get_from_cache( ) elif response is not None and response.status_code == 404: raise FileNotFoundError("Couldn't find file at {}".format(url)) + _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") raise ConnectionError("Couldn't reach {}".format(url)) # Try a second time diff --git a/tests/test_dataset_common.py b/tests/test_dataset_common.py index cef1cee268c..b4b0402bde0 100644 --- a/tests/test_dataset_common.py +++ b/tests/test_dataset_common.py @@ -43,7 +43,7 @@ from datasets.utils.file_utils import is_remote_url from datasets.utils.logging import get_logger -from .utils import for_all_test_methods, local, offline, packaged, remote, slow +from .utils import OfflineSimulationMode, for_all_test_methods, local, offline, packaged, remote, slow logger = get_logger(__name__) @@ -281,8 +281,8 @@ def setUp(self): self.dataset_tester = DatasetTester(self) def test_load_dataset_offline(self, dataset_name): - for connection_times_out in (False, True): - with offline(connection_times_out=connection_times_out): + for offline_simulation_mode in list(OfflineSimulationMode): + with offline(offline_simulation_mode): configs = self.dataset_tester.load_all_configs(dataset_name)[:1] self.dataset_tester.check_load_dataset(dataset_name, configs, use_local_dummy_data=True) diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index 3ae262934c1..978d80c7742 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -1,11 +1,21 @@ import os from pathlib import Path from unittest import TestCase +from unittest.mock import patch import numpy as np import pytest -from datasets.utils.file_utils import DownloadConfig, cached_path, temp_seed +from datasets.utils.file_utils import ( + DownloadConfig, + OfflineModeIsEnabled, + cached_path, + ftp_get, + ftp_head, + http_get, + http_head, + temp_seed, +) from .utils import require_tf, require_torch @@ -92,3 +102,27 @@ def test_cached_path_missing_local(tmp_path): missing_file = "./__missing_file__.txt" with pytest.raises(FileNotFoundError): cached_path(missing_file) + + +@patch("datasets.config.HF_DATASETS_OFFLINE", True) +def test_cached_path_offline(): + with pytest.raises(OfflineModeIsEnabled): + cached_path("https://huggingface.co") + + +@patch("datasets.config.HF_DATASETS_OFFLINE", True) +def test_http_offline(tmp_path_factory): + filename = tmp_path_factory.mktemp("data") / "file.html" + with pytest.raises(OfflineModeIsEnabled): + http_get("https://huggingface.co", temp_file=filename) + with pytest.raises(OfflineModeIsEnabled): + http_head("https://huggingface.co") + + +@patch("datasets.config.HF_DATASETS_OFFLINE", True) +def test_ftp_offline(tmp_path_factory): + filename = tmp_path_factory.mktemp("data") / "file.html" + with pytest.raises(OfflineModeIsEnabled): + ftp_get("ftp://huggingface.co", temp_file=filename) + with pytest.raises(OfflineModeIsEnabled): + ftp_head("ftp://huggingface.co") diff --git a/tests/test_load.py b/tests/test_load.py index 86bf5563bc5..791f9d34734 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -14,7 +14,7 @@ import datasets from datasets import load_dataset -from .utils import offline +from .utils import OfflineSimulationMode, offline DATASET_LOADING_SCRIPT_NAME = "__dummy_dataset1__" @@ -113,8 +113,8 @@ def test_prepare_module(self): self.assertEqual(dummy_module.MY_DUMMY_VARIABLE, "general kenobi") self.assertEqual(module_hash, sha256(dummy_code.encode("utf-8")).hexdigest()) # missing module - for connection_times_out in (False, True): - with offline(connection_times_out=connection_times_out): + for offline_simulation_mode in list(OfflineSimulationMode): + with offline(offline_simulation_mode): with self.assertRaises((FileNotFoundError, ConnectionError, requests.exceptions.ConnectionError)): datasets.load.prepare_module( "__missing_dummy_module_name__", dynamic_modules_path=self.dynamic_modules_path @@ -133,8 +133,8 @@ def test_offline_prepare_module(self): importable_module_path2, _ = datasets.load.prepare_module( module_dir, dynamic_modules_path=self.dynamic_modules_path ) - for connection_times_out in (False, True): - with offline(connection_times_out=connection_times_out): + for offline_simulation_mode in list(OfflineSimulationMode): + with offline(offline_simulation_mode): self._caplog.clear() # allow provide the module name without an explicit path to remote or local actual file importable_module_path3, _ = datasets.load.prepare_module( @@ -158,8 +158,8 @@ def test_load_dataset_canonical(self): "https://raw.githubusercontent.com/huggingface/datasets/0.0.0/datasets/_dummy/_dummy.py", str(context.exception), ) - for connection_times_out in (False, True): - with offline(connection_times_out=connection_times_out): + for offline_simulation_mode in list(OfflineSimulationMode): + with offline(offline_simulation_mode): with self.assertRaises(ConnectionError) as context: datasets.load_dataset("_dummy") self.assertIn( @@ -174,8 +174,8 @@ def test_load_dataset_users(self): "https://huggingface.co/datasets/lhoestq/_dummy/resolve/main/_dummy.py", str(context.exception), ) - for connection_times_out in (False, True): - with offline(connection_times_out=connection_times_out): + for offline_simulation_mode in list(OfflineSimulationMode): + with offline(offline_simulation_mode): with self.assertRaises(ConnectionError) as context: datasets.load_dataset("lhoestq/_dummy") self.assertIn( @@ -191,8 +191,8 @@ def test_load_dataset_local(dataset_loading_script_dir, data_dir, keep_in_memory increased_allocated_memory = (pa.total_allocated_bytes() - previous_allocated_memory) > 0 assert len(dataset) == 2 assert increased_allocated_memory == keep_in_memory - for connection_times_out in (False, True): - with offline(connection_times_out=connection_times_out): + for offline_simulation_mode in list(OfflineSimulationMode): + with offline(offline_simulation_mode): caplog.clear() # Load dataset from cache dataset = datasets.load_dataset(DATASET_LOADING_SCRIPT_NAME, data_dir=data_dir) diff --git a/tests/test_offline_util.py b/tests/test_offline_util.py index 0d0848c244f..467f519066e 100644 --- a/tests/test_offline_util.py +++ b/tests/test_offline_util.py @@ -1,11 +1,13 @@ import pytest import requests -from .utils import RequestWouldHangIndefinitelyError, offline +from datasets.utils.file_utils import http_head + +from .utils import OfflineSimulationMode, RequestWouldHangIndefinitelyError, offline def test_offline_with_timeout(): - with offline(connection_times_out=True): + with offline(OfflineSimulationMode.CONNECTION_TIMES_OUT): with pytest.raises(RequestWouldHangIndefinitelyError): requests.request("GET", "https://huggingface.co") with pytest.raises(requests.exceptions.ConnectTimeout): @@ -13,6 +15,12 @@ def test_offline_with_timeout(): def test_offline_with_connection_error(): - with offline(connection_times_out=False): + with offline(OfflineSimulationMode.CONNECTION_FAILS): with pytest.raises(requests.exceptions.ConnectionError): requests.request("GET", "https://huggingface.co") + + +def test_offline_with_datasets_offline_mode_enabled(): + with offline(OfflineSimulationMode.HF_DATASETS_OFFLINE_SET_TO_1): + with pytest.raises(ConnectionError): + http_head("https://huggingface.co") diff --git a/tests/utils.py b/tests/utils.py index 4c8562d15a1..c9e2af0e9f8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ import unittest from contextlib import contextmanager from distutils.util import strtobool +from enum import Enum from pathlib import Path from unittest.mock import patch @@ -189,17 +190,26 @@ class RequestWouldHangIndefinitelyError(Exception): pass +class OfflineSimulationMode(Enum): + CONNECTION_FAILS = 0 + CONNECTION_TIMES_OUT = 1 + HF_DATASETS_OFFLINE_SET_TO_1 = 2 + + @contextmanager -def offline(connection_times_out=False, timeout=1e-16): +def offline(mode=OfflineSimulationMode.CONNECTION_FAILS, timeout=1e-16): """ Simulate offline mode. - By default a ConnectionError is raised for each network call. - With connection_times_out=True on the other hand, the connection hangs until it times out. - The default timeout value is low (1e-16) to speed up the tests. + There are three offline simulatiom modes: - Connection errors are created by mocking socket.socket, - while the timeout errors are created by mocking requests.request. + CONNECTION_FAILS (default mode): a ConnectionError is raised for each network call. + Connection errors are created by mocking socket.socket + CONNECTION_TIMES_OUT: the connection hangs until it times out. + The default timeout value is low (1e-16) to speed up the tests. + Timeout errors are created by mocking requests.request + HF_DATASETS_OFFLINE_SET_TO_1: the HF_DATASETS_OFFLINE environment variable is set to 1. + This makes the http/ftp calls of the library instantly fail and raise an OfflineModeEmabled error. """ import socket @@ -226,14 +236,19 @@ def timeout_request(method, url, **kwargs): def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled.") - if connection_times_out: + if mode is OfflineSimulationMode.CONNECTION_FAILS: + # inspired from https://stackoverflow.com/a/18601897 + with patch("socket.socket", offline_socket): + yield + elif mode is OfflineSimulationMode.CONNECTION_TIMES_OUT: # inspired from https://stackoverflow.com/a/904609 with patch("requests.request", timeout_request): yield - else: - # inspired from https://stackoverflow.com/a/18601897 - with patch("socket.socket", offline_socket): + elif mode is OfflineSimulationMode.HF_DATASETS_OFFLINE_SET_TO_1: + with patch("datasets.config.HF_DATASETS_OFFLINE", True): yield + else: + raise ValueError("Please use a value from the OfflineSimulationMode enum.") @contextmanager