Skip to content

Commit

Permalink
[TEST] Complete refactoring of old testing functions
Browse files Browse the repository at this point in the history
No need to be extensive, you want to be accurate
  • Loading branch information
Aedial committed Apr 27, 2023
1 parent 3436fef commit be02c90
Show file tree
Hide file tree
Showing 14 changed files with 3,095 additions and 5 deletions.
10 changes: 6 additions & 4 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ jobs:
python-version: [3.7, 3.8, 3.9, "3.10", 3.11]

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
- name: Install dependencies
run: |
pip install nox
Expand All @@ -40,11 +41,12 @@ jobs:
python-version: ["3.11"]

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
- name: Install dependencies
run: |
pip install nox
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ sentencepiece = "^0.1.98"

[tool.poetry.group.dev.dependencies]
pytest-asyncio = "^0.20.1"
pytest-xdist = "^3.0.2"
pytest-randomly = "^3.12.0"
pylint = "^2.15.5"

[tool.flake8]
Expand Down
Empty file added tests/__init__.py
Empty file.
143 changes: 143 additions & 0 deletions tests/api/boilerplate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import asyncio
import json
from logging import Logger, StreamHandler
from os import environ as env
from typing import Any, NoReturn

import pytest
from aiohttp import ClientConnectionError, ClientPayloadError, ClientSession

from novelai_api import NovelAIAPI, NovelAIError
from novelai_api.utils import get_encryption_key


class API:
_username: str
_password: str
_session: ClientSession
_sync: bool

logger: Logger
api: NovelAIAPI

def __init__(self, sync: bool = False):
if "NAI_USERNAME" not in env or "NAI_PASSWORD" not in env:
raise RuntimeError("Please ensure that NAI_USERNAME and NAI_PASSWORD are set in your environment")

self._username = env["NAI_USERNAME"]
self._password = env["NAI_PASSWORD"]
self._sync = sync

self.logger = Logger("NovelAI")
self.logger.addHandler(StreamHandler())

proxy = env["NAI_PROXY"] if "NAI_PROXY" in env else None

self.api = NovelAIAPI(logger=self.logger)
self.api.proxy = proxy

@property
def encryption_key(self):
return get_encryption_key(self._username, self._password)

def __enter__(self) -> NoReturn:
raise TypeError("Use async with instead")

async def __aenter__(self):
if not self._sync:
self._session = ClientSession()
await self._session.__aenter__()
self.api.attach_session(self._session)

await self.api.high_level.login(self._username, self._password)

return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
if not self._sync:
await self._session.__aexit__(exc_type, exc_val, exc_tb)

async def run_test(self, func, *args, attempts: int = 5, wait: int = 5):
"""
Run the function ``func`` with the provided arguments and retry on error handling
The function must accept a NovelAIAPI object as first arguments
:param func: Function to run
:param args: Arguments to provide to the function
:param attempts: Number of attempts to do before raising the error
:param wait: Time (in seconds) to wait after each call
"""

err: Exception = RuntimeError("Error placeholder. Shouldn't happen")
for _ in range(attempts):
try:
res = await func(self.api, *args)
await asyncio.sleep(wait)

return res
except (ClientConnectionError, asyncio.TimeoutError, ClientPayloadError) as e:
err = e
retry = True

except NovelAIError as e:
err = e
retry = any(
[
e.status == 502, # Bad Gateway
e.status == 520, # Cloudflare Unknown Error
e.status == 524, # Cloudflare Gateway Error
]
)

if not retry:
break

# 10s wait between each retry
await asyncio.sleep(10)

# no internet: ping every 5 mins until connection is re-established
async with ClientSession() as session:
while True:
try:
rsp = await session.get("https://www.google.com", timeout=5 * 60)
rsp.raise_for_status()

break
except ClientConnectionError:
await asyncio.sleep(5 * 60)
except asyncio.TimeoutError:
pass

raise err


class JSONEncoder(json.JSONEncoder):
def default(self, o: Any) -> Any:
if isinstance(o, bytes):
return o.hex()

return super().default(o)


def dumps(e: Any) -> str:
return json.dumps(e, indent=4, ensure_ascii=False, cls=JSONEncoder)


@pytest.fixture(scope="session")
async def api_handle():
"""
API handle for an Async Test
"""

async with API() as api:
yield api


@pytest.fixture(scope="session")
async def api_handle_sync():
"""
API handle for a Sync Test
"""

async with API(sync=True) as api:
yield api
11 changes: 11 additions & 0 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import asyncio

import pytest


# cannot put in boilerplate because pytest is a mess
@pytest.fixture(scope="session")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
Loading

0 comments on commit be02c90

Please sign in to comment.