diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1dad050..8162bf5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: ["3.8", "3.9", "3.10"] + python: ["3.8", "3.9", "3.10", "3.11", "3.12"] name: Lint - Python ${{ matrix.python }} steps: - uses: actions/checkout@v2 @@ -37,7 +37,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: ["3.8", "3.9", "3.10"] + python: ["3.8", "3.9", "3.10", "3.11", "3.12"] name: Test - Python ${{ matrix.python }} steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ba28ee6..9eab418 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -7,56 +7,7 @@ env: PACKAGE_DIR: arango_datasets TESTS_DIR: tests jobs: - lint: - runs-on: ubuntu-latest - strategy: - matrix: - python: ["3.8", "3.9", "3.10"] - name: Lint - Python ${{ matrix.python }} - steps: - - uses: actions/checkout@v2 - - name: Setup Python ${{ matrix.python }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python }} - - name: Install packages - run: pip install .[dev] - - name: Run black - run: black --check --verbose --diff --color ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} - - name: Run flake8 - run: flake8 ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} - - name: Run isort - run: isort --check --profile=black ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} - - name: Run mypy - run: mypy ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} - - name: Run bandit - run: bandit --exclude "./tests/*" -r ./ - test: - needs: lint - runs-on: ubuntu-latest - strategy: - matrix: - python: ["3.8", "3.9", "3.10"] - name: Test - Python ${{ matrix.python }} - steps: - - uses: actions/checkout@v2 - - name: Setup Python ${{ matrix.python }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python }} - - name: Set up ArangoDB Instance via Docker - run: docker create --name adb -p 8529:8529 -e ARANGO_ROOT_PASSWORD=openSesame arangodb/arangodb:3.10.0 - - name: Start ArangoDB Instance - run: docker start adb - - name: Setup pip - run: python -m pip install --upgrade pip setuptools wheel - - name: Install packages - run: pip install .[dev] - - name: Run pytest - run: pytest --cov=${{env.PACKAGE_DIR}} --cov-report xml --cov-report term-missing -v --color=yes --no-cov-on-fail --code-highlight=yes --cov-fail-under=75 - release: - needs: [lint, test] runs-on: ubuntu-latest name: Release package steps: @@ -66,24 +17,22 @@ jobs: run: git fetch --prune --unshallow - name: Setup python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: - python-version: "3.8" + python-version: "3.10" - name: Install release packages - run: pip install setuptools wheel twine setuptools-scm[toml] - - - name: Install dependencies - run: pip install .[dev] + run: pip install build twine - name: Build distribution - run: python setup.py sdist bdist_wheel + run: python -m build - name: Publish to PyPI Test env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD_TEST }} run: twine upload --repository testpypi dist/* #--skip-existing + - name: Publish to PyPI env: TWINE_USERNAME: __token__ diff --git a/README.md b/README.md index 957990e..2dc9aad 100644 --- a/README.md +++ b/README.md @@ -2,20 +2,20 @@ Package for loading example datasets into an ArangoDB Instance. ```py -from arango_datasets.datasets import Datasets from arango import ArangoClient +from arango_datasets import Datasets # Datasets requires a valid database object db = ArangoClient(hosts='http://localhost:8529').db("dbName", username="root", password="") datasets = Datasets(db) -# list available datasets +# List available datasets datasets.list_datasets() -# list more information about the dataset files and characteristics -#datasets.dataset_info("IMDB_X") +# List more information about a particular dataset +datasets.dataset_info("IMDB_X") # Import the dataset -# datasets.load("IMDB_X") +datasets.load("IMDB_X") ``` \ No newline at end of file diff --git a/arango_datasets/__init__.py b/arango_datasets/__init__.py index e69de29..c8c9def 100644 --- a/arango_datasets/__init__.py +++ b/arango_datasets/__init__.py @@ -0,0 +1 @@ +from arango_datasets.datasets import Datasets # noqa: F401 diff --git a/arango_datasets/datasets.py b/arango_datasets/datasets.py index 560624b..69c6135 100644 --- a/arango_datasets/datasets.py +++ b/arango_datasets/datasets.py @@ -1,12 +1,9 @@ import json -import sys -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List, Optional import requests from arango.collection import StandardCollection from arango.database import Database -from arango.exceptions import CollectionCreateError, DocumentInsertError -from requests import ConnectionError, HTTPError from .utils import progress @@ -16,12 +13,14 @@ class Datasets: :param db: A python-arango database instance :type db: arango.database.Database - :param batch_size: - Optional batch size supplied to python-arango import_bulk function + :param batch_size: Optional batch size supplied to the + python-arango `import_bulk` function. Defaults to 50. :type batch_size: int - :param metadata_file: Optional URL for datasets metadata file + :param metadata_file: URL for datasets metadata file. Defaults to + "https://arangodb-dataset-library-ml.s3.amazonaws.com/root_metadata.json". :type metadata_file: str - :param preserve_existing: Boolean to preserve existing data and graph definiton + :param preserve_existing: Whether to preserve the existing collections and + graph of the dataset (if any). Defaults to False. type preserve_existing: bool """ @@ -32,150 +31,186 @@ def __init__( metadata_file: str = "https://arangodb-dataset-library-ml.s3.amazonaws.com/root_metadata.json", # noqa: E501 preserve_existing: bool = False, ): - self.metadata_file: str = metadata_file - self.metadata_contents: Dict[str, Any] - self.batch_size = batch_size - self.user_db = db - self.preserve_existing = preserve_existing - self.file_type: str - if issubclass(type(db), Database) is False: + if not isinstance(db, Database): msg = "**db** parameter must inherit from arango.database.Database" raise TypeError(msg) - try: - response = requests.get(self.metadata_file, timeout=6000) - response.raise_for_status() - self.metadata_contents = response.json() - except (HTTPError, ConnectionError) as e: - print("Unable to retrieve metadata information.") - print(e) - raise + self.user_db = db + self.batch_size = batch_size + self.metadata_file = metadata_file + self.preserve_existing = preserve_existing - self.labels = [] - for label in self.metadata_contents: - self.labels.append(label) + self.__metadata: Dict[str, Dict[str, Any]] + self.__metadata = self.__get_response(self.metadata_file).json() + self.__dataset_names = [n for n in self.__metadata] def list_datasets(self) -> List[str]: - print(self.labels) - return self.labels + """List available datasets - def dataset_info(self, dataset_name: str) -> Dict[str, Any]: - for i in self.metadata_contents[str(dataset_name).upper()]: - print(f"{i}: {self.metadata_contents[str(dataset_name).upper()][i]} ") - print("") - return self.metadata_contents + :return: Names of the available datasets to load. + :rtype: List[str] + """ + print(self.__dataset_names) + return self.__dataset_names - def insert_docs( + def dataset_info(self, dataset_name: str) -> Dict[str, Any]: + """Get information about a dataset + + :param dataset_name: Name of the dataset. + :type dataset_name: str + :return: Some metadata about the dataset. + :rtype: Dict[str, Any] + :raises ValueError: If the dataset is not found. + """ + if dataset_name.upper() not in self.__dataset_names: + raise ValueError(f"Dataset '{dataset_name}' not found") + + info: Dict[str, Any] = self.__metadata[dataset_name.upper()] + print(info) + return info + + def load( self, - collection: StandardCollection, - docs: List[Dict[Any, Any]], - collection_name: str, + dataset_name: str, + batch_size: Optional[int] = None, + preserve_existing: Optional[bool] = None, ) -> None: - try: - with progress(f"Collection: {collection_name}") as p: - p.add_task("insert_docs") + """Load a dataset into the database. + + :param dataset_name: Name of the dataset. + :type dataset_name: str + :param batch_size: Batch size supplied to the + python-arango `import_bulk` function. Overrides the **batch_size** + supplied to the constructor. Defaults to None. + :type batch_size: Optional[int] + :param preserve_existing: Whether to preserve the existing collections and + graph of the dataset (if any). Overrides the **preserve_existing** + supplied to the constructor. Defaults to False. + :type preserve_existing: bool + :raises ValueError: If the dataset is not found. + """ + if dataset_name.upper() not in self.__dataset_names: + raise ValueError(f"Dataset '{dataset_name}' not found") + + dataset_contents = self.__metadata[dataset_name.upper()] + + # Backwards compatibility + self.batch_size = batch_size if batch_size is not None else self.batch_size + self.preserve_existing = ( + preserve_existing + if preserve_existing is not None + else self.preserve_existing + ) + + file_type = dataset_contents["file_type"] + load_file_function: Callable[[str], List[Dict[str, Any]]] + if file_type == "json": + load_file_function = self.__load_json + elif file_type == "jsonl": + load_file_function = self.__load_jsonl + else: + raise ValueError(f"Unsupported file type: {file_type}") + + for data, is_edge in [ + (dataset_contents["vertices"], False), + (dataset_contents["edges"], True), + ]: + for col_data in data: + col = self.__initialize_collection(col_data["collection_name"], is_edge) + + for file in col_data["files"]: + self.__import_bulk(col, load_file_function(file)) + + if edge_definitions := dataset_contents.get("edge_definitions"): + self.user_db.delete_graph(dataset_name, ignore_missing=True) + self.user_db.create_graph(dataset_name, edge_definitions) + + def __get_response(self, url: str, timeout: int = 60) -> requests.Response: + """Wrapper around requests.get() with a progress bar. + + :param url: URL to get a response from. + :type url: str + :param timeout: Timeout in seconds. Defaults to 60. + :type timeout: int + :raises ConnectionError: If the connection fails. + :raises HTTPError: If the HTTP request fails. + :return: The response from the URL. + :rtype: requests.Response + """ + with progress(f"GET: {url}") as p: + p.add_task("get_response") + + response = requests.get(url, timeout=timeout) + response.raise_for_status() + return response + + def __initialize_collection( + self, collection_name: str, is_edge: bool + ) -> StandardCollection: + """Initialize a collection. + + :param collection_name: Name of the collection. + :type collection_name: str + :param is_edge: Whether the collection is an edge collection. + :type is_edge: bool + :raises CollectionCreateError: If the collection cannot be created. + :return: The collection. + :rtype: arango.collection.StandardCollection + """ + if self.preserve_existing is False: + m = f"Collection '{collection_name}' already exists, dropping and creating with new data." # noqa: E501 + print(m) + + self.user_db.delete_collection(collection_name, ignore_missing=True) + + return self.user_db.create_collection(collection_name, edge=is_edge) + + def __load_json(self, file_url: str) -> List[Dict[str, Any]]: + """Load a JSON file into memory. + + :param file_url: URL of the JSON file. + :type file_url: str + :raises ConnectionError: If the connection fails. + :raises HTTPError: If the HTTP request fails. + :return: The JSON data. + :rtype: Dict[str, Any] + """ + json_data: List[Dict[str, Any]] = self.__get_response(file_url).json() + return json_data + + def __load_jsonl(self, file_url: str) -> List[Dict[str, Any]]: + """Load a JSONL file into memory. + + :param file_url: URL of the JSONL file. + :type file_url: str + :raises ConnectionError: If the connection fails. + :raises HTTPError: If the HTTP request fails. + :return: The JSONL data as a list of dictionaries. + """ + json_data = [] + data = self.__get_response(file_url) - collection.import_bulk(docs, batch_size=self.batch_size) + if data.encoding is None: + data.encoding = "utf-8" - except DocumentInsertError as exec: - print("Document insertion failed due to the following error:") - print(exec.message) - sys.exit(1) + for line in data.iter_lines(decode_unicode=True): + if line: + json_data.append(json.loads(line)) - print(f"Finished loading current file for collection: {collection_name}") + return json_data - def load_json( - self, - collection_name: str, - edge_type: bool, - file_url: str, - collection: StandardCollection, - ) -> None: - try: - with progress(f"Downloading file for: {collection_name}") as p: - p.add_task("load_file") - data = requests.get(file_url, timeout=6000).json() - except (HTTPError, ConnectionError) as e: - print("Unable to download file.") - print(e) - raise e - print(f"Downloaded file for: {collection_name}, now importing... ") - self.insert_docs(collection, data, collection_name) - - def load_jsonl( - self, - collection_name: str, - edge_type: bool, - file_url: str, - collection: StandardCollection, + def __import_bulk( + self, collection: StandardCollection, docs: List[Dict[str, Any]] ) -> None: - json_data = [] - try: - with progress(f"Downloading file for: {collection_name}") as p: - p.add_task("load_file") - data = requests.get(file_url, timeout=6000) - - if data.encoding is None: - data.encoding = "utf-8" - - for line in data.iter_lines(decode_unicode=True): - if line: - json_data.append(json.loads(line)) - - except (HTTPError, ConnectionError) as e: - print("Unable to download file.") - print(e) - raise - print(f"Downloaded file for: {collection_name}, now importing... ") - self.insert_docs(collection, json_data, collection_name) - - def load_file(self, collection_name: str, edge_type: bool, file_url: str) -> None: - collection: StandardCollection - try: - collection = self.user_db.create_collection(collection_name, edge=edge_type) - except CollectionCreateError as exec: - print( - f"""Failed to create {collection_name} collection due - to the following error:""" - ) - print(exec.error_message) - sys.exit(1) - if self.file_type == "json": - self.load_json(collection_name, edge_type, file_url, collection) - elif self.file_type == "jsonl": - self.load_jsonl(collection_name, edge_type, file_url, collection) - else: - raise ValueError(f"Unsupported file type: {self.file_type}") - - def cleanup_collections(self, collection_name: str) -> None: - if ( - self.user_db.has_collection(collection_name) - and self.preserve_existing is False - ): - print( - f""" - Old collection found - ${collection_name}, - dropping and creating with new data.""" - ) - self.user_db.delete_collection(collection_name) - - def load(self, dataset_name: str) -> None: - if str(dataset_name).upper() in self.labels: - self.file_type = self.metadata_contents[str(dataset_name).upper()][ - "file_type" - ] - - for edge in self.metadata_contents[str(dataset_name).upper()]["edges"]: - self.cleanup_collections(collection_name=edge["collection_name"]) - for e in edge["files"]: - self.load_file(edge["collection_name"], True, e) - - for vertex in self.metadata_contents[str(dataset_name).upper()]["vertices"]: - self.cleanup_collections(collection_name=vertex["collection_name"]) - for v in vertex["files"]: - self.load_file(vertex["collection_name"], False, v) - - else: - print(f"Dataset `{str(dataset_name.upper())}` not found") - sys.exit(1) + """Wrapper around python-arango's import_bulk() with a progress bar. + + :param collection: The collection to insert the documents into. + :type collection: arango.collection.StandardCollection + :param docs: The documents to insert. + :type docs: List[Dict[Any, Any]] + :raises DocumentInsertError: If the document cannot be inserted. + """ + with progress(f"Collection: {collection.name}") as p: + p.add_task("insert_docs") + + collection.import_bulk(docs, batch_size=self.batch_size) diff --git a/tests/test_datasets.py b/tests/test_datasets.py new file mode 100644 index 0000000..e21e50d --- /dev/null +++ b/tests/test_datasets.py @@ -0,0 +1,65 @@ +from typing import Any + +import pytest +from requests import ConnectionError + +from arango_datasets.datasets import Datasets + +from .conftest import cleanup_collections, db + +test_metadata_url = "https://arangodb-dataset-library-ml.s3.amazonaws.com/test_metadata.json" # noqa: E501 +bad_metadata_url = "http://bad_url.arangodb.com/" + + +def test_dataset_constructor() -> None: + with pytest.raises(Exception): + Datasets({}) + + with pytest.raises(TypeError): + Datasets(db="some none db object") + + with pytest.raises(ConnectionError): + Datasets(db, metadata_file=bad_metadata_url) + + +def test_list_datasets(capfd: Any) -> None: + datasets = Datasets(db, metadata_file=test_metadata_url).list_datasets() + assert type(datasets) is list + assert "TEST" in datasets + + out, _ = capfd.readouterr() + assert "TEST" in out + + +def test_dataset_info(capfd: Any) -> None: + with pytest.raises(ValueError): + Datasets(db).dataset_info("invalid") + + datasets = Datasets(db, metadata_file=test_metadata_url) + + dataset = datasets.dataset_info("TEST") + assert type(dataset) is dict + assert dataset["file_type"] == "json" + + dataset = datasets.dataset_info("TEST_JSONL") + assert type(dataset) is dict + assert dataset["file_type"] == "jsonl" + + out, _ = capfd.readouterr() + assert len(out.replace("\n", "")) > 0 + + +def test_load_json() -> None: + cleanup_collections() + Datasets(db, metadata_file=test_metadata_url).load("TEST") + assert db.collection("test_vertex").count() == 2 + assert db.collection("test_edge").count() == 1 + assert db.has_graph("TEST") + + +def test_load_jsonl() -> None: + cleanup_collections() + Datasets(db, metadata_file=test_metadata_url).load("TEST_JSONL") + assert db.collection("test_vertex").count() == 2 + assert db.collection("test_edge").count() == 1 + assert db.has_graph("TEST_JSONL") diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index 1c40d70..0000000 --- a/tests/test_main.py +++ /dev/null @@ -1,165 +0,0 @@ -from typing import Any, no_type_check - -import pytest -from requests import ConnectionError - -from arango_datasets.datasets import Datasets - -from .conftest import cleanup_collections, db - -global test_metadata_url -global root_metadata_url -global bad_metadata_url -test_metadata_url = ( - "https://arangodb-dataset-library.s3.amazonaws.com/test_metadata.json" # noqa: E501 -) -root_metadata_url = ( - "https://arangodb-dataset-library.s3.amazonaws.com/root_metadata.json" # noqa: E501 -) -bad_metadata_url = "http://bad_url.arangodb.com/" - - -@no_type_check -def test_dataset_constructor() -> None: - assert Datasets(db) - assert Datasets(db, batch_size=1000) - assert Datasets( - db, - batch_size=1000, - ) - assert Datasets( - db, - batch_size=1000, - metadata_file=root_metadata_url, - ) - with pytest.raises(TypeError): - assert Datasets( - db="some none db object", - batch_size=1000, - metadata_file=root_metadata_url, - ) - with pytest.raises(Exception): - assert Datasets({}) - - with pytest.raises(ConnectionError): - assert Datasets(db, metadata_file=bad_metadata_url) - - -@no_type_check -def test_list_datasets(capfd: Any) -> None: - datasets = Datasets( - db, - metadata_file=test_metadata_url, - ).list_datasets() - out, err = capfd.readouterr() - assert "TEST" in out - assert type(datasets) is list - assert "TEST" in datasets - - -@no_type_check -def test_dataset_info(capfd: Any) -> None: - with pytest.raises(Exception): - Datasets.dataset_info() - - with pytest.raises(Exception): - Datasets(db).dataset_info(2) - - dataset = Datasets( - db, - metadata_file=test_metadata_url, - ).dataset_info("TEST") - assert type(dataset) is dict - - assert dataset["TEST"]["file_type"] == "json" - - out, err = capfd.readouterr() - assert len(out) > 0 - - -@no_type_check -def test_load_file() -> None: - with pytest.raises(Exception): - Datasets.load_file(collection_name="test", edge_type=None, file_url="false") - - -@no_type_check -def test_load_json() -> None: - cleanup_collections() - collection_name = "test_vertex" - edge_type = False - file_url = "https://arangodb-dataset-library.s3.amazonaws.com/test_files/json/vertex_collection/test_vertex.json" # noqa: E501 - collection = db.create_collection("test_vertex") - assert None == ( - Datasets.load_json( - Datasets(db), - collection_name=collection_name, - edge_type=edge_type, - file_url=file_url, - collection=collection, - ) - ) - - -@no_type_check -def json_bad_url() -> None: - cleanup_collections() - collection_name = "test_vertex" - edge_type = False - collection = db.create_collection("test_vertex") - - with pytest.raises(ConnectionError): - Datasets.load_json( - Datasets(db), - collection_name=collection_name, - edge_type=edge_type, - file_url=bad_metadata_url, - collection=collection, - ) - - -@no_type_check -def test_load_jsonl() -> None: - cleanup_collections() - collection_name = "test_vertex" - edge_type = False - file_url = "https://arangodb-dataset-library.s3.amazonaws.com/test_files/jsonl/vertex_collection/test_vertex.jsonl" # noqa: E501 - collection = db.create_collection("test_vertex") - assert None == ( - Datasets.load_jsonl( - Datasets(db), - collection_name=collection_name, - edge_type=edge_type, - file_url=file_url, - collection=collection, - ) - ) - - -@no_type_check -def jsonl_bad_url() -> None: - cleanup_collections() - collection_name = "test_vertex" - edge_type = False - collection = db.create_collection("test_vertex") - with pytest.raises(ConnectionError): - Datasets.load_jsonl( - Datasets(db), - collection_name=collection_name, - edge_type=edge_type, - file_url=bad_metadata_url, - collection=collection, - ) - - -@no_type_check -def test_load() -> None: - cleanup_collections() - Datasets( - db, - metadata_file=test_metadata_url, - ).load("TEST") - with pytest.raises(Exception): - Datasets(db).load(2) - assert db.collection("test_vertex").count() == 2 - assert db.collection("test_edge").count() == 1