Skip to content

Commit

Permalink
Enforce DATADIR to be type pathlib.Path (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
xuhdev authored and GitHub Enterprise committed Nov 11, 2020
1 parent d9610ce commit 4c8cfbd
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 6 deletions.
3 changes: 2 additions & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[mypy]
disallow_untyped_defs = True
show_error_codes = True
show_error_codes = True
plugins = pydantic.mypy
7 changes: 3 additions & 4 deletions pydax/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,17 @@
"Module for defining and modifying global configs"


from dataclasses import dataclass
import pathlib

from . import _typing
from pydantic.dataclasses import dataclass


@dataclass(frozen=True)
class Config:
"""Global read-only configurations for PyDAX.
"""
# DATADIR is the default dir where datasets files are downloaded/loaded to/from.
DATADIR: _typing.PathLike = pathlib.Path.home() / '.pydax' / 'data'
DATADIR: pathlib.Path = pathlib.Path.home() / '.pydax' / 'data'


def get_config() -> Config:
Expand All @@ -39,7 +38,7 @@ def get_config() -> Config:
return global_config # type: ignore [name-defined]


def init(**kwargs: _typing.PathLike) -> None:
def init(**kwargs: pathlib.Path) -> None:
"""
(Re-)initialize the PyDAX library. This includes updating PyDAX global configs.
Expand Down
2 changes: 1 addition & 1 deletion pydax/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def load_dataset(name: str, *,

schema = load_schemata().schemata['dataset_schema'].export_schema('datasets', name, version)

data_dir = pathlib.Path(get_config().DATADIR) / name / version # TODO issue 646
data_dir = get_config().DATADIR / name / version
dataset = Dataset(schema=schema, data_dir=data_dir, mode=Dataset.InitializationMode.LAZY)
if download and not dataset.is_downloaded():
dataset.download()
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
install_requires=[
"packaging >= 20.4",
"pandas >= 1.1.0",
"pydantic >= 1.7.2",
"PyYAML >= 5.3.1",
"requests >= 2.24.0"],
classifiers=[
Expand Down
17 changes: 17 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
#

import pathlib
import re

import pytest
from pydantic import ValidationError

from pydax import get_config, init
from pydax.dataset import Dataset
Expand All @@ -25,12 +29,25 @@ def test_default_data_dir(wikitext103_schema):

pydax_data_home = pathlib.Path.home() / '.pydax' / 'data'
assert get_config().DATADIR == pydax_data_home
assert isinstance(get_config().DATADIR, pathlib.Path)


def test_custom_data_dir(tmp_path, wikitext103_schema):
"Test to make sure Dataset constructor uses new global data dir if one was supplied earlier to pydax.init."

init(DATADIR=tmp_path)
assert get_config().DATADIR == tmp_path
assert isinstance(get_config().DATADIR, pathlib.Path)
wikitext = Dataset(wikitext103_schema, data_dir=tmp_path, mode=Dataset.InitializationMode.LAZY)
assert wikitext._data_dir == tmp_path
assert isinstance(wikitext._data_dir, pathlib.Path)


def test_non_path_data_dir():
"Test exception when a nonpath is passed as DATADIR."

with pytest.raises(ValidationError) as e:
init(DATADIR=10)

assert re.search(r"1 validation error for Config\s+DATADIR\s+value is not a valid path \(type=type_error.path\)",
str(e.value))

0 comments on commit 4c8cfbd

Please sign in to comment.