Skip to content

Commit

Permalink
Add stubs for python3-saml to fix mypy errors (#65)
Browse files Browse the repository at this point in the history
- Use stubgen to generate stubs for onelogin.saml2.auth and
  onelogin.saml2.settings and fix imports
- Add `stubs` directory to `mypy_path`
  • Loading branch information
taesungh authored Dec 10, 2023
1 parent e2eafe6 commit cd80b69
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 7 deletions.
2 changes: 1 addition & 1 deletion apps/api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ branch = true
show_missing = true

[tool.mypy]
mypy_path = "src"
mypy_path = "src,stubs"
explicit_package_bases = true
strict = true
11 changes: 6 additions & 5 deletions apps/api/src/routers/saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import RedirectResponse, Response
from onelogin.saml2.auth import OneLogin_Saml2_Auth, OneLogin_Saml2_Settings
from onelogin.saml2.auth import OneLogin_Saml2_Auth
from onelogin.saml2.settings import OneLogin_Saml2_Settings

# from auth import user_identity

Expand Down Expand Up @@ -56,7 +57,7 @@ def _read_json(filename: str) -> dict[str, Any]:
return OneLogin_Saml2_Settings(settings, custom_base_path=str(BASE_PATH))


async def _prepare_saml_req(req: Request) -> dict[str, Any]:
async def _prepare_saml_req(req: Request) -> dict[str, object]:
"""Packages a FastAPI Request into a request dict for SAML Auth"""
return {
"http_host": req.url.hostname,
Expand Down Expand Up @@ -90,7 +91,7 @@ async def login(req: Request) -> RedirectResponse:


@router.post("/acs")
async def acs(req: Request) -> RedirectResponse:
async def acs(req: Request) -> str:
"""
SAML Assertion Consumer Service.
Accepts the response returned by the SAML Identity Provider and
Expand All @@ -113,7 +114,7 @@ async def acs(req: Request) -> RedirectResponse:
(email,) = auth.get_friendlyname_attribute("email")
(display_name,) = auth.get_friendlyname_attribute("displayName")
(ucinetid,) = auth.get_friendlyname_attribute("ucinetid")
affiliations: list[str] = auth.get_friendlyname_attribute("uciaffiliation")
affiliations = auth.get_friendlyname_attribute("uciaffiliation")
except (ValueError, TypeError) as e:
log.exception("Error decoding SAML Attributes: %s", e)
raise HTTPException(500, "Error decoding user identity")
Expand Down Expand Up @@ -142,7 +143,7 @@ async def sls(req: Request) -> str:
async def get_saml_metadata() -> Response:
"""Provides SAML metadata, used when registering service with IdP"""
saml_settings = _get_saml_settings()
metadata: bytes = saml_settings.get_sp_metadata()
metadata = saml_settings.get_sp_metadata()

errors = saml_settings.validate_metadata(metadata)
if errors:
Expand Down
23 changes: 23 additions & 0 deletions apps/api/stubs/onelogin/saml2/auth.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from onelogin.saml2.settings import OneLogin_Saml2_Settings

class OneLogin_Saml2_Auth:
def __init__(
self,
request_data: dict[str, object],
old_settings: OneLogin_Saml2_Settings | dict[str, object] | None = ...,
custom_base_path: str | None = ...,
) -> None: ...
def process_response(self, request_id: str | None = None) -> None: ...
def is_authenticated(self) -> bool: ...
def get_friendlyname_attributes(self) -> dict[str, list[str]]: ...
def get_errors(self) -> list[str]: ...
def get_last_error_reason(self) -> str | None: ...
def get_friendlyname_attribute(self, friendlyname: str) -> list[str]: ...
def login(
self,
return_to: str | None = ...,
force_authn: bool = ...,
is_passive: bool = ...,
set_nameid_policy: bool = ...,
name_id_value_req: str | None = ...,
) -> str: ...
9 changes: 9 additions & 0 deletions apps/api/stubs/onelogin/saml2/settings.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class OneLogin_Saml2_Settings:
def __init__(
self,
settings: dict[str, object] | None = None,
custom_base_path: str | None = None,
sp_validation_only: bool | None = None,
) -> None: ...
def get_sp_metadata(self) -> bytes | str: ...
def validate_metadata(self, xml: bytes | str) -> list[str]: ...
2 changes: 1 addition & 1 deletion apps/api/tests/test_saml.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch

from fastapi.testclient import TestClient
from onelogin.saml2.auth import OneLogin_Saml2_Settings
from onelogin.saml2.settings import OneLogin_Saml2_Settings

from routers import saml

Expand Down

0 comments on commit cd80b69

Please sign in to comment.