diff --git a/fastapi_health/route.py b/fastapi_health/route.py index c208d72..418538c 100644 --- a/fastapi_health/route.py +++ b/fastapi_health/route.py @@ -1,3 +1,4 @@ +import functools from inspect import Parameter, Signature from typing import TypeVar from typing import Any, Awaitable, Callable, Dict, List, Union @@ -65,12 +66,17 @@ async def endpoint(**dependencies) -> JSONResponse: params = [] for condition in conditions: + dependency = Depends(condition) + + while isinstance(condition, functools.partial): + condition = condition.func + params.append( Parameter( f"{condition.__name__}", kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=bool, - default=Depends(condition), + default=dependency, ) ) endpoint.__signature__ = Signature(params) diff --git a/tests/test_health.py b/tests/test_health.py index 491ed3b..1f5f021 100644 --- a/tests/test_health.py +++ b/tests/test_health.py @@ -1,4 +1,5 @@ -from typing import Any +from typing import Any, Dict +from functools import partial import pytest from fastapi import Depends, FastAPI @@ -44,13 +45,16 @@ async def success_handler(**kwargs): return kwargs +def healthy_partial(healthy: bool) -> Dict[str, bool]: + return {"healthy": healthy} + + async def custom_failure_handler(**kwargs): is_success = all(kwargs.values()) return { "status": "success" if is_success else "failure", "results": [ - {"condition": condition, "output": value} - for condition, value in kwargs.items() + {"condition": condition, "output": value} for condition, value in kwargs.items() ], } @@ -64,9 +68,8 @@ async def custom_failure_handler(**kwargs): multiple_healthy_dict_app = create_app([healthy_dict, another_health_dict]) hybrid_app = create_app([healthy, sick, healthy_dict]) success_handler_app = create_app([healthy], success_handler=success_handler) -failure_handler_app = create_app( - [sick, healthy], failure_handler=custom_failure_handler -) +failure_handler_app = create_app([sick, healthy], failure_handler=custom_failure_handler) +partial_healthy_app = create_app([partial(healthy_partial, healthy=True)]) @pytest.mark.asyncio @@ -93,10 +96,11 @@ async def custom_failure_handler(**kwargs): ], }, ), + (partial_healthy_app, 200, {"healthy": True}), ), ) async def test_health(app: FastAPI, status_code: int, body: dict) -> None: async with AsyncClient(app=app, base_url="http://test") as client: res = await client.get("/health") - assert res.status_code == status_code + assert res.status_code == status_code, res.json() assert res.json() == body