Skip to content

Commit

Permalink
fix: raw text data for gcp (#36)
Browse files Browse the repository at this point in the history
* fix: raw text data for gcp

* fix: update cli test to conform with application json format

* fix: fastapi editor agents to comply with application/json

* fix: gcp editor agents to comply with application/json

* docs: update fastapi docs to no longer use form data

* fix: Vary header and specific allow origin

* feat: add CORS middleware for fastapi to make it easier to use

* feat: simpler CORS middleware for fastapi

* refactor: use common origin matching for cors

* docs: make docs display CORS middleware in reference section

* docs: typo

* fix: make mypy happy

* fix: cors domain regex account for us deployment

* chore: add todo for removing if else statement on Jan 31st

* test: add unittest for the CORS regex.
  • Loading branch information
frederik-encord authored Dec 20, 2024
1 parent 89d63ac commit b575385
Show file tree
Hide file tree
Showing 12 changed files with 171 additions and 46 deletions.
9 changes: 3 additions & 6 deletions docs/code_examples/fastapi/frame_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from anthropic import Anthropic
from encord.objects.ontology_labels_impl import LabelRowV2
from fastapi import Depends, FastAPI, Form
from fastapi.middleware.cors import CORSMiddleware
from numpy.typing import NDArray
from typing_extensions import Annotated

from encord_agents.core.data_model import Frame
from encord_agents.core.ontology import OntologyDataModel
from encord_agents.core.utils import get_user_client
from encord_agents.fastapi.cors import EncordCORSMiddleware
from encord_agents.fastapi.dependencies import (
FrameData,
dep_label_row,
Expand All @@ -19,10 +19,7 @@

# Initialize FastAPI app
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*", "https://app.encord.com"],
)
app.add_middleware(EncordCORSMiddleware)

# Setup project and data model
client = get_user_client()
Expand All @@ -47,7 +44,7 @@

@app.post("/frame_classification")
async def classify_frame(
frame_data: Annotated[FrameData, Form()],
frame_data: FrameData,
lr: Annotated[LabelRowV2, Depends(dep_label_row)],
content: Annotated[NDArray[np.uint8], Depends(dep_single_frame)],
):
Expand Down
11 changes: 4 additions & 7 deletions docs/code_examples/fastapi/object_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from anthropic import Anthropic
from encord.objects.ontology_labels_impl import LabelRowV2
from fastapi import Depends, FastAPI, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi import Depends, FastAPI
from typing_extensions import Annotated

from encord_agents.core.data_model import InstanceCrop
from encord_agents.core.ontology import OntologyDataModel
from encord_agents.core.utils import get_user_client
from encord_agents.fastapi.cors import EncordCORSMiddleware
from encord_agents.fastapi.dependencies import (
FrameData,
dep_label_row,
Expand All @@ -17,10 +17,7 @@

# Initialize FastAPI app
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*", "https://app.encord.com"],
)
app.add_middleware(EncordCORSMiddleware)

# User client and ontology setup
client = get_user_client()
Expand Down Expand Up @@ -49,7 +46,7 @@

@app.post("/object_classification")
async def classify_objects(
frame_data: Annotated[FrameData, Form()],
frame_data: FrameData,
lr: Annotated[LabelRowV2, Depends(dep_label_row)],
crops: Annotated[
list[InstanceCrop],
Expand Down
16 changes: 8 additions & 8 deletions docs/editor_agents/examples/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -1047,27 +1047,27 @@ Let us go through the code section by section.
First, we import dependencies and setup the FastAPI app with CORS middleware:

<!--codeinclude-->
[main.py](../../code_examples/fastapi/frame_classification.py) lines:1-25
[main.py](../../code_examples/fastapi/frame_classification.py) lines:1-22
<!--/codeinclude-->

The CORS middleware is crucial as it allows the Encord platform to make requests to your API.

Next, we set up the Project and create a data model based on the Ontology:

<!--codeinclude-->
[main.py](../../code_examples/fastapi/frame_classification.py) lines:28-30
[main.py](../../code_examples/fastapi/frame_classification.py) lines:25-27
<!--/codeinclude-->

We create the system prompt that tells Claude how to structure its response:

<!--codeinclude-->
[main.py](../../code_examples/fastapi/frame_classification.py) lines:33-45
[main.py](../../code_examples/fastapi/frame_classification.py) lines:30-42
<!--/codeinclude-->

Finally, we define the endpoint to handle the classification:

<!--codeinclude-->
[main.py](../../code_examples/fastapi/frame_classification.py) lines:48-78
[main.py](../../code_examples/fastapi/frame_classification.py) lines:45-75
<!--/codeinclude-->

The endpoint:
Expand Down Expand Up @@ -1155,25 +1155,25 @@ Let's walk through the key components.
First, we setup the FastAPI app and CORS middleware:

<!--codeinclude-->
[main.py](../../code_examples/fastapi/object_classification.py) lines:1-23
[main.py](../../code_examples/fastapi/object_classification.py) lines:1-20
<!--/codeinclude-->

Then we setup the client, Project, and extract the generic Ontology object:

<!--codeinclude-->
[main.py](../../code_examples/fastapi/object_classification.py) lines:26-32
[main.py](../../code_examples/fastapi/object_classification.py) lines:23-29
<!--/codeinclude-->

We create the data model and system prompt for Claude:

<!--codeinclude-->
[main.py](../../code_examples/fastapi/object_classification.py) lines:34-47
[main.py](../../code_examples/fastapi/object_classification.py) lines:32-44
<!--/codeinclude-->

Finally, we define our object classification endpoint:

<!--codeinclude-->
[main.py](../../code_examples/fastapi/object_classification.py) lines:50-97
[main.py](../../code_examples/fastapi/object_classification.py) lines:47-94
<!--/codeinclude-->

The endpoint:
Expand Down
9 changes: 3 additions & 6 deletions docs/editor_agents/fastapi.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,16 @@ from typing_extensions import Annotated
from encord.objects.ontology_labels_impl import LabelRowV2
from encord_agents import FrameData
from encord_agents.fastapi import dep_label_row
from encord_agents.fastapi.cors import EncordCORSMiddleware

from fastapi import FastAPI, Depends, Form
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*", "https://app.encord.com"],
)
app.add_middleware(EncordCORSMiddleware)

@app.post("/my_agent")
def my_agent(
frame_data: Annotated[FrameData, Form()],
frame_data: FrameData,
label_row: Annotated[LabelRowV2, Depends(dep_label_row)],
):
# ... Do your edits to the labels
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/editor_agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

## FastAPI

::: encord_agents.fastapi.dependencies
::: encord_agents.fastapi
options:
show_if_no_docstring: false
show_subodules: false
9 changes: 5 additions & 4 deletions encord_agents/cli/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def local(
request = requests.Request(
"POST",
f"http://localhost:{port}{target}",
data=payload,
headers={"Content-type": "application/x-www-form-urlencoded"},
json=payload,
headers={"Content-type": "application/json"},
)
prepped = request.prepare()

Expand All @@ -96,7 +96,8 @@ def local(
table.add_section()
table.add_row("[green]Request[/green]")
table.add_row("url", prepped.url)
table.add_row("data", prepped.body) # type: ignore
body_json_str = prepped.body.decode("utf-8") # type: ignore
table.add_row("data", body_json_str)
table_headers = ", ".join([f"'{k}': '{v}'" for k, v in prepped.headers.items()])
table.add_row("headers", f"{{{table_headers}}}")

Expand All @@ -115,7 +116,7 @@ def local(

headers = ["'{0}: {1}'".format(k, v) for k, v in prepped.headers.items()]
str_headers = " -H ".join(headers)
curl_command = f"curl -X {prepped.method} \\{os.linesep} -H {str_headers} \\{os.linesep} -d '{prepped.body!r}' \\{os.linesep} '{prepped.url}'"
curl_command = f"curl -X {prepped.method} \\{os.linesep} -H {str_headers} \\{os.linesep} -d '{body_json_str}' \\{os.linesep} '{prepped.url}'"
table.add_row("curl", curl_command)

rich.print(table)
3 changes: 3 additions & 0 deletions encord_agents/core/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ENCORD_DOMAIN_REGEX = (
r"^https:\/\/(?:(?:cord-ai-development--[\w\d]+-[\w\d]+\.web.app)|(?:(?:dev|staging|app)\.(us\.)?encord\.com))$"
)
60 changes: 60 additions & 0 deletions encord_agents/fastapi/cors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
Convenience method to easily extend FastAPI servers
with the appropriate CORS Middleware to allow
interactions from the Encord platform.
"""

import typing

try:
from fastapi.middleware.cors import CORSMiddleware
from starlette.types import ASGIApp
except ModuleNotFoundError:
print(
'To use the `fastapi` dependencies, you must also install fastapi. `python -m pip install "fastapi[standard]"'
)
exit()

from encord_agents.core.constants import ENCORD_DOMAIN_REGEX


# Type checking does not work here because we do not enforce people to
# install fastapi as they can use package for, e.g., task runner wo fastapi.
class EncordCORSMiddleware(CORSMiddleware): # type: ignore [misc]
"""
Like a regular `fastapi.midleware.cors.CORSMiddleware` but matches against
the Encord origin by default.
**Example:**
```python
from fastapi import FastAPI
from encord_agents.fastapi.cors import EncordCORSMiddleware
app = FastAPI()
app.add_middleware(EncordCORSMiddleware)
```
The CORS middleware will allow POST requests from the Encord domain.
"""

def __init__(
self,
app: ASGIApp,
allow_origins: typing.Sequence[str] = (),
allow_methods: typing.Sequence[str] = ("POST",),
allow_headers: typing.Sequence[str] = (),
allow_credentials: bool = False,
allow_origin_regex: str = ENCORD_DOMAIN_REGEX,
expose_headers: typing.Sequence[str] = (),
max_age: int = 3600,
) -> None:
super().__init__(
app,
allow_origins,
allow_methods,
allow_headers,
allow_credentials,
allow_origin_regex,
expose_headers,
max_age,
)
16 changes: 6 additions & 10 deletions encord_agents/fastapi/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
...
@app.post("/my-agent-route")
def my_agent(
frame_data: Annotated[FrameData, Form()],
frame_data: FrameData,
):
...
```
Expand Down Expand Up @@ -117,15 +117,15 @@ def my_route(
"""

def wrapper(frame_data: Annotated[FrameData, Form()]) -> LabelRowV2:
def wrapper(frame_data: FrameData) -> LabelRowV2:
return get_initialised_label_row(
frame_data, include_args=label_row_metadata_include_args, init_args=label_row_initialise_labels_args
)

return wrapper


def dep_label_row(frame_data: Annotated[FrameData, Form()]) -> LabelRowV2:
def dep_label_row(frame_data: FrameData) -> LabelRowV2:
"""
Dependency to provide an initialized label row.
Expand Down Expand Up @@ -154,9 +154,7 @@ def my_route(
return get_initialised_label_row(frame_data)


def dep_single_frame(
lr: Annotated[LabelRowV2, Depends(dep_label_row)], frame_data: Annotated[FrameData, Form()]
) -> NDArray[np.uint8]:
def dep_single_frame(lr: Annotated[LabelRowV2, Depends(dep_label_row)], frame_data: FrameData) -> NDArray[np.uint8]:
"""
Dependency to inject the underlying asset of the frame data.
Expand Down Expand Up @@ -266,9 +264,7 @@ def my_route(
yield iter_video(asset)


def dep_project(
frame_data: Annotated[FrameData, Form()], client: Annotated[EncordUserClient, Depends(dep_client)]
) -> Project:
def dep_project(frame_data: FrameData, client: Annotated[EncordUserClient, Depends(dep_client)]) -> Project:
r"""
Dependency to provide an instantiated
[Project](https://docs.encord.com/sdk-documentation/sdk-references/LabelRowV2){ target="\_blank", rel="noopener noreferrer" }.
Expand Down Expand Up @@ -327,7 +323,7 @@ def dep_data_lookup(lookup: Annotated[DataLookup, Depends(_lookup_adapter)]) ->
...
@app.post("/my-agent")
def my_agent(
frame_data: Annotated[FrameData, Form()],
frame_data: FrameData,
lookup: Annotated[DataLookup, Depends(dep_data_lookup)]
):
# Client will authenticated and ready to use.
Expand Down
4 changes: 1 addition & 3 deletions encord_agents/fastapi/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os

from pydantic import ValidationError

from encord_agents.core.settings import Settings
from encord_agents.core.utils import get_user_client
from encord_agents.exceptions import PrintableError
Expand All @@ -20,7 +18,7 @@ def verify_auth() -> None:
on_startup=[verify_auth]
```
This will make the server fail early if auth is not setup.
This will make the server fail early if auth is not set up.
"""
from datetime import datetime, timedelta

Expand Down
28 changes: 27 additions & 1 deletion encord_agents/gcp/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import re
from contextlib import ExitStack
from functools import wraps
from typing import Any, Callable
Expand All @@ -8,6 +9,7 @@
from flask import Request, Response, make_response

from encord_agents import FrameData
from encord_agents.core.constants import ENCORD_DOMAIN_REGEX
from encord_agents.core.data_model import LabelRowInitialiseLabelsArgs, LabelRowMetadataIncludeArgs
from encord_agents.core.dependencies.models import Context
from encord_agents.core.dependencies.utils import get_dependant, solve_dependencies
Expand Down Expand Up @@ -49,10 +51,34 @@ def editor_agent(

def context_wrapper_inner(func: AgentFunction) -> Callable[[Request], Response]:
dependant = get_dependant(func=func)
cors_regex = re.compile(ENCORD_DOMAIN_REGEX)

@wraps(func)
def wrapper(request: Request) -> Response:
frame_data = FrameData.model_validate_json(orjson.dumps(request.form.to_dict()))
if request.method == "OPTIONS":
response = make_response("")
response.headers["Vary"] = "Origin"

if not cors_regex.fullmatch(request.origin):
response.status_code = 403
return response

headers = {
"Access-Control-Allow-Origin": request.origin,
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}
response.headers.update(headers)
response.status_code = 204
return response

# TODO: We'll remove FF from FE on Jan. 31 2025.
# At that point, only the if statement applies and the else should be removed.
if request.is_json:
frame_data = FrameData.model_validate(request.get_json())
else:
frame_data = FrameData.model_validate_json(request.get_data())
logging.info(f"Request: {frame_data}")

client = get_user_client()
Expand Down
Loading

0 comments on commit b575385

Please sign in to comment.