diff --git a/.gitignore b/.gitignore index 5f0959b1..86214288 100644 --- a/.gitignore +++ b/.gitignore @@ -87,3 +87,4 @@ venv.bak/ site .bento/ +.vscode/ diff --git a/fastapi_utils/crud/__init__.py b/fastapi_utils/crud/__init__.py new file mode 100644 index 00000000..43ef8504 --- /dev/null +++ b/fastapi_utils/crud/__init__.py @@ -0,0 +1,4 @@ +from .base import CRUDBase +from .route import CRUDRoute + +__all__ = ["CRUDBase", "CRUDRoute"] diff --git a/fastapi_utils/crud/base.py b/fastapi_utils/crud/base.py new file mode 100644 index 00000000..2c46701b --- /dev/null +++ b/fastapi_utils/crud/base.py @@ -0,0 +1,158 @@ +from decimal import Decimal +from enum import Enum +from typing import Dict, Generic, List, Optional, Type, TypeVar, Union + +from fastapi.encoders import jsonable_encoder +from pydantic import BaseModel +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session + +from fastapi_utils.camelcase import snake2camel +from sqlalchemy_filters import apply_filters, apply_sort + +ModelType = TypeVar("ModelType") +MultiSchemaType = TypeVar("MultiSchemaType", bound=BaseModel) +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) +IDType = TypeVar("IDType") + + +class SortDirectionEnum(str, Enum): + ASC = "asc" + DESC = "desc" + + +class FilterOpEnum(str, Enum): + IS_NULL = "is_null" + IS_NOT_NULL = "is_not_null" + EQ_SYM = "==" + EQ = "eq" + NE_SYM = "!=" + NE = "ne" + GT_SYM = ">" + GT = "gt" + LT_SYM = "<" + LT = "lt" + GE_SYM = ">=" + GE = "ge" + LE_SYM = "<=" + LE = "le" + LIKE = "like" + ILIKE = "ilike" + IN = "in" + NOT_IN = "not_in" + + +class SortField(BaseModel): + field: str + model: Optional[str] = None + direction: SortDirectionEnum = SortDirectionEnum.DESC + + +class FilterField(BaseModel): + field: str + model: Optional[str] = None + op: FilterOpEnum + value: Union[str, int, Decimal] + + +def get_filter_field(field: str, field_name: str, split_character: str = ":") -> FilterField: + model = None + op, value = field.split(":") + if "__" in field_name: + model, field_name = field_name.split("__") + model = snake2camel(model, start_lower=False) + filter_field = FilterField(field=field_name, model=model, op=op, value=value) + return filter_field + + +def get_filter_fields(fields: Optional[Dict[str, str]], split_character: str = ":") -> List[FilterField]: + filter_fields = [] + if fields: + for field_name in fields: + if fields[field_name]: + filter_fields.append(get_filter_field(field=fields[field_name], field_name=field_name)) + return filter_fields + + +def get_sort_field(field: str) -> SortField: + model = None + field_name, direction = field.split(":") + if "__" in field_name: + model, field_name = field_name.split("__") + sort_field = SortField(model=model, field=field_name, direction=direction) + return sort_field + + +def get_sort_fields(sort_string: Optional[str], split_character: str = ",") -> List[SortField]: + sort_fields = [] + # There could be many sort fields + if sort_string: + sort_by_fields = sort_string.split(",") + for _to_sort in sort_by_fields: + sort_fields.append(get_sort_field(_to_sort)) + return sort_fields + + +class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): + def __init__(self, model: Type[ModelType]): + """ + CRUD object with default methods to Create, Read, Update, Delete (CRUD). + **Parameters** + * `model`: A SQLAlchemy model class + * `schema`: A Pydantic model (schema) class + """ + self.model = model + + def get(self, db_session: Session, id: IDType) -> Optional[ModelType]: + return db_session.query(self.model).filter(self.model.id == id).first() # type: ignore + + def get_many( + self, + db_session: Session, + *, + skip: int = 0, + limit: int = 100, + filter_by: Optional[Dict[str, str]] = None, + sort_by: Optional[str] = None, + ) -> List[ModelType]: + + sort_spec_pydantic = get_sort_fields(sort_by) + filter_spec_pydantic = get_filter_fields(filter_by) + + sort_spec = [x.dict(exclude_none=True) for x in sort_spec_pydantic] + filter_spec = [x.dict(exclude_none=True) for x in filter_spec_pydantic] + + query = db_session.query(self.model) + query = apply_filters(query, filter_spec) + query = apply_sort(query, sort_spec) + + # count = query.count() + query = query.offset(skip).limit(limit) + + return query.all() + + def create(self, db_session: Session, *, obj_in: CreateSchemaType) -> ModelType: + obj_in_data = jsonable_encoder(obj_in) + db_obj = self.model(**obj_in_data) + db_session.add(db_obj) + db_session.commit() + db_session.refresh(db_obj) + return db_obj + + def update(self, db_session: Session, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType: + obj_data = jsonable_encoder(db_obj) + update_data = obj_in.dict(skip_defaults=True) + for field in obj_data: + if field in update_data: + setattr(db_obj, field, update_data[field]) + db_session.add(db_obj) + db_session.commit() + db_session.refresh(db_obj) + return db_obj + + def remove(self, db_session: Session, *, id: IDType) -> ModelType: + obj = db_session.query(self.model).get(id) + db_session.delete(obj) + db_session.commit() + return obj diff --git a/fastapi_utils/crud/route.py b/fastapi_utils/crud/route.py new file mode 100644 index 00000000..332f4056 --- /dev/null +++ b/fastapi_utils/crud/route.py @@ -0,0 +1,107 @@ +from typing import ClassVar, Dict, Generic, Tuple, TypeVar + +from fastapi import Depends, HTTPException +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from fastapi_utils.crud import CRUDBase + +ResponseModelType = TypeVar("ResponseModelType", bound=BaseModel) +ResponseModelManyType = TypeVar("ResponseModelManyType", bound=BaseModel) +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) +CRUDBaseType = TypeVar("CRUDBaseType", bound=CRUDBase) +IDType = TypeVar("IDType") + + +class CRUDRoute(Generic[ResponseModelType, ResponseModelManyType, CreateSchemaType, UpdateSchemaType, IDType]): + """A base route that has the basic CRUD endpoints. + + For read_many + + """ + + crud_base: ClassVar[CRUDBaseType] + filter_fields: ClassVar[Tuple[str]] = () + db: Session = Depends(None) + object_name: ClassVar[str] = "CRUDBase" + + def read_many(self, skip: int = 0, limit: int = 100, sort_by: str = None, **kwargs) -> ResponseModelManyType: + """Reads many from the database with the provided filter and sort parameters. + + Filter parameters need to be specified by overriding this read_many method and calling it like: + + @router.get("/", response_model=List[Person]) + def read_persons( + self, skip: int = 0, limit: int = 100, sort_by: str = None, name: str = None, + ) -> List[Person]: + return super().read_many(skip=skip, limit=limit, sort_by=sort_by, name=name) + + Where the filter fields are defined as parameters. In this case "name" is a filter field + + Keyword Arguments: + skip {int} -- [description] (default: {0}) + limit {int} -- [description] (default: {100}) + sort_by {str} -- Expected in the form "model__field_name:asc,field_name:desc" (default: {None}) + + **kwargs {str} -- Filter field names expected in the form field_name or model__field_name if + filtering through + a join. The filter is defined as op:value. For example ==:paul or eq:paul + + The filter op is specified in the crud_base FilterOpEnum. + + Returns: + ResponseModelManyType -- [description] + """ + filter_fields: Dict[str, str] = {} + + for field in self.filter_fields: + filter_fields[field] = kwargs.pop(field, None) + + if len(kwargs) != 0: + raise ValueError(f"Method parameters have not been added to class filter fields {kwargs.keys()}") + + results = self.crud_base.get_many(self.db, skip=skip, limit=limit, filter_by=filter_fields, sort_by=sort_by) + return results + + def create(self, *, obj_in: CreateSchemaType,) -> ResponseModelType: + """ + Create new object. + """ + result = self.crud_base.create(db_session=self.db, obj_in=obj_in) + return result + + def update(self, *, id: IDType, obj_in: UpdateSchemaType,) -> ResponseModelType: + """ + Update an object. + """ + result = self.crud_base.get(db_session=self.db, id=id) + if not result: + raise HTTPException(status_code=404, detail=f"{self.object_name} not found") + # if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id): + # raise HTTPException(status_code=400, detail="Not enough permissions") + result = self.crud_base.update(db_session=self.db, db_obj=result, obj_in=obj_in) + return result + + def read(self, *, id: IDType,) -> ResponseModelType: + """ + Get object by ID. + """ + result = self.crud_base.get(db_session=self.db, id=id) + if not result: + raise HTTPException(status_code=404, detail=f"{self.object_name} not found") + # if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id): + # raise HTTPException(status_code=400, detail="Not enough permissions") + return result + + def delete(self, *, id: IDType,) -> ResponseModelType: + """ + Delete an object. + """ + result = self.crud_base.get(db_session=self.db, id=id) + if not result: + raise HTTPException(status_code=404, detail=f"{self.object_name} not found") + # if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id): + # raise HTTPException(status_code=400, detail="Not enough permissions") + result = self.crud_base.remove(db_session=self.db, id=id) + return result diff --git a/poetry.lock b/poetry.lock index 8ef126fa..bc862aa1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -603,7 +603,7 @@ security = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)"] socks = ["PySocks (>=1.5.6,<1.5.7 || >1.5.7)", "win-inet-pton"] [[package]] -category = "dev" +category = "main" description = "Python 2 and 3 compatibility utilities" name = "six" optional = false @@ -630,6 +630,24 @@ postgresql_psycopg2binary = ["psycopg2-binary"] postgresql_psycopg2cffi = ["psycopg2cffi"] pymysql = ["pymysql"] +[[package]] +category = "main" +description = "A library to filter SQLAlchemy queries." +name = "sqlalchemy-filters" +optional = false +python-versions = "*" +version = "0.10.0" + +[package.dependencies] +six = ">=1.10.0" +sqlalchemy = ">=1.0.16" + +[package.extras] +dev = ["pytest (4.3.0)", "flake8 (3.7.7)", "coverage (4.5.3)", "sqlalchemy-utils (0.33.11)", "restructuredtext-lint (1.2.2)", "Pygments (2.3.1)"] +mysql = ["mysql-connector-python-rf (2.2.2)"] +postgresql = ["psycopg2 (2.7.7)"] +python2 = ["funcsigs (>=1.0.2)"] + [[package]] category = "dev" description = "SQLAlchemy stubs and mypy plugin" @@ -720,7 +738,7 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["jaraco.itertools", "func-timeout"] [metadata] -content-hash = "2b727851846408766afb773a03cef7946f566c0817623186d8bcfed2e9d62557" +content-hash = "aa4f526dbf926321768fab7d310bdc0a449537a074babec63c01ea88c0cadacf" python-versions = "^3.6" [metadata.files] @@ -886,11 +904,6 @@ markupsafe = [ {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win32.whl", hash = "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-win32.whl", hash = "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"}, {file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"}, ] mccabe = [ @@ -1044,6 +1057,10 @@ six = [ sqlalchemy = [ {file = "SQLAlchemy-1.3.13.tar.gz", hash = "sha256:64a7b71846db6423807e96820993fa12a03b89127d278290ca25c0b11ed7b4fb"}, ] +sqlalchemy-filters = [ + {file = "sqlalchemy-filters-0.10.0.tar.gz", hash = "sha256:3b0d4fc39075cd1079e6089ac3165c1930b74fb1804515f109ec80e75fec46c8"}, + {file = "sqlalchemy_filters-0.10.0-py3-none-any.whl", hash = "sha256:34265e3b4605aae6e7c7fe3082b1de148c6295409f4d34286447f8c195bac699"}, +] sqlalchemy-stubs = [ {file = "sqlalchemy-stubs-0.3.tar.gz", hash = "sha256:a3318c810697164e8c818aa2d90bac570c1a0e752ced3ec25455b309c0bee8fd"}, {file = "sqlalchemy_stubs-0.3-py3-none-any.whl", hash = "sha256:ca1250605a39648cc433f5c70cb1a6f9fe0b60bdda4c51e1f9a2ab3651daadc8"}, diff --git a/pyproject.toml b/pyproject.toml index 750a4ceb..44c1db00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ python = "^3.6" fastapi = "*" pydantic = "^1.0" sqlalchemy = "^1.3.12" +sqlalchemy-filters = "^0.10.0" [tool.poetry.dev-dependencies] # Starlette features