Skip to content

Commit

Permalink
chore: Update SQLAlchemy to 2.0
Browse files Browse the repository at this point in the history
Updated SQLAlchemy to 2.0 to have proper type hints and newest API.
  • Loading branch information
gbdlin committed Aug 22, 2024
1 parent 4d7ccd6 commit bbc27bb
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 110 deletions.
17 changes: 8 additions & 9 deletions plugin_store/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from fastapi import Depends
from sqlalchemy import asc, desc
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
from sqlalchemy.sql import delete, select, update

from constants import SortDirection, SortType
Expand All @@ -21,21 +20,21 @@
from .models.Version import Version

if TYPE_CHECKING:
from typing import AsyncIterator, Iterable
from typing import AsyncIterator, Iterable, Sequence

logger = logging.getLogger()

UTC = ZoneInfo("UTC")


db_url = getenv("DB_URL")
if not db_url:
raise Exception("DB_URL not provided or invalid!")
async_engine = create_async_engine(
getenv("DB_URL"),
db_url,
pool_pre_ping=True,
# echo=settings.ECHO_SQL,
)
AsyncSessionLocal = sessionmaker(
bind=async_engine, autoflush=False, future=True, expire_on_commit=False, class_=AsyncSession
)
AsyncSessionLocal = async_sessionmaker(bind=async_engine, autoflush=False, future=True, expire_on_commit=False)

db_lock = Lock()

Expand Down Expand Up @@ -158,7 +157,7 @@ async def search(
sort_direction: SortDirection = SortDirection.DESC,
limit: int = 50,
page: int = 0,
) -> list["Artifact"]:
) -> "Sequence[Artifact]":
statement = select(Artifact).offset(limit * page)
if name:
statement = statement.where(Artifact.name.like(f"%{name}%"))
Expand Down
5 changes: 4 additions & 1 deletion plugin_store/database/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ async def run_migrations_online() -> None:
and associate a connection with the context.
"""
db_url = getenv("DB_URL")
if not db_url:
raise Exception("DB_URL not provided or invalid!")
connectable = create_async_engine(
getenv("DB_URL"),
db_url,
poolclass=pool.NullPool,
future=True,
)
Expand Down
27 changes: 14 additions & 13 deletions plugin_store/database/models/Artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from urllib.parse import quote

from sqlalchemy import Boolean, Column, ForeignKey, func, Integer, select, Table, Text, UniqueConstraint
from sqlalchemy.orm import column_property, relationship
from sqlalchemy.orm import column_property, Mapped, relationship

import constants

Expand All @@ -29,30 +29,31 @@ class Tag(Base):
class Artifact(Base):
__tablename__ = "artifacts"

id: int = Column(Integer, autoincrement=True, primary_key=True)
name: str = Column(Text)
author: str = Column(Text)
description: str = Column(Text)
_image_path: "str | None" = Column("image_path", Text, nullable=True)
tags: "list[Tag]" = relationship(
id: Mapped[int] = Column(Integer, autoincrement=True, primary_key=True)
name: Mapped[str] = Column(Text)
author: Mapped[str] = Column(Text)
description: Mapped[str] = Column(Text)
_image_path: Mapped[str | None] = Column("image_path", Text, nullable=True)
tags: "Mapped[list[Tag]]" = relationship(
"Tag", secondary=PluginTag, cascade="all, delete", order_by="Tag.tag", lazy="selectin"
)
versions: "list[Version]" = relationship(
versions: "Mapped[list[Version]]" = relationship(
"Version", cascade="all, delete", lazy="selectin", order_by="Version.created.desc(), Version.id.asc()"
)
visible: bool = Column(Boolean, default=True)
visible: Mapped[bool] = Column(Boolean, default=True)

downloads: int = column_property(
# Properties computed from relations
downloads: Mapped[int] = column_property(
select(func.sum(Version.downloads)).where(Version.artifact_id == id).correlate_except(Version).scalar_subquery()
)
updates: int = column_property(
updates: Mapped[int] = column_property(
select(func.sum(Version.updates)).where(Version.artifact_id == id).correlate_except(Version).scalar_subquery()
)

created: datetime = column_property(
created: Mapped[datetime] = column_property(
select(func.min(Version.created)).where(Version.artifact_id == id).correlate_except(Version).scalar_subquery()
)
updated: datetime = column_property(
updated: Mapped[datetime] = column_property(
select(func.max(Version.created)).where(Version.artifact_id == id).correlate_except(Version).scalar_subquery()
)

Expand Down
6 changes: 4 additions & 2 deletions plugin_store/database/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from sqlalchemy.types import TypeDecorator

if TYPE_CHECKING:
from typing import Any

from sqlalchemy.engine import Dialect

UTC = ZoneInfo("UTC")


class TZDateTime(TypeDecorator):
class TZDateTime(TypeDecorator[datetime]):
"""
A DateTime type which can only store tz-aware DateTimes.
"""
Expand All @@ -26,7 +28,7 @@ def process_bind_param(self, value: "datetime | None", dialect: "Dialect"):
return value.astimezone(UTC)
return value

def process_result_value(self, value: "datetime", dialect: "Dialect") -> "datetime | None":
def process_result_value(self, value: "Any | None", dialect: "Dialect") -> "datetime | None":
if isinstance(value, datetime) and value.tzinfo is None:
return value.replace(tzinfo=UTC)
return value
Expand Down
Loading

0 comments on commit bbc27bb

Please sign in to comment.