Skip to content

Commit 3732def

Browse files
committed
feat(litestar): add get_session to plugin
1 parent b1fea84 commit 3732def

File tree

3 files changed

+116
-17
lines changed

3 files changed

+116
-17
lines changed

advanced_alchemy/extensions/litestar/plugins/__init__.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
from collections.abc import Sequence
2-
from typing import Union
1+
from collections.abc import AsyncGenerator, Generator, Sequence
2+
from contextlib import asynccontextmanager, contextmanager
3+
from typing import Any, Callable, Optional, Union, cast
34

45
from litestar.config.app import AppConfig
56
from litestar.plugins import InitPluginProtocol
7+
from sqlalchemy import Engine
8+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
9+
from sqlalchemy.orm import Session
610

711
from advanced_alchemy.extensions.litestar.plugins import _slots_base
812
from advanced_alchemy.extensions.litestar.plugins.init import (
@@ -41,10 +45,99 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig:
4145
4246
Args:
4347
app_config: The :class:`AppConfig <.config.app.AppConfig>` instance.
48+
49+
Returns:
50+
The :class:`AppConfig <.config.app.AppConfig>` instance.
4451
"""
4552
app_config.plugins.extend([SQLAlchemyInitPlugin(config=self._config), SQLAlchemySerializationPlugin()])
4653
return app_config
4754

55+
def _get_config(self, key: Optional[str] = None) -> Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]:
56+
"""Get a configuration by key.
57+
58+
Args:
59+
key: Optional key to identify the configuration. If not provided, uses the first config.
60+
61+
Raises:
62+
ValueError: If no configuration is found.
63+
64+
Returns:
65+
The SQLAlchemy configuration.
66+
"""
67+
if key is None:
68+
return self._config[0]
69+
for config in self._config:
70+
if getattr(config, "key", None) == key:
71+
return config
72+
msg = f"No configuration found with key {key}"
73+
raise ValueError(msg)
74+
75+
def get_session(
76+
self,
77+
key: Optional[str] = None,
78+
) -> Union[AsyncGenerator[AsyncSession, None], Generator[Session, None, None]]:
79+
"""Get a SQLAlchemy session.
80+
81+
Args:
82+
key: Optional key to identify the configuration. If not provided, uses the first config.
83+
84+
Returns:
85+
A SQLAlchemy session.
86+
"""
87+
config = self._get_config(key)
88+
89+
if isinstance(config, SQLAlchemyAsyncConfig):
90+
91+
@asynccontextmanager
92+
async def async_gen() -> AsyncGenerator[AsyncSession, None]:
93+
async with config.get_session() as session:
94+
yield session
95+
96+
return cast("AsyncGenerator[AsyncSession, None]", async_gen())
97+
98+
@contextmanager
99+
def sync_gen() -> Generator[Session, None, None]:
100+
with config.get_session() as session:
101+
yield session
102+
103+
return cast("Generator[Session, None, None]", sync_gen())
104+
105+
def provide_session(
106+
self,
107+
key: Optional[str] = None,
108+
) -> Callable[..., Union[AsyncGenerator[AsyncSession, None], Generator[Session, None, None]]]:
109+
"""Get a session provider for dependency injection.
110+
111+
Args:
112+
key: Optional key to identify the configuration. If not provided, uses the first config.
113+
114+
Returns:
115+
A callable that returns a session provider.
116+
"""
117+
118+
def provider(
119+
*args: Any, # noqa: ARG001
120+
**kwargs: Any, # noqa: ARG001
121+
) -> Union[AsyncGenerator[AsyncSession, None], Generator[Session, None, None]]:
122+
return self.get_session(key)
123+
124+
return provider
125+
126+
def get_engine(
127+
self,
128+
key: Optional[str] = None,
129+
) -> Union[AsyncEngine, Engine]:
130+
"""Get the SQLAlchemy engine.
131+
132+
Args:
133+
key: Optional key to identify the configuration. If not provided, uses the first config.
134+
135+
Returns:
136+
The SQLAlchemy engine.
137+
"""
138+
config = self._get_config(key)
139+
return config.get_engine()
140+
48141

49142
__all__ = (
50143
"EngineConfig",

advanced_alchemy/extensions/litestar/plugins/init/config/sync.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ def handler(message: "Message", scope: "Scope") -> None:
5454
Args:
5555
message: ASGI-``Message``
5656
scope: An ASGI-``Scope``
57-
58-
Returns:
59-
None
6057
"""
6158
session = cast("Optional[Session]", get_aa_scope_state(scope, session_scope_key))
6259
if session and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:

examples/litestar/litestar_service.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,13 @@ async def create_author(self, authors_service: AuthorService, data: AuthorCreate
101101
async def get_author(
102102
self,
103103
authors_service: AuthorService,
104-
author_id: UUID = Parameter(
105-
title="Author ID",
106-
description="The author to retrieve.",
107-
),
104+
author_id: Annotated[
105+
UUID,
106+
Parameter(
107+
title="Author ID",
108+
description="The author to retrieve.",
109+
),
110+
],
108111
) -> Author:
109112
"""Get an existing author."""
110113
obj = await authors_service.get(author_id)
@@ -115,10 +118,13 @@ async def update_author(
115118
self,
116119
authors_service: AuthorService,
117120
data: AuthorUpdate,
118-
author_id: UUID = Parameter(
119-
title="Author ID",
120-
description="The author to update.",
121-
),
121+
author_id: Annotated[
122+
UUID,
123+
Parameter(
124+
title="Author ID",
125+
description="The author to update.",
126+
),
127+
],
122128
) -> Author:
123129
"""Update an author."""
124130
obj = await authors_service.update(data, item_id=author_id, auto_commit=True)
@@ -128,10 +134,13 @@ async def update_author(
128134
async def delete_author(
129135
self,
130136
authors_service: AuthorService,
131-
author_id: UUID = Parameter(
132-
title="Author ID",
133-
description="The author to delete.",
134-
),
137+
author_id: Annotated[
138+
UUID,
139+
Parameter(
140+
title="Author ID",
141+
description="The author to delete.",
142+
),
143+
],
135144
) -> None:
136145
"""Delete a author from the system."""
137146
_ = await authors_service.delete(author_id)

0 commit comments

Comments
 (0)