Skip to content

Commit

Permalink
replace startup shutdown with lifespan
Browse files Browse the repository at this point in the history
Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls committed Mar 12, 2024
1 parent 04f8cbd commit ce631ca
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 72 deletions.
92 changes: 43 additions & 49 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Main app."""

from contextlib import asynccontextmanager
import os
import warnings
from pathlib import Path
Expand Down Expand Up @@ -59,60 +60,53 @@ def overridden_redoc():
redoc_favicon_url=favicon_url,
)


@app.on_event("startup")
async def auth_check():
"""Checks whether username and password environment variables are set."""
if (
# TODO: Check if this error is still raised when variables are empty strings
os.environ.get(util.GRAPH_USERNAME.name) is None
or os.environ.get(util.GRAPH_PASSWORD.name) is None
):
raise RuntimeError(
f"The application was launched but could not find the {util.GRAPH_USERNAME.name} and / or {util.GRAPH_PASSWORD.name} environment variables."
)


@app.on_event("startup")
async def allowed_origins_check():
"""Raises warning if allowed origins environment variable has not been set or is an empty string."""
if os.environ.get(util.ALLOWED_ORIGINS.name, "") == "":
warnings.warn(
f"The API was launched without providing any values for the {util.ALLOWED_ORIGINS.name} environment variable. "
"This means that the API will only be accessible from the same origin it is hosted from: https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy. "
f"If you want to access the API from tools hosted at other origins such as the Neurobagel query tool, explicitly set the value of {util.ALLOWED_ORIGINS.name} to the origin(s) of these tools (e.g. http://localhost:3000). "
"Multiple allowed origins should be separated with spaces in a single string enclosed in quotes. "
)


@app.on_event("startup")
async def fetch_vocabularies_to_temp_dir():
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Create and store on the app instance a temporary directory for vocabulary term lookup JSON files
(each of which contain key-value pairings of IDs to human-readable names of terms),
and then fetch vocabularies using their respective native APIs and save them to the temporary directory for reuse.
Load and set up resources before the application starts receiving requests.
Clean up resources after the application finishes handling requests.
"""
# We use Starlette's ability (FastAPI is Starlette underneath) to store arbitrary state on the app instance (https://www.starlette.io/applications/#storing-state-on-the-app-instance)
# to store a temporary directory object and its corresponding path. These data are local to the instance and will be recreated on every app launch (i.e. not persisted).
app.state.vocab_dir = TemporaryDirectory()
app.state.vocab_dir_path = Path(app.state.vocab_dir.name)

# TODO: Maybe store these paths in one dictionary on the app instance instead of separate variables?
app.state.cogatlas_term_lookup_path = (
app.state.vocab_dir_path / "cogatlas_task_term_labels.json"
)
app.state.snomed_term_lookup_path = (
app.state.vocab_dir_path / "snomedct_disorder_term_labels.json"
)

util.fetch_and_save_cogatlas(app.state.cogatlas_term_lookup_path)
util.create_snomed_term_lookup(app.state.snomed_term_lookup_path)
vocab_dir = TemporaryDirectory()
try:
# Check if username and password environment variables are set
if (
# TODO: Check if this error is still raised when variables are empty strings
os.environ.get(util.GRAPH_USERNAME.name) is None
or os.environ.get(util.GRAPH_PASSWORD.name) is None
):
raise RuntimeError(
f"The application was launched but could not find the {util.GRAPH_USERNAME.name} and / or {util.GRAPH_PASSWORD.name} environment variables."
)

# Raises warning if allowed origins environment variable has not been set or is an empty string.
if os.environ.get(util.ALLOWED_ORIGINS.name, "") == "":
warnings.warn(
f"The API was launched without providing any values for the {util.ALLOWED_ORIGINS.name} environment variable. "
"This means that the API will only be accessible from the same origin it is hosted from: https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy. "
f"If you want to access the API from tools hosted at other origins such as the Neurobagel query tool, explicitly set the value of {util.ALLOWED_ORIGINS.name} to the origin(s) of these tools (e.g. http://localhost:3000). "
"Multiple allowed origins should be separated with spaces in a single string enclosed in quotes. "
)

# Fetch and store vocabularies to a temporary directory
# We use Starlette's ability (FastAPI is Starlette underneath) to store arbitrary state on the app instance (https://www.starlette.io/applications/#storing-state-on-the-app-instance)
vocab_dir_path = Path(vocab_dir.name)

# TODO: Maybe store these paths in one dictionary on the app instance instead of separate variables?
app.cogatlas_term_lookup_path = (
vocab_dir_path / "cogatlas_task_term_labels.json"
)
app.snomed_term_lookup_path = (
vocab_dir_path / "snomedct_disorder_term_labels.json"
)

util.fetch_and_save_cogatlas(app.cogatlas_term_lookup_path)
util.create_snomed_term_lookup(app.snomed_term_lookup_path)

@app.on_event("shutdown")
async def cleanup_temp_vocab_dir():
"""Clean up the temporary directory created on startup."""
app.state.vocab_dir.cleanup()
yield
finally:
# Clean up the temporary directory created on startup
vocab_dir.cleanup()


app.include_router(query.router)
Expand Down
50 changes: 27 additions & 23 deletions tests/test_app_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@
import pytest

from app.api import utility as util
from app.main import lifespan


def test_start_app_without_environment_vars_fails(test_app, monkeypatch):
@pytest.mark.asyncio
async def test_start_app_without_environment_vars_fails(test_app, monkeypatch):
"""Given non-existing username and password environment variables, raises an informative RuntimeError."""
monkeypatch.delenv(util.GRAPH_USERNAME.name, raising=False)
monkeypatch.delenv(util.GRAPH_PASSWORD.name, raising=False)

with pytest.raises(RuntimeError) as e_info:
with test_app:
async with lifespan(test_app):
pass
assert (
f"could not find the {util.GRAPH_USERNAME.name} and / or {util.GRAPH_PASSWORD.name} environment variables"
Expand All @@ -36,8 +38,8 @@ def mock_httpx_post(**kwargs):
response = test_app.get("/query/")
assert response.status_code == 401


def test_app_with_unset_allowed_origins(
@pytest.mark.asyncio
async def test_app_with_unset_allowed_origins(
test_app, monkeypatch, set_test_credentials
):
"""Tests that when the environment variable for allowed origins has not been set, a warning is raised and the app uses a default value."""
Expand All @@ -47,7 +49,7 @@ def test_app_with_unset_allowed_origins(
UserWarning,
match=f"API was launched without providing any values for the {util.ALLOWED_ORIGINS.name} environment variable",
):
with test_app:
async with lifespan(test_app):
pass

assert util.parse_origins_as_list(
Expand Down Expand Up @@ -83,7 +85,9 @@ def test_app_with_unset_allowed_origins(
),
],
)
def test_app_with_set_allowed_origins(

@pytest.mark.asyncio
async def test_app_with_set_allowed_origins(
test_app,
monkeypatch,
set_test_credentials,
Expand All @@ -98,7 +102,7 @@ def test_app_with_set_allowed_origins(
monkeypatch.setenv(util.ALLOWED_ORIGINS.name, allowed_origins)

with expectation:
with test_app:
async with lifespan(test_app):
pass

assert set(parsed_origins).issubset(
Expand All @@ -107,18 +111,18 @@ def test_app_with_set_allowed_origins(
)
)


def test_stored_vocab_lookup_file_created_on_startup(
@pytest.mark.asyncio
async def test_stored_vocab_lookup_file_created_on_startup(
test_app, set_test_credentials
):
"""Test that on startup, a non-empty temporary lookup file is created for term ID-label mappings for the locally stored SNOMED CT vocabulary."""
with test_app:
term_labels_path = test_app.app.state.snomed_term_lookup_path
async with lifespan(test_app) as p:
term_labels_path = test_app.snomed_term_lookup_path
assert term_labels_path.exists()
assert term_labels_path.stat().st_size > 0


def test_external_vocab_is_fetched_on_startup(
@pytest.mark.asyncio
async def test_external_vocab_is_fetched_on_startup(
test_app, monkeypatch, set_test_credentials
):
"""
Expand Down Expand Up @@ -147,8 +151,8 @@ def mock_httpx_get(**kwargs):

monkeypatch.setattr(httpx, "get", mock_httpx_get)

with test_app:
term_labels_path = test_app.app.state.cogatlas_term_lookup_path
async with lifespan(test_app):
term_labels_path = test_app.cogatlas_term_lookup_path
assert term_labels_path.exists()

with open(term_labels_path, "r") as f:
Expand All @@ -159,8 +163,8 @@ def mock_httpx_get(**kwargs):
"tsk_ccTKYnmv7tOZY": "Verbal Interference Test",
}


def test_failed_vocab_fetching_on_startup_raises_warning(
@pytest.mark.asyncio
async def test_failed_vocab_fetching_on_startup_raises_warning(
test_app, monkeypatch, set_test_credentials
):
"""
Expand All @@ -176,17 +180,17 @@ def mock_httpx_get(**kwargs):
monkeypatch.setattr(httpx, "get", mock_httpx_get)

with pytest.warns(UserWarning) as w:
with test_app:
assert test_app.app.state.cogatlas_term_lookup_path.exists()
async with lifespan(test_app):
assert test_app.cogatlas_term_lookup_path.exists()

assert any(
"unable to fetch the Cognitive Atlas task vocabulary (https://www.cognitiveatlas.org/tasks/a/) from the source and will default to using a local backup copy"
in str(warn.message)
for warn in w
)


def test_network_error_on_startup_raises_warning(
@pytest.mark.asyncio
async def test_network_error_on_startup_raises_warning(
test_app, monkeypatch, set_test_credentials
):
"""
Expand All @@ -200,8 +204,8 @@ def mock_httpx_get(**kwargs):
monkeypatch.setattr(httpx, "get", mock_httpx_get)

with pytest.warns(UserWarning) as w:
with test_app:
assert test_app.app.state.cogatlas_term_lookup_path.exists()
async with lifespan(test_app):
assert test_app.cogatlas_term_lookup_path.exists()

assert any(
"failed due to a network error" in str(warn.message) for warn in w
Expand Down

0 comments on commit ce631ca

Please sign in to comment.