Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions encord_agents/fastapi/cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@
interactions from the Encord platform.
"""

import asyncio
import json
import typing
from http import HTTPStatus
from uuid import UUID

from encord.exceptions import AuthorisationError
from pydantic import ValidationError

from encord_agents.core.data_model import FrameData

try:
from fastapi import FastAPI, Request
Expand Down Expand Up @@ -102,6 +108,45 @@ async def _authorization_error_exception_handler(request: Request, exc: Authoris
)


class FieldPairLockMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app: ASGIApp,
):
super().__init__(app)
self.field_locks: dict[tuple[UUID, UUID], asyncio.Lock] = {}
self.locks_lock = asyncio.Lock()

async def get_lock(self, frame_data: FrameData) -> asyncio.Lock:
lock_key = (frame_data.project_hash, frame_data.data_hash)
async with self.locks_lock:
if lock_key not in self.field_locks:
self.field_locks[lock_key] = asyncio.Lock()
return self.field_locks[lock_key]

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
if request.method != "POST":
return await call_next(request)
try:
body = await request.body()
try:
frame_data = FrameData.model_validate_json(body)
except ValidationError:
# Hope that route doesn't use FrameData
return await call_next(request)
lock = await self.get_lock(frame_data)
async with lock:
# Create a new request with the same body since we've already consumed it
request._body = body
return await call_next(request)
except Exception as e:
return Response(
content=json.dumps({"detail": f"Error in middleware: {str(e)}"}),
status_code=500,
media_type="application/json",
)


def get_encord_app(*, custom_cors_regex: str | None = None) -> FastAPI:
"""
Get a FastAPI app with the Encord middleware.
Expand All @@ -114,10 +159,12 @@ def get_encord_app(*, custom_cors_regex: str | None = None) -> FastAPI:
FastAPI: A FastAPI app with the Encord middleware.
"""
app = FastAPI()

app.add_middleware(
EncordCORSMiddleware,
allow_origin_regex=custom_cors_regex or ENCORD_DOMAIN_REGEX,
)
app.add_middleware(EncordTestHeaderMiddleware)
app.add_middleware(FieldPairLockMiddleware)
app.exception_handlers[AuthorisationError] = _authorization_error_exception_handler
return app
45 changes: 45 additions & 0 deletions tests/integration_tests/fastapi/test_dependencies.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from http import HTTPStatus
from typing import Annotated, NamedTuple
from uuid import uuid4
Expand Down Expand Up @@ -306,3 +307,47 @@ def post_client(client: Annotated[EncordUserClient, Depends(dep_client)]) -> Non
resp = client.post("/client", headers={"Origin": "https://example.com"})
assert resp.status_code == 200, resp.content
assert "Access-Control-Allow-Origin" not in resp.headers


class TestFieldPairLockMiddleware:
context: SharedResolutionContext
client: TestClient
list_holder: list[int | tuple[int, str]]

# Set the project and first label row for the class
@classmethod
@pytest.fixture(autouse=True)
def setup(cls, context: SharedResolutionContext) -> None:
cls.context = context
app = get_encord_app()
cls.list_holder = []

@app.post("/threadsafe-endpoint")
async def threadsafe_endpoint(frame_data: FrameData) -> None:
cls.list_holder.append(frame_data.frame)
await asyncio.sleep(0.1)
cls.list_holder.append((frame_data.frame, "DONE"))

cls.client = TestClient(app)

def test_field_pair_lock_middleware(self) -> None:
assert self.list_holder == []
resp = self.client.post(
"/threadsafe-endpoint",
json={
"projectHash": self.context.project.project_hash,
"dataHash": self.context.video_label_row.data_hash,
"frame": 0,
},
)
assert resp.status_code == 200, resp.content
resp = self.client.post(
"/threadsafe-endpoint",
json={
"projectHash": self.context.project.project_hash,
"dataHash": self.context.video_label_row.data_hash,
"frame": 1,
},
)
assert resp.status_code == 200, resp.content
assert self.list_holder == [0, (0, "DONE"), 1, (1, "DONE")]