diff --git a/encord_agents/fastapi/cors.py b/encord_agents/fastapi/cors.py index 7e87ed04..4234484b 100644 --- a/encord_agents/fastapi/cors.py +++ b/encord_agents/fastapi/cors.py @@ -4,10 +4,16 @@ interactions from the Encord platform. """ +import asyncio +import json import typing from http import HTTPStatus +from uuid import UUID from encord.exceptions import AuthorisationError +from pydantic import ValidationError + +from encord_agents.core.data_model import FrameData try: from fastapi import FastAPI, Request @@ -102,6 +108,45 @@ async def _authorization_error_exception_handler(request: Request, exc: Authoris ) +class FieldPairLockMiddleware(BaseHTTPMiddleware): + def __init__( + self, + app: ASGIApp, + ): + super().__init__(app) + self.field_locks: dict[tuple[UUID, UUID], asyncio.Lock] = {} + self.locks_lock = asyncio.Lock() + + async def get_lock(self, frame_data: FrameData) -> asyncio.Lock: + lock_key = (frame_data.project_hash, frame_data.data_hash) + async with self.locks_lock: + if lock_key not in self.field_locks: + self.field_locks[lock_key] = asyncio.Lock() + return self.field_locks[lock_key] + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + if request.method != "POST": + return await call_next(request) + try: + body = await request.body() + try: + frame_data = FrameData.model_validate_json(body) + except ValidationError: + # Hope that route doesn't use FrameData + return await call_next(request) + lock = await self.get_lock(frame_data) + async with lock: + # Create a new request with the same body since we've already consumed it + request._body = body + return await call_next(request) + except Exception as e: + return Response( + content=json.dumps({"detail": f"Error in middleware: {str(e)}"}), + status_code=500, + media_type="application/json", + ) + + def get_encord_app(*, custom_cors_regex: str | None = None) -> FastAPI: """ Get a FastAPI app with the Encord middleware. @@ -114,10 +159,12 @@ def get_encord_app(*, custom_cors_regex: str | None = None) -> FastAPI: FastAPI: A FastAPI app with the Encord middleware. """ app = FastAPI() + app.add_middleware( EncordCORSMiddleware, allow_origin_regex=custom_cors_regex or ENCORD_DOMAIN_REGEX, ) app.add_middleware(EncordTestHeaderMiddleware) + app.add_middleware(FieldPairLockMiddleware) app.exception_handlers[AuthorisationError] = _authorization_error_exception_handler return app diff --git a/tests/integration_tests/fastapi/test_dependencies.py b/tests/integration_tests/fastapi/test_dependencies.py index 5d72a5ab..f1d3ee40 100644 --- a/tests/integration_tests/fastapi/test_dependencies.py +++ b/tests/integration_tests/fastapi/test_dependencies.py @@ -1,3 +1,4 @@ +import asyncio from http import HTTPStatus from typing import Annotated, NamedTuple from uuid import uuid4 @@ -306,3 +307,47 @@ def post_client(client: Annotated[EncordUserClient, Depends(dep_client)]) -> Non resp = client.post("/client", headers={"Origin": "https://example.com"}) assert resp.status_code == 200, resp.content assert "Access-Control-Allow-Origin" not in resp.headers + + +class TestFieldPairLockMiddleware: + context: SharedResolutionContext + client: TestClient + list_holder: list[int | tuple[int, str]] + + # Set the project and first label row for the class + @classmethod + @pytest.fixture(autouse=True) + def setup(cls, context: SharedResolutionContext) -> None: + cls.context = context + app = get_encord_app() + cls.list_holder = [] + + @app.post("/threadsafe-endpoint") + async def threadsafe_endpoint(frame_data: FrameData) -> None: + cls.list_holder.append(frame_data.frame) + await asyncio.sleep(0.1) + cls.list_holder.append((frame_data.frame, "DONE")) + + cls.client = TestClient(app) + + def test_field_pair_lock_middleware(self) -> None: + assert self.list_holder == [] + resp = self.client.post( + "/threadsafe-endpoint", + json={ + "projectHash": self.context.project.project_hash, + "dataHash": self.context.video_label_row.data_hash, + "frame": 0, + }, + ) + assert resp.status_code == 200, resp.content + resp = self.client.post( + "/threadsafe-endpoint", + json={ + "projectHash": self.context.project.project_hash, + "dataHash": self.context.video_label_row.data_hash, + "frame": 1, + }, + ) + assert resp.status_code == 200, resp.content + assert self.list_holder == [0, (0, "DONE"), 1, (1, "DONE")]