diff --git a/pyproject.toml b/pyproject.toml index 5ed835f..906c35c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "cems-nuclei" version = "2.0.1" description = "Python wrapper around NUCLEI's functionality." -dependencies = ["requests>=2.25.1,<3", "pyjwt>=2.6.0,<3"] +dependencies = ["aiohttp>=3.11.16,<4", "pyjwt>=2.6.0,<3"] requires-python = ">=3.11" license = { file = "LICENSE.txt" } readme = "README.md" diff --git a/requirements.txt b/requirements.txt index 2aa5b25..eab0245 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,11 @@ # This file was autogenerated by uv via the following command: # uv pip compile --extra=test --extra=docs --extra=lint --extra=client --output-file=requirements.txt pyproject.toml +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.11.16 + # via cems-nuclei (pyproject.toml) +aiosignal==1.3.2 + # via aiohttp alabaster==1.0.0 # via sphinx anywidget==0.9.18 @@ -8,6 +14,8 @@ asteroid-sphinx-theme==0.0.3 # via cems-nuclei (pyproject.toml) asttokens==3.0.0 # via stack-data +attrs==25.3.0 + # via aiohttp babel==2.17.0 # via sphinx black==25.1.0 @@ -34,8 +42,14 @@ docutils==0.21.2 # sphinx-rtd-theme executing==2.2.0 # via stack-data +frozenlist==1.5.0 + # via + # aiohttp + # aiosignal idna==3.10 - # via requests + # via + # requests + # yarl imagesize==1.4.1 # via sphinx iniconfig==2.1.0 @@ -60,6 +74,10 @@ markupsafe==3.0.2 # via jinja2 matplotlib-inline==0.1.7 # via ipython +multidict==6.4.2 + # via + # aiohttp + # yarl mypy-extensions==1.0.0 # via black numpy==2.2.4 @@ -83,6 +101,10 @@ pluggy==1.5.0 # via pytest prompt-toolkit==3.0.50 # via ipython +propcache==0.3.1 + # via + # aiohttp + # yarl psygnal==0.12.0 # via anywidget ptyprocess==0.7.0 @@ -100,7 +122,6 @@ pytest==8.3.5 # via cems-nuclei (pyproject.toml) requests==2.32.3 # via - # cems-nuclei (pyproject.toml) # coveralls # sphinx roman-numerals-py==3.1.0 @@ -148,3 +169,5 @@ wcwidth==0.2.13 # via prompt-toolkit widgetsnbextension==4.0.13 # via ipywidgets +yarl==1.19.0 + # via aiohttp diff --git a/src/nuclei/api/main.py b/src/nuclei/api/main.py index c51bf06..b74d093 100644 --- a/src/nuclei/api/main.py +++ b/src/nuclei/api/main.py @@ -2,10 +2,10 @@ import os import jwt -import requests +import aiohttp -def create_session() -> requests.Session: +def create_session() -> aiohttp.ClientSession: """ Initialising session object to call the NUCLEI endpoints. Provide the user token from https://nuclei.cemsbv.io/#/personal-access-tokens to @@ -16,28 +16,10 @@ def create_session() -> requests.Session: Session """ - # initialising session - _session = requests.Session() - # set bearer token - _session.headers.update({"Authorization": f"Bearer {authenticate()}"}) - - return _session - - -def authenticate() -> str: - """ - Returns a validated JWT token from backend. - - Prompt the user for a user-token if it is not stored as an environmental variable - "NUCLEI_TOKEN". A prompted user token will be stored as env-var after validation. + headers = {"Authorization": f"Bearer {_get_valid_user_token()}"} - Will throw an exception when authentication fails. - """ - - user_token = _get_valid_user_token() - - return user_token + return aiohttp.ClientSession(headers=headers) def _get_valid_user_token() -> str: @@ -46,6 +28,11 @@ def _get_valid_user_token() -> str: If the token is valid, it is stored as env-var "NUCLEI_TOKEN" as a side-effect. """ + global USER_TOKEN + + if USER_TOKEN: + return USER_TOKEN + # check if User Token is in environment variables if "NUCLEI_TOKEN" in os.environ: logging.info("user token found in environment") @@ -63,7 +50,10 @@ def _get_valid_user_token() -> str: os.environ["NUCLEI_TOKEN"] = token logging.info("user token set in environment") - return token + USER_TOKEN = token + + if USER_TOKEN: + return USER_TOKEN def _validate_user_token(token: str) -> None: diff --git a/src/nuclei/client/main.py b/src/nuclei/client/main.py index 027c559..f558ba5 100644 --- a/src/nuclei/client/main.py +++ b/src/nuclei/client/main.py @@ -8,7 +8,18 @@ import jwt from nuclei import create_session -from nuclei.client import utils + +# try import serialize functions +try: + from IPython.display import Image + + from nuclei.client import utils +except ImportError as e: + raise ImportError( + "Could not import one of dependencies [numpy, orjson, ipython]. " + "You must install nuclei[client] in order to use NucleiClient \n" + rf"Traceback: {e}" + ) ROUTING = { "PileCore": { @@ -52,9 +63,6 @@ def __init__(self) -> None: The connect timeout is the number of seconds. Default is 5 seconds """ - # initialize session - self.session = create_session() - # get routing table to application self.routing = ROUTING @@ -104,11 +112,12 @@ def user_permissions(self) -> List[str | None]: out : list[str] Names of the API's """ - return jwt.decode( - self.session.headers["Authorization"].split(" ")[1], # type: ignore - algorithms=["HS256"], - options={"verify_signature": False, "verify_exp": False}, - ).get("permissions", []) + with create_session() as session: + return jwt.decode( + session.headers["Authorization"].split(" ")[1], # type: ignore + algorithms=["HS256"], + options={"verify_signature": False, "verify_exp": False}, + ).get("permissions", []) @property def applications(self) -> List[str]: @@ -154,7 +163,7 @@ def get_versions(self, app: str) -> List[str]: return list(self.routing[app].keys()) @lru_cache(16) - def _get_app_specification(self, app: str, version: str = "latest") -> dict: + async def _get_app_specification(self, app: str, version: str = "latest") -> dict: """ Private methode to get the JSON schema of the API documentation. @@ -179,16 +188,17 @@ def _get_app_specification(self, app: str, version: str = "latest") -> dict: ValueError: Wrong value for `app` or `version` argument """ - response = self.session.get( - self.get_url(app, version) + "/openapi.json", timeout=self.timeout - ) - if response.status_code != 200: - raise ConnectionError( - "Unfortunately the application you are trying to reach is unavailable (status code: " - f"{response.status_code}). Please check you connection. If the problem persists contact " - "CEMS at info@cemsbv.nl" - ) - return response.json() + async with create_session() as session: + async with session.get( + self.get_url(app, version) + "/openapi.json", timeout=self.timeout + ) as response: + if response.status_code != 200: + raise ConnectionError( + "Unfortunately the application you are trying to reach is unavailable (status code: " + f"{response.status_code}). Please check you connection. If the problem persists contact " + "CEMS at info@cemsbv.nl" + ) + return await response.json() def get_application_version(self, app: str, version: str = "latest") -> str: """ @@ -338,8 +348,8 @@ def call_endpoint( content response out : Response requests response object - figure: bytes - PNG bytes, can be saved directly to a .png file + figure: Image + IPython display Image object Raises ------- @@ -448,7 +458,9 @@ def call_endpoint( content_type = response.headers["Content-Type"] if content_type == "image/png;base64": - return base64.b64decode(response.text) + return Image(base64.b64decode(response.text)) + elif content_type == "image/png": + return Image(response.content) elif content_type.endswith("json"): return response.json() elif content_type.startswith("text/"):