diff --git a/apps/api/pyproject.toml b/apps/api/pyproject.toml index 1a03a62d..8fc96c1c 100644 --- a/apps/api/pyproject.toml +++ b/apps/api/pyproject.toml @@ -15,6 +15,6 @@ branch = true show_missing = true [tool.mypy] -mypy_path = "src" +mypy_path = "src,stubs" explicit_package_bases = true strict = true diff --git a/apps/api/src/routers/saml.py b/apps/api/src/routers/saml.py index 74d84fcb..8878118c 100644 --- a/apps/api/src/routers/saml.py +++ b/apps/api/src/routers/saml.py @@ -7,7 +7,8 @@ from fastapi import APIRouter, HTTPException, Request from fastapi.responses import RedirectResponse, Response -from onelogin.saml2.auth import OneLogin_Saml2_Auth, OneLogin_Saml2_Settings +from onelogin.saml2.auth import OneLogin_Saml2_Auth +from onelogin.saml2.settings import OneLogin_Saml2_Settings # from auth import user_identity @@ -56,7 +57,7 @@ def _read_json(filename: str) -> dict[str, Any]: return OneLogin_Saml2_Settings(settings, custom_base_path=str(BASE_PATH)) -async def _prepare_saml_req(req: Request) -> dict[str, Any]: +async def _prepare_saml_req(req: Request) -> dict[str, object]: """Packages a FastAPI Request into a request dict for SAML Auth""" return { "http_host": req.url.hostname, @@ -90,7 +91,7 @@ async def login(req: Request) -> RedirectResponse: @router.post("/acs") -async def acs(req: Request) -> RedirectResponse: +async def acs(req: Request) -> str: """ SAML Assertion Consumer Service. Accepts the response returned by the SAML Identity Provider and @@ -113,7 +114,7 @@ async def acs(req: Request) -> RedirectResponse: (email,) = auth.get_friendlyname_attribute("email") (display_name,) = auth.get_friendlyname_attribute("displayName") (ucinetid,) = auth.get_friendlyname_attribute("ucinetid") - affiliations: list[str] = auth.get_friendlyname_attribute("uciaffiliation") + affiliations = auth.get_friendlyname_attribute("uciaffiliation") except (ValueError, TypeError) as e: log.exception("Error decoding SAML Attributes: %s", e) raise HTTPException(500, "Error decoding user identity") @@ -142,7 +143,7 @@ async def sls(req: Request) -> str: async def get_saml_metadata() -> Response: """Provides SAML metadata, used when registering service with IdP""" saml_settings = _get_saml_settings() - metadata: bytes = saml_settings.get_sp_metadata() + metadata = saml_settings.get_sp_metadata() errors = saml_settings.validate_metadata(metadata) if errors: diff --git a/apps/api/stubs/onelogin/saml2/auth.pyi b/apps/api/stubs/onelogin/saml2/auth.pyi new file mode 100644 index 00000000..b932a053 --- /dev/null +++ b/apps/api/stubs/onelogin/saml2/auth.pyi @@ -0,0 +1,23 @@ +from onelogin.saml2.settings import OneLogin_Saml2_Settings + +class OneLogin_Saml2_Auth: + def __init__( + self, + request_data: dict[str, object], + old_settings: OneLogin_Saml2_Settings | dict[str, object] | None = ..., + custom_base_path: str | None = ..., + ) -> None: ... + def process_response(self, request_id: str | None = None) -> None: ... + def is_authenticated(self) -> bool: ... + def get_friendlyname_attributes(self) -> dict[str, list[str]]: ... + def get_errors(self) -> list[str]: ... + def get_last_error_reason(self) -> str | None: ... + def get_friendlyname_attribute(self, friendlyname: str) -> list[str]: ... + def login( + self, + return_to: str | None = ..., + force_authn: bool = ..., + is_passive: bool = ..., + set_nameid_policy: bool = ..., + name_id_value_req: str | None = ..., + ) -> str: ... diff --git a/apps/api/stubs/onelogin/saml2/settings.pyi b/apps/api/stubs/onelogin/saml2/settings.pyi new file mode 100644 index 00000000..57939076 --- /dev/null +++ b/apps/api/stubs/onelogin/saml2/settings.pyi @@ -0,0 +1,9 @@ +class OneLogin_Saml2_Settings: + def __init__( + self, + settings: dict[str, object] | None = None, + custom_base_path: str | None = None, + sp_validation_only: bool | None = None, + ) -> None: ... + def get_sp_metadata(self) -> bytes | str: ... + def validate_metadata(self, xml: bytes | str) -> list[str]: ... diff --git a/apps/api/tests/test_saml.py b/apps/api/tests/test_saml.py index 1c199d0e..27e94017 100644 --- a/apps/api/tests/test_saml.py +++ b/apps/api/tests/test_saml.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch from fastapi.testclient import TestClient -from onelogin.saml2.auth import OneLogin_Saml2_Settings +from onelogin.saml2.settings import OneLogin_Saml2_Settings from routers import saml