Skip to content

Commit b6c2cb1

Browse files
KaliszSFBegiello
andauthored
Add fernet encryption & CORS (#70)
* [feature] Add CORS middleware * [feature] Add CORS parsing in config, add supporting features * [feature] Add fernet encryptor for secret fields in config --------- Co-authored-by: Filip Begiełło <[email protected]>
1 parent bd6e326 commit b6c2cb1

File tree

5 files changed

+145
-5
lines changed

5 files changed

+145
-5
lines changed

{{project_name}}/app/config.py.jinja

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from functools import lru_cache
22
from pathlib import Path
3+
from typing import Any
34

4-
{%if project_type =="agent" %}
5-
from pydantic import SecretStr, field_validator
6-
{%else%}
7-
from pydantic import SecretStr
8-
{%endif%}
5+
from pydantic import AnyHttpUrl, Field, SecretStr, ValidationInfo, field_validator
96
from pydantic_settings import BaseSettings, SettingsConfigDict
107

8+
from app.utils.config_utils import (
9+
EncryptedField,
10+
EnvironmentType,
11+
FernetDecryptorField,
12+
)
13+
1114

1215
class Settings(BaseSettings):
1316
model_config = SettingsConfigDict(
@@ -18,11 +21,18 @@ class Settings(BaseSettings):
1821
# dotenv search .env when module is imported, without usecwd it starts from the file it was called
1922
)
2023

24+
# CORE SETTINGS
25+
fernet_decryptor: FernetDecryptorField = Field("MASTER_KEY")
26+
debug: bool = False
27+
environment: EnvironmentType = EnvironmentType.LOCAL
28+
2129
# API SETTINGS
2230
api_name: str = f"{{project_name}} API"
2331
api_v1: str = "/api/v1"
2432
api_latest: str = api_v1
2533
paging_limit: int = 100
34+
cors_origins: list[AnyHttpUrl] = []
35+
cors_allow_all: bool = False
2636

2737
{% if project_type in ["api-monolith", "api-microservice"] %}
2838
# DATABASE SETTINGS
@@ -58,6 +68,25 @@ class Settings(BaseSettings):
5868
return v
5969
{% endif %}
6070

71+
@field_validator("cors_origins", mode="after")
72+
@classmethod
73+
def assemble_cors_origins(cls, v: str | list[str]) -> list[str] | str:
74+
if isinstance(v, str) and not v.startswith("["):
75+
return [i.strip() for i in v.split(",")]
76+
if isinstance(v, (list, str)):
77+
return v
78+
79+
# This should never be reached given the type annotation, but ensures type safety
80+
raise ValueError(f"Unexpected type for cors_origins: {type(v)}")
81+
82+
@field_validator("*", mode="after")
83+
@classmethod
84+
def _decryptor(cls, v: Any, validation_info: ValidationInfo, *args, **kwargs) -> Any:
85+
if isinstance(v, EncryptedField):
86+
return v.get_decrypted_value(validation_info.data["fernet_decryptor"])
87+
return v
88+
89+
6190
{%if project_type == "mcp-server" %}
6291
# MCP SETTINGS
6392
mcp_server_name: str = f"MCP Server"

{{project_name}}/app/main.py.jinja

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ from app.mcp import mcp_router
2525
from app.database import engine
2626
from app.integrations.sqladmin.views import add_admin_views
2727
{% endif %}
28+
from app.middlewares import add_cors_middleware
2829

2930
basicConfig(level=INFO, format="[%(asctime)s - %(name)s] (%(levelname)s) %(message)s")
3031

@@ -35,6 +36,7 @@ admin = Admin(app=api, engine=engine)
3536
add_admin_views(admin)
3637
{% endif %}
3738

39+
add_cors_middleware(api)
3840

3941
@api.get("/")
4042
async def root() -> dict[str, str]:
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from fastapi import FastAPI
2+
from fastapi.middleware.cors import CORSMiddleware
3+
4+
from app.config import settings
5+
6+
7+
def add_cors_middleware(app: FastAPI) -> None:
8+
cors_origins = [str(origin).rstrip("/") for origin in settings.cors_origins]
9+
if settings.cors_allow_all:
10+
cors_origins = ["*"]
11+
12+
app.add_middleware(
13+
CORSMiddleware,
14+
allow_origins=cors_origins,
15+
allow_credentials=True,
16+
allow_methods=["*"],
17+
allow_headers=["*"],
18+
)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import os
2+
from enum import Enum
3+
from functools import wraps
4+
from typing import Any, Callable, Generator, Protocol
5+
6+
from cryptography.fernet import Fernet
7+
from pydantic import ValidationInfo
8+
9+
CallableGenerator = Generator[Callable[..., Any], None, None]
10+
11+
12+
class EnvironmentType(str, Enum):
13+
LOCAL = "local"
14+
TEST = "test"
15+
STAGING = "staging"
16+
PRODUCTION = "production"
17+
18+
19+
class Decryptor(Protocol):
20+
def decrypt(self, value: bytes) -> bytes: ...
21+
22+
23+
class FakeFernet:
24+
def decrypt(self, value: bytes) -> bytes:
25+
return value
26+
27+
28+
class EncryptedField(str):
29+
@classmethod
30+
def __get_pydantic_json_schema__(cls, field_schema: dict[str, Any]) -> None:
31+
field_schema.update(type="str", writeOnly=True)
32+
33+
@classmethod
34+
def __get_validators__(cls) -> "CallableGenerator":
35+
yield cls.validate
36+
37+
@classmethod
38+
def validate(cls, value: str, _: ValidationInfo) -> "EncryptedField":
39+
if isinstance(value, cls):
40+
return value
41+
return cls(value)
42+
43+
def __init__(self, value: str):
44+
self._secret_value = "".join(value.splitlines()).strip().encode("utf-8")
45+
self.decrypted = False
46+
47+
def get_decrypted_value(self, decryptor: Decryptor) -> str:
48+
if not self.decrypted:
49+
value = decryptor.decrypt(self._secret_value)
50+
self._secret_value = value
51+
self.decrypted = True
52+
return self._secret_value.decode("utf-8")
53+
54+
55+
class FernetDecryptorField(str):
56+
def __get_pydantic_json_schema__(self, field_schema: dict[str, Any]) -> None:
57+
field_schema.update(type="str", writeOnly=True)
58+
59+
@classmethod
60+
def __get_validators__(cls) -> "CallableGenerator":
61+
yield cls.validate
62+
63+
@classmethod
64+
def validate(cls, value: str, _: ValidationInfo) -> Decryptor:
65+
master_key = os.environ.get(value)
66+
if not master_key:
67+
return FakeFernet()
68+
return Fernet(os.environ[value])
69+
70+
71+
def set_env_from_settings(func: Callable[..., Any]) -> Callable[..., Any]:
72+
"""
73+
Decorator to set environment variables from settings.
74+
This decorator is useful for encrypted fields and providers that
75+
require API keys to be available as environment variables.
76+
"""
77+
78+
@wraps(func)
79+
def wrapper(*args, **kwargs) -> Any:
80+
settings = func(*args, **kwargs)
81+
# os.environ["EXAMPLE_API_KEY"] = settings.EXAMPLE_API_KEY
82+
return settings # noqa: RET504
83+
84+
return wrapper

{{project_name}}/config/.env.example.jinja

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
#--- APP ---#
2+
ENVIRONMENT="local"
3+
DEBUG=True
4+
CORS_ORIGINS=["http://localhost:8000", "https://localhost:8000", "http://localhost", "https://localhost"]
5+
SERVER_HOST="https://new_project_name.dev"
6+
7+
#--- DB ---#
18
DB_HOST=localhost
29
DP_PORT=5432
310
DB_NAME={{project_name}}

0 commit comments

Comments
 (0)