Skip to content

Commit

Permalink
Add datasets full offline mode with HF_DATASETS_OFFLINE (#1976)
Browse files Browse the repository at this point in the history
* add HF_DATASETS_OFFLINE env var

* fail network calls instantly if offline is true

* update offline simulation for test env

* update tests with new offline  mode

* minor

* docs

* Update test_file_utils.py
  • Loading branch information
lhoestq committed Mar 3, 2021
1 parent d5afa3c commit 96acc04
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 29 deletions.
14 changes: 14 additions & 0 deletions docs/source/loading_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,17 @@ For example, run the following to skip integrity verifications when loading the
>>> from datasets import load_dataset
>>> dataset = load_dataset('imdb', ignore_verifications=True)
Loading datasets offline
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Each dataset builder (e.g. "squad") is a python script that is downloaded and cached from either from the huggingface/datasets GitHub repository or from the `HuggingFace Hub <https://huggingface.co/datasets>`__.
Only the ``text``, ``csv``, ``json`` and ``pandas`` builders are included in ``datasets`` without requiring external downloads.

Therefore if you don't have an internet connection you can't load a dataset that is not packaged with ``datasets``, unless the dataset is already cached.
Indeed, if you've already loaded the dataset once before (when you had an internet connection), then the dataset is reloaded from the cache and you can use it offline.

You can even set the environment variable `HF_DATASETS_OFFLINE` to ``1`` to tell ``datasets`` to run in full offline mode.
This mode disables all the network calls of the library.
This way, instead of waiting for a dataset builder download to time out, the library looks directly at the cache.
6 changes: 6 additions & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,9 @@
# Batch size constants. For more info, see:
# https://github.com/apache/arrow/blob/master/docs/source/cpp/arrays.rst#size-limitations-and-recommendations)
DEFAULT_MAX_BATCH_SIZE = 10_000

HF_DATASETS_OFFLINE = os.environ.get("HF_DATASETS_OFFLINE", "AUTO").upper()
if HF_DATASETS_OFFLINE in ("1", "ON", "YES"):
HF_DATASETS_OFFLINE = True
else:
HF_DATASETS_OFFLINE = False
20 changes: 19 additions & 1 deletion src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,18 @@ def get_authentication_headers_for_url(url: str, use_auth_token: Optional[Union[
return headers


class OfflineModeIsEnabled(ConnectionError):
pass


def _raise_if_offline_mode_is_enabled(msg: Optional[str] = None):
"""Raise a OfflineModeIsEnabled error (subclass of ConnectionError) if HF_DATASETS_OFFLINE is True."""
if config.HF_DATASETS_OFFLINE:
raise OfflineModeIsEnabled(
"Offline mode is enabled." if msg is None else "Offline mode is enabled. " + str(msg)
)


def _request_with_retry(
method: str,
url: str,
Expand All @@ -397,7 +409,9 @@ def _request_with_retry(
timeout: float = 10.0,
**params,
) -> requests.Response:
"""Wrapper around requests to retry in case it fails with a ConnectTimeout, with exponential backoff
"""Wrapper around requests to retry in case it fails with a ConnectTimeout, with exponential backoff.
Note that if the environment variable HF_DATASETS_OFFLINE is set to 1, then a OfflineModeIsEnabled error is raised.
Args:
method (str): HTTP method, such as 'GET' or 'HEAD'
Expand All @@ -408,6 +422,7 @@ def _request_with_retry(
max_wait_time (float): Maximum amount of time between two retries, in seconds
**params: Params to pass to `requests.request`
"""
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
tries, success = 0, False
while not success:
tries += 1
Expand All @@ -425,6 +440,7 @@ def _request_with_retry(


def ftp_head(url, timeout=10.0):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
try:
with closing(urllib.request.urlopen(url, timeout=timeout)) as r:
r.read(1)
Expand All @@ -434,6 +450,7 @@ def ftp_head(url, timeout=10.0):


def ftp_get(url, temp_file, proxies=None, resume_size=0, headers=None, cookies=None, timeout=10.0):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
try:
logger.info(f"Getting through FTP {url} into {temp_file.name}")
with closing(urllib.request.urlopen(url, timeout=timeout)) as r:
Expand Down Expand Up @@ -595,6 +612,7 @@ def get_from_cache(
)
elif response is not None and response.status_code == 404:
raise FileNotFoundError("Couldn't find file at {}".format(url))
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
raise ConnectionError("Couldn't reach {}".format(url))

# Try a second time
Expand Down
6 changes: 3 additions & 3 deletions tests/test_dataset_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from datasets.utils.file_utils import is_remote_url
from datasets.utils.logging import get_logger

from .utils import for_all_test_methods, local, offline, packaged, remote, slow
from .utils import OfflineSimulationMode, for_all_test_methods, local, offline, packaged, remote, slow


logger = get_logger(__name__)
Expand Down Expand Up @@ -281,8 +281,8 @@ def setUp(self):
self.dataset_tester = DatasetTester(self)

def test_load_dataset_offline(self, dataset_name):
for connection_times_out in (False, True):
with offline(connection_times_out=connection_times_out):
for offline_simulation_mode in list(OfflineSimulationMode):
with offline(offline_simulation_mode):
configs = self.dataset_tester.load_all_configs(dataset_name)[:1]
self.dataset_tester.check_load_dataset(dataset_name, configs, use_local_dummy_data=True)

Expand Down
36 changes: 35 additions & 1 deletion tests/test_file_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import os
from pathlib import Path
from unittest import TestCase
from unittest.mock import patch

import numpy as np
import pytest

from datasets.utils.file_utils import DownloadConfig, cached_path, temp_seed
from datasets.utils.file_utils import (
DownloadConfig,
OfflineModeIsEnabled,
cached_path,
ftp_get,
ftp_head,
http_get,
http_head,
temp_seed,
)

from .utils import require_tf, require_torch

Expand Down Expand Up @@ -92,3 +102,27 @@ def test_cached_path_missing_local(tmp_path):
missing_file = "./__missing_file__.txt"
with pytest.raises(FileNotFoundError):
cached_path(missing_file)


@patch("datasets.config.HF_DATASETS_OFFLINE", True)
def test_cached_path_offline():
with pytest.raises(OfflineModeIsEnabled):
cached_path("https://huggingface.co")


@patch("datasets.config.HF_DATASETS_OFFLINE", True)
def test_http_offline(tmp_path_factory):
filename = tmp_path_factory.mktemp("data") / "file.html"
with pytest.raises(OfflineModeIsEnabled):
http_get("https://huggingface.co", temp_file=filename)
with pytest.raises(OfflineModeIsEnabled):
http_head("https://huggingface.co")


@patch("datasets.config.HF_DATASETS_OFFLINE", True)
def test_ftp_offline(tmp_path_factory):
filename = tmp_path_factory.mktemp("data") / "file.html"
with pytest.raises(OfflineModeIsEnabled):
ftp_get("ftp://huggingface.co", temp_file=filename)
with pytest.raises(OfflineModeIsEnabled):
ftp_head("ftp://huggingface.co")
22 changes: 11 additions & 11 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import datasets
from datasets import load_dataset

from .utils import offline
from .utils import OfflineSimulationMode, offline


DATASET_LOADING_SCRIPT_NAME = "__dummy_dataset1__"
Expand Down Expand Up @@ -113,8 +113,8 @@ def test_prepare_module(self):
self.assertEqual(dummy_module.MY_DUMMY_VARIABLE, "general kenobi")
self.assertEqual(module_hash, sha256(dummy_code.encode("utf-8")).hexdigest())
# missing module
for connection_times_out in (False, True):
with offline(connection_times_out=connection_times_out):
for offline_simulation_mode in list(OfflineSimulationMode):
with offline(offline_simulation_mode):
with self.assertRaises((FileNotFoundError, ConnectionError, requests.exceptions.ConnectionError)):
datasets.load.prepare_module(
"__missing_dummy_module_name__", dynamic_modules_path=self.dynamic_modules_path
Expand All @@ -133,8 +133,8 @@ def test_offline_prepare_module(self):
importable_module_path2, _ = datasets.load.prepare_module(
module_dir, dynamic_modules_path=self.dynamic_modules_path
)
for connection_times_out in (False, True):
with offline(connection_times_out=connection_times_out):
for offline_simulation_mode in list(OfflineSimulationMode):
with offline(offline_simulation_mode):
self._caplog.clear()
# allow provide the module name without an explicit path to remote or local actual file
importable_module_path3, _ = datasets.load.prepare_module(
Expand All @@ -158,8 +158,8 @@ def test_load_dataset_canonical(self):
"https://raw.githubusercontent.com/huggingface/datasets/0.0.0/datasets/_dummy/_dummy.py",
str(context.exception),
)
for connection_times_out in (False, True):
with offline(connection_times_out=connection_times_out):
for offline_simulation_mode in list(OfflineSimulationMode):
with offline(offline_simulation_mode):
with self.assertRaises(ConnectionError) as context:
datasets.load_dataset("_dummy")
self.assertIn(
Expand All @@ -174,8 +174,8 @@ def test_load_dataset_users(self):
"https://huggingface.co/datasets/lhoestq/_dummy/resolve/main/_dummy.py",
str(context.exception),
)
for connection_times_out in (False, True):
with offline(connection_times_out=connection_times_out):
for offline_simulation_mode in list(OfflineSimulationMode):
with offline(offline_simulation_mode):
with self.assertRaises(ConnectionError) as context:
datasets.load_dataset("lhoestq/_dummy")
self.assertIn(
Expand All @@ -191,8 +191,8 @@ def test_load_dataset_local(dataset_loading_script_dir, data_dir, keep_in_memory
increased_allocated_memory = (pa.total_allocated_bytes() - previous_allocated_memory) > 0
assert len(dataset) == 2
assert increased_allocated_memory == keep_in_memory
for connection_times_out in (False, True):
with offline(connection_times_out=connection_times_out):
for offline_simulation_mode in list(OfflineSimulationMode):
with offline(offline_simulation_mode):
caplog.clear()
# Load dataset from cache
dataset = datasets.load_dataset(DATASET_LOADING_SCRIPT_NAME, data_dir=data_dir)
Expand Down
14 changes: 11 additions & 3 deletions tests/test_offline_util.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
import pytest
import requests

from .utils import RequestWouldHangIndefinitelyError, offline
from datasets.utils.file_utils import http_head

from .utils import OfflineSimulationMode, RequestWouldHangIndefinitelyError, offline


def test_offline_with_timeout():
with offline(connection_times_out=True):
with offline(OfflineSimulationMode.CONNECTION_TIMES_OUT):
with pytest.raises(RequestWouldHangIndefinitelyError):
requests.request("GET", "https://huggingface.co")
with pytest.raises(requests.exceptions.ConnectTimeout):
requests.request("GET", "https://huggingface.co", timeout=1.0)


def test_offline_with_connection_error():
with offline(connection_times_out=False):
with offline(OfflineSimulationMode.CONNECTION_FAILS):
with pytest.raises(requests.exceptions.ConnectionError):
requests.request("GET", "https://huggingface.co")


def test_offline_with_datasets_offline_mode_enabled():
with offline(OfflineSimulationMode.HF_DATASETS_OFFLINE_SET_TO_1):
with pytest.raises(ConnectionError):
http_head("https://huggingface.co")
35 changes: 25 additions & 10 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest
from contextlib import contextmanager
from distutils.util import strtobool
from enum import Enum
from pathlib import Path
from unittest.mock import patch

Expand Down Expand Up @@ -189,17 +190,26 @@ class RequestWouldHangIndefinitelyError(Exception):
pass


class OfflineSimulationMode(Enum):
CONNECTION_FAILS = 0
CONNECTION_TIMES_OUT = 1
HF_DATASETS_OFFLINE_SET_TO_1 = 2


@contextmanager
def offline(connection_times_out=False, timeout=1e-16):
def offline(mode=OfflineSimulationMode.CONNECTION_FAILS, timeout=1e-16):
"""
Simulate offline mode.
By default a ConnectionError is raised for each network call.
With connection_times_out=True on the other hand, the connection hangs until it times out.
The default timeout value is low (1e-16) to speed up the tests.
There are three offline simulatiom modes:
Connection errors are created by mocking socket.socket,
while the timeout errors are created by mocking requests.request.
CONNECTION_FAILS (default mode): a ConnectionError is raised for each network call.
Connection errors are created by mocking socket.socket
CONNECTION_TIMES_OUT: the connection hangs until it times out.
The default timeout value is low (1e-16) to speed up the tests.
Timeout errors are created by mocking requests.request
HF_DATASETS_OFFLINE_SET_TO_1: the HF_DATASETS_OFFLINE environment variable is set to 1.
This makes the http/ftp calls of the library instantly fail and raise an OfflineModeEmabled error.
"""
import socket

Expand All @@ -226,14 +236,19 @@ def timeout_request(method, url, **kwargs):
def offline_socket(*args, **kwargs):
raise socket.error("Offline mode is enabled.")

if connection_times_out:
if mode is OfflineSimulationMode.CONNECTION_FAILS:
# inspired from https://stackoverflow.com/a/18601897
with patch("socket.socket", offline_socket):
yield
elif mode is OfflineSimulationMode.CONNECTION_TIMES_OUT:
# inspired from https://stackoverflow.com/a/904609
with patch("requests.request", timeout_request):
yield
else:
# inspired from https://stackoverflow.com/a/18601897
with patch("socket.socket", offline_socket):
elif mode is OfflineSimulationMode.HF_DATASETS_OFFLINE_SET_TO_1:
with patch("datasets.config.HF_DATASETS_OFFLINE", True):
yield
else:
raise ValueError("Please use a value from the OfflineSimulationMode enum.")


@contextmanager
Expand Down

1 comment on commit 96acc04

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==0.17.1

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.018656 / 0.011353 (0.007303) 0.017515 / 0.011008 (0.006507) 0.048794 / 0.038508 (0.010286) 0.035176 / 0.023109 (0.012067) 0.219401 / 0.275898 (-0.056497) 0.247244 / 0.323480 (-0.076236) 0.006604 / 0.007986 (-0.001382) 0.004815 / 0.004328 (0.000486) 0.007091 / 0.004250 (0.002841) 0.051598 / 0.037052 (0.014546) 0.226370 / 0.258489 (-0.032119) 0.260555 / 0.293841 (-0.033286) 0.176916 / 0.128546 (0.048370) 0.152552 / 0.075646 (0.076905) 0.455350 / 0.419271 (0.036079) 0.427167 / 0.043533 (0.383634) 0.235497 / 0.255139 (-0.019642) 0.249633 / 0.283200 (-0.033567) 1.774224 / 0.141683 (1.632541) 1.869541 / 1.452155 (0.417386) 2.042894 / 1.492716 (0.550177)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.044886 / 0.037411 (0.007475) 0.019249 / 0.014526 (0.004724) 0.028205 / 0.176557 (-0.148352) 0.096856 / 0.737135 (-0.640279) 0.033016 / 0.296338 (-0.263323)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.282260 / 0.215209 (0.067051) 2.684108 / 2.077655 (0.606454) 1.395072 / 1.504120 (-0.109047) 1.225713 / 1.541195 (-0.315482) 1.232509 / 1.468490 (-0.235982) 7.679576 / 4.584777 (3.094799) 6.691820 / 3.745712 (2.946108) 9.388986 / 5.269862 (4.119124) 8.252658 / 4.565676 (3.686981) 0.769267 / 0.424275 (0.344992) 0.010794 / 0.007607 (0.003187) 0.303448 / 0.226044 (0.077403) 3.236900 / 2.268929 (0.967972) 1.881774 / 55.444624 (-53.562850) 1.615333 / 6.876477 (-5.261144) 1.633646 / 2.142072 (-0.508427) 7.607424 / 4.805227 (2.802197) 5.603401 / 6.500664 (-0.897263) 10.239433 / 0.075469 (10.163964)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 12.098088 / 1.841788 (10.256301) 15.437794 / 8.074308 (7.363486) 23.157745 / 10.191392 (12.966353) 0.516379 / 0.680424 (-0.164045) 0.319442 / 0.534201 (-0.214759) 0.863782 / 0.579283 (0.284499) 0.692088 / 0.434364 (0.257724) 0.782711 / 0.540337 (0.242373) 1.748296 / 1.386936 (0.361360)
PyArrow==1.0
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.020605 / 0.011353 (0.009252) 0.016813 / 0.011008 (0.005805) 0.047144 / 0.038508 (0.008636) 0.034490 / 0.023109 (0.011380) 0.369271 / 0.275898 (0.093373) 0.401985 / 0.323480 (0.078505) 0.007004 / 0.007986 (-0.000982) 0.004744 / 0.004328 (0.000416) 0.006616 / 0.004250 (0.002365) 0.056227 / 0.037052 (0.019174) 0.362513 / 0.258489 (0.104024) 0.429114 / 0.293841 (0.135273) 0.173696 / 0.128546 (0.045150) 0.137160 / 0.075646 (0.061514) 0.453726 / 0.419271 (0.034454) 0.467593 / 0.043533 (0.424060) 0.354628 / 0.255139 (0.099489) 0.394698 / 0.283200 (0.111498) 1.813389 / 0.141683 (1.671706) 1.913925 / 1.452155 (0.461770) 1.984044 / 1.492716 (0.491328)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.043243 / 0.037411 (0.005832) 0.022020 / 0.014526 (0.007495) 0.027350 / 0.176557 (-0.149207) 0.048941 / 0.737135 (-0.688194) 0.078561 / 0.296338 (-0.217777)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.336883 / 0.215209 (0.121674) 3.475078 / 2.077655 (1.397424) 2.161452 / 1.504120 (0.657332) 1.993838 / 1.541195 (0.452644) 2.066361 / 1.468490 (0.597871) 7.428588 / 4.584777 (2.843811) 6.504262 / 3.745712 (2.758550) 9.126560 / 5.269862 (3.856698) 8.027257 / 4.565676 (3.461580) 0.723168 / 0.424275 (0.298892) 0.011048 / 0.007607 (0.003441) 0.374627 / 0.226044 (0.148583) 3.823416 / 2.268929 (1.554487) 2.584518 / 55.444624 (-52.860107) 2.225079 / 6.876477 (-4.651398) 2.253029 / 2.142072 (0.110956) 7.453804 / 4.805227 (2.648577) 6.607841 / 6.500664 (0.107177) 6.850517 / 0.075469 (6.775047)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 12.352724 / 1.841788 (10.510936) 14.440168 / 8.074308 (6.365860) 23.207600 / 10.191392 (13.016208) 0.944495 / 0.680424 (0.264071) 0.620055 / 0.534201 (0.085854) 0.839931 / 0.579283 (0.260647) 0.662678 / 0.434364 (0.228314) 0.757645 / 0.540337 (0.217308) 1.682810 / 1.386936 (0.295874)

CML watermark

Please sign in to comment.