From 3662f6c7e4c3d19028292b680e21acd578fdd346 Mon Sep 17 00:00:00 2001 From: Matt Conflitti Date: Wed, 5 Feb 2025 12:18:33 -0500 Subject: [PATCH 01/10] add support for async server function --- shiny/_app.py | 10 +++++----- shiny/session/_session.py | 4 +++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/shiny/_app.py b/shiny/_app.py index 2929192cc..f9a0cce41 100644 --- a/shiny/_app.py +++ b/shiny/_app.py @@ -6,7 +6,7 @@ from contextlib import AsyncExitStack, asynccontextmanager from inspect import signature from pathlib import Path -from typing import Any, Callable, Mapping, Optional, TypeVar, cast +from typing import Any, Awaitable, Callable, Mapping, Optional, TypeVar, cast import starlette.applications import starlette.exceptions @@ -57,8 +57,8 @@ class App: returns a UI definition, if you need the UI definition to be created dynamically for each pageview. server - A function which is called once for each session, ensuring that each session is - independent. + A sync or async function which is called once for each session, ensuring that + each session is independent. static_assets Static files to be served by the app. If this is a string or Path object, it must be a directory, and it will be mounted at `/`. If this is a dictionary, @@ -104,13 +104,13 @@ def server(input: Inputs, output: Outputs, session: Session): """ ui: RenderedHTML | Callable[[Request], Tag | TagList] - server: Callable[[Inputs, Outputs, Session], None] + server: Callable[[Inputs, Outputs, Session], Awaitable[None] | None] def __init__( self, ui: Tag | TagList | Callable[[Request], Tag | TagList] | Path, server: ( - Callable[[Inputs], None] | Callable[[Inputs, Outputs, Session], None] | None + Callable[[Inputs], Awaitable[None] | None] | Callable[[Inputs, Outputs, Session], Awaitable[None] | None] | None ), *, static_assets: Optional[str | Path | Mapping[str, str | Path]] = None, diff --git a/shiny/session/_session.py b/shiny/session/_session.py index 834fc2211..d0b3b5e4c 100644 --- a/shiny/session/_session.py +++ b/shiny/session/_session.py @@ -633,7 +633,9 @@ def verify_state(expected_state: ConnectionState) -> None: self._manage_inputs(message_obj["data"]) with session_context(self): - self.app.server(self.input, self.output, self) + result = self.app.server(self.input, self.output, self) + if isinstance(result, Awaitable): + await result elif message_obj["method"] == "update": verify_state(ConnectionState.Running) From 04bf6b80732f6a66e7a51e4aa730e755d2dcc62b Mon Sep 17 00:00:00 2001 From: Matt Conflitti Date: Wed, 5 Feb 2025 12:34:15 -0500 Subject: [PATCH 02/10] lint fixes --- shiny/_app.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/shiny/_app.py b/shiny/_app.py index f9a0cce41..48a690df2 100644 --- a/shiny/_app.py +++ b/shiny/_app.py @@ -110,7 +110,9 @@ def __init__( self, ui: Tag | TagList | Callable[[Request], Tag | TagList] | Path, server: ( - Callable[[Inputs], Awaitable[None] | None] | Callable[[Inputs, Outputs, Session], Awaitable[None] | None] | None + Callable[[Inputs], Awaitable[None] | None] + | Callable[[Inputs, Outputs, Session], Awaitable[None] | None] + | None ), *, static_assets: Optional[str | Path | Mapping[str, str | Path]] = None, From aaa9c6e04c6d6e870d73dc80173f6b739c6f1068 Mon Sep 17 00:00:00 2001 From: Matt Conflitti Date: Wed, 5 Feb 2025 13:28:47 -0500 Subject: [PATCH 03/10] updated annotations; make server async regardless of function passed in --- shiny/_app.py | 18 +++++++++--------- shiny/_utils.py | 4 ++-- shiny/session/_session.py | 4 +--- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/shiny/_app.py b/shiny/_app.py index 48a690df2..9239c5dfa 100644 --- a/shiny/_app.py +++ b/shiny/_app.py @@ -29,7 +29,7 @@ from ._connection import Connection, StarletteConnection from ._error import ErrorMiddleware from ._shinyenv import is_pyodide -from ._utils import guess_mime_type, is_async_callable, sort_keys_length +from ._utils import guess_mime_type, is_async_callable, sort_keys_length, wrap_async from .html_dependencies import jquery_deps, require_deps, shiny_deps from .http_staticfiles import FileResponse, StaticFiles from .session._session import AppSession, Inputs, Outputs, Session, session_context @@ -104,7 +104,7 @@ def server(input: Inputs, output: Outputs, session: Session): """ ui: RenderedHTML | Callable[[Request], Tag | TagList] - server: Callable[[Inputs, Outputs, Session], Awaitable[None] | None] + server: Callable[[Inputs, Outputs, Session], Awaitable[None]] def __init__( self, @@ -123,13 +123,13 @@ def __init__( self._exit_stack = AsyncExitStack() if server is None: - self.server = noop_server_fn + self.server = wrap_async(noop_server_fn) elif len(signature(server).parameters) == 1: self.server = wrap_server_fn_with_output_session( - cast(Callable[[Inputs], None], server) + wrap_async(cast(Callable[[Inputs], Awaitable[None] | None], server)) ) elif len(signature(server).parameters) == 3: - self.server = cast(Callable[[Inputs, Outputs, Session], None], server) + self.server = wrap_async(cast(Callable[[Inputs, Outputs, Session], Awaitable[None] | None], server)) else: raise ValueError( "`server` must have 1 (Inputs) or 3 parameters (Inputs, Outputs, Session)" @@ -522,10 +522,10 @@ def noop_server_fn(input: Inputs, output: Outputs, session: Session) -> None: def wrap_server_fn_with_output_session( - server: Callable[[Inputs], None], -) -> Callable[[Inputs, Outputs, Session], None]: - def _server(input: Inputs, output: Outputs, session: Session): + server: Callable[[Inputs], Awaitable[None]], +) -> Callable[[Inputs, Outputs, Session], Awaitable[None]]: + async def _server(input: Inputs, output: Outputs, session: Session): # Only has 1 parameter, ignore output, session - server(input) + await server(input) return _server diff --git a/shiny/_utils.py b/shiny/_utils.py index ad84667de..4654e99b7 100644 --- a/shiny/_utils.py +++ b/shiny/_utils.py @@ -262,7 +262,7 @@ def private_seed() -> Generator[None, None, None]: def wrap_async( - fn: Callable[P, R] | Callable[P, Awaitable[R]], + fn: Callable[P, R] | Callable[P, Awaitable[R]] | Callable[P, Awaitable[R] | R], ) -> Callable[P, Awaitable[R]]: """ Given a synchronous function that returns R, return an async function that wraps the @@ -270,7 +270,7 @@ def wrap_async( """ if is_async_callable(fn): - return fn + return cast(Callable[P, Awaitable[R]], fn) fn = cast(Callable[P, R], fn) diff --git a/shiny/session/_session.py b/shiny/session/_session.py index d0b3b5e4c..f27591b81 100644 --- a/shiny/session/_session.py +++ b/shiny/session/_session.py @@ -633,9 +633,7 @@ def verify_state(expected_state: ConnectionState) -> None: self._manage_inputs(message_obj["data"]) with session_context(self): - result = self.app.server(self.input, self.output, self) - if isinstance(result, Awaitable): - await result + await self.app.server(self.input, self.output, self) elif message_obj["method"] == "update": verify_state(ConnectionState.Running) From 7162f5df462aab64c311ff6051ced13a23c5ac5d Mon Sep 17 00:00:00 2001 From: Matt Conflitti Date: Wed, 5 Feb 2025 13:45:30 -0500 Subject: [PATCH 04/10] lint fixes --- shiny/_app.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/shiny/_app.py b/shiny/_app.py index 9239c5dfa..7084ac076 100644 --- a/shiny/_app.py +++ b/shiny/_app.py @@ -129,7 +129,11 @@ def __init__( wrap_async(cast(Callable[[Inputs], Awaitable[None] | None], server)) ) elif len(signature(server).parameters) == 3: - self.server = wrap_async(cast(Callable[[Inputs, Outputs, Session], Awaitable[None] | None], server)) + self.server = wrap_async( + cast( + Callable[[Inputs, Outputs, Session], Awaitable[None] | None], server + ) + ) else: raise ValueError( "`server` must have 1 (Inputs) or 3 parameters (Inputs, Outputs, Session)" From f1aa19cd6af53c754bbda0d7fa90e825b4e0e032 Mon Sep 17 00:00:00 2001 From: Matt Conflitti Date: Wed, 5 Feb 2025 14:13:48 -0500 Subject: [PATCH 05/10] py3.9 needs union for cast --- shiny/_app.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/shiny/_app.py b/shiny/_app.py index 7084ac076..b893f4f4f 100644 --- a/shiny/_app.py +++ b/shiny/_app.py @@ -6,7 +6,7 @@ from contextlib import AsyncExitStack, asynccontextmanager from inspect import signature from pathlib import Path -from typing import Any, Awaitable, Callable, Mapping, Optional, TypeVar, cast +from typing import Any, Awaitable, Callable, Mapping, Optional, TypeVar, Union, cast import starlette.applications import starlette.exceptions @@ -126,12 +126,12 @@ def __init__( self.server = wrap_async(noop_server_fn) elif len(signature(server).parameters) == 1: self.server = wrap_server_fn_with_output_session( - wrap_async(cast(Callable[[Inputs], Awaitable[None] | None], server)) + wrap_async(cast(Callable[[Inputs], Union[Awaitable[None], None]], server)) ) elif len(signature(server).parameters) == 3: self.server = wrap_async( cast( - Callable[[Inputs, Outputs, Session], Awaitable[None] | None], server + Callable[[Inputs, Outputs, Session], Union[Awaitable[None], None]], server ) ) else: From 88bdf4b3717c7fe24ce2be4348ae826bcd97cb0b Mon Sep 17 00:00:00 2001 From: Matt Conflitti Date: Wed, 5 Feb 2025 14:20:01 -0500 Subject: [PATCH 06/10] lint fixes --- shiny/_app.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/shiny/_app.py b/shiny/_app.py index b893f4f4f..7c71f716c 100644 --- a/shiny/_app.py +++ b/shiny/_app.py @@ -126,12 +126,15 @@ def __init__( self.server = wrap_async(noop_server_fn) elif len(signature(server).parameters) == 1: self.server = wrap_server_fn_with_output_session( - wrap_async(cast(Callable[[Inputs], Union[Awaitable[None], None]], server)) + wrap_async( + cast(Callable[[Inputs], Union[Awaitable[None], None]], server) + ) ) elif len(signature(server).parameters) == 3: self.server = wrap_async( cast( - Callable[[Inputs, Outputs, Session], Union[Awaitable[None], None]], server + Callable[[Inputs, Outputs, Session], Union[Awaitable[None], None]], + server, ) ) else: From 58c0db5815c21551878276a1d81ccbded52f3274 Mon Sep 17 00:00:00 2001 From: Barret Schloerke Date: Fri, 28 Mar 2025 12:28:39 -0400 Subject: [PATCH 07/10] Allow for module server to be async --- shiny/module.py | 57 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/shiny/module.py b/shiny/module.py index a47e52e5e..2f32a6f00 100644 --- a/shiny/module.py +++ b/shiny/module.py @@ -1,8 +1,7 @@ from __future__ import annotations -__all__ = ("current_namespace", "resolve_id", "ui", "server", "ResolvedId") - -from typing import TYPE_CHECKING, Callable, TypeVar +import functools +from typing import TYPE_CHECKING, Awaitable, Callable, TypeVar, overload from ._docstring import no_example from ._namespaces import ( @@ -13,10 +12,13 @@ resolve_id, ) from ._typing_extensions import Concatenate, ParamSpec +from ._utils import is_async_callable, not_is_async_callable if TYPE_CHECKING: from .session import Inputs, Outputs, Session +__all__ = ("current_namespace", "resolve_id", "ui", "server", "ResolvedId") + P = ParamSpec("P") R = TypeVar("R") @@ -34,15 +36,50 @@ def wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R: @no_example() +# Use overloads so the function type stays the same for when the user calls it +@overload +def server( + fn: Callable[Concatenate[Inputs, Outputs, Session, P], Awaitable[R]], +) -> Callable[Concatenate[str, P], Awaitable[R]]: ... +@overload def server( fn: Callable[Concatenate[Inputs, Outputs, Session, P], R], -) -> Callable[Concatenate[str, P], R]: +) -> Callable[Concatenate[str, P], R]: ... +def server( + fn: ( + Callable[Concatenate[Inputs, Outputs, Session, P], R] + | Callable[Concatenate[Inputs, Outputs, Session, P], Awaitable[R]] + ), +) -> Callable[Concatenate[str, P], R] | Callable[Concatenate[str, P], Awaitable[R]]: from .session import require_active_session, session_context - def wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R: - sess = require_active_session(None) - child_sess = sess.make_scope(id) - with session_context(child_sess): - return fn(child_sess.input, child_sess.output, child_sess, *args, **kwargs) + if is_async_callable(fn): - return wrapper + @functools.wraps(fn) + async def async_wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R: + sess = require_active_session(None) + child_sess = sess.make_scope(id) + with session_context(child_sess): + return await fn( + child_sess.input, child_sess.output, child_sess, *args, **kwargs + ) + + return async_wrapper + + # Required for type narrowing. `TypeIs` did not seem to work as expected here. + if not_is_async_callable(fn): + + @functools.wraps(fn) + def sync_wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R: + sess = require_active_session(None) + child_sess = sess.make_scope(id) + with session_context(child_sess): + return fn( + child_sess.input, child_sess.output, child_sess, *args, **kwargs + ) + + return sync_wrapper + + raise RuntimeError( + "The provided function must be either synchronous or asynchronous." + ) From eacb07b85510f9c249974c527c1244ef38a2f2ab Mon Sep 17 00:00:00 2001 From: Barret Schloerke Date: Fri, 28 Mar 2025 12:29:26 -0400 Subject: [PATCH 08/10] Allow for express module server to be async --- shiny/express/_module.py | 72 ++++++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 17 deletions(-) diff --git a/shiny/express/_module.py b/shiny/express/_module.py index 9036e287a..ce23a77ba 100644 --- a/shiny/express/_module.py +++ b/shiny/express/_module.py @@ -1,8 +1,9 @@ import functools -from typing import Callable, TypeVar +from typing import Awaitable, Callable, TypeVar, overload from .._docstring import add_example from .._typing_extensions import Concatenate, ParamSpec +from .._utils import is_async_callable, not_is_async_callable from ..module import Id from ..session._session import Inputs, Outputs, Session from ..session._utils import require_active_session, session_context @@ -16,9 +17,21 @@ @add_example(ex_dir="../api-examples/express_module") +# Use overloads so the function type stays the same for when the user calls it +@overload +def module( + fn: Callable[Concatenate[Inputs, Outputs, Session, P], Awaitable[R]], +) -> Callable[Concatenate[Id, P], Awaitable[R]]: ... +@overload def module( fn: Callable[Concatenate[Inputs, Outputs, Session, P], R], -) -> Callable[Concatenate[Id, P], R]: +) -> Callable[Concatenate[Id, P], R]: ... +def module( + fn: ( + Callable[Concatenate[Inputs, Outputs, Session, P], R] + | Callable[Concatenate[Inputs, Outputs, Session, P], Awaitable[R]] + ), +) -> Callable[Concatenate[Id, P], R] | Callable[Concatenate[Id, P], Awaitable[R]]: """ Create a Shiny module using Shiny Express syntax @@ -42,18 +55,43 @@ def module( """ fn = expressify(fn) - @functools.wraps(fn) - def wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R: - parent_session = require_active_session(None) - module_session = parent_session.make_scope(id) - - with session_context(module_session): - return fn( - module_session.input, - module_session.output, - module_session, - *args, - **kwargs, - ) - - return wrapper + if is_async_callable(fn): + # If the function is async, we need to wrap it in an async wrapper + @functools.wraps(fn) + async def async_wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R: + parent_session = require_active_session(None) + module_session = parent_session.make_scope(id) + + with session_context(module_session): + return await fn( + module_session.input, + module_session.output, + module_session, + *args, + **kwargs, + ) + + return async_wrapper + + # Required for type narrowing. `TypeIs` did not seem to work as expected here. + if not_is_async_callable(fn): + + @functools.wraps(fn) + def wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R: + parent_session = require_active_session(None) + module_session = parent_session.make_scope(id) + + with session_context(module_session): + return fn( + module_session.input, + module_session.output, + module_session, + *args, + **kwargs, + ) + + return wrapper + + raise RuntimeError( + "The provided function must be either synchronous or asynchronous." + ) From d6b2420db067d50235bdfb7bc2fb67897e04e238 Mon Sep 17 00:00:00 2001 From: Barret Schloerke Date: Fri, 28 Mar 2025 12:41:42 -0400 Subject: [PATCH 09/10] Uncomment `not_is_async_callable()` --- shiny/_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/shiny/_utils.py b/shiny/_utils.py index 707990325..a4353facc 100644 --- a/shiny/_utils.py +++ b/shiny/_utils.py @@ -362,10 +362,10 @@ def is_async_callable( return False -# def not_is_async_callable( -# obj: Callable[P, T] | Callable[P, Awaitable[T]] -# ) -> TypeGuard[Callable[P, T]]: -# return not is_async_callable(obj) +def not_is_async_callable( + obj: Callable[P, T] | Callable[P, Awaitable[T]] +) -> TypeGuard[Callable[P, T]]: + return not is_async_callable(obj) # See https://stackoverflow.com/a/59780868/412655 for an excellent explanation From cbb3242ae7fdd9bf068aa81657f5f1accf2594ac Mon Sep 17 00:00:00 2001 From: Barret Schloerke Date: Fri, 28 Mar 2025 16:46:22 -0400 Subject: [PATCH 10/10] lints --- shiny/_app.py | 13 +++++++++++-- shiny/_utils.py | 2 +- shiny/express/_module.py | 2 ++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/shiny/_app.py b/shiny/_app.py index ee290f3ff..48975dac3 100644 --- a/shiny/_app.py +++ b/shiny/_app.py @@ -6,8 +6,17 @@ from contextlib import AsyncExitStack, asynccontextmanager from inspect import signature from pathlib import Path -from typing import Any, Awaitable, Callable, Mapping, Optional, TypeVar, Union, cast -from typing import Any, Callable, Literal, Mapping, Optional, TypeVar, cast +from typing import ( + Any, + Awaitable, + Callable, + Literal, + Mapping, + Optional, + TypeVar, + Union, + cast, +) import starlette.applications import starlette.exceptions diff --git a/shiny/_utils.py b/shiny/_utils.py index a4353facc..799988401 100644 --- a/shiny/_utils.py +++ b/shiny/_utils.py @@ -363,7 +363,7 @@ def is_async_callable( def not_is_async_callable( - obj: Callable[P, T] | Callable[P, Awaitable[T]] + obj: Callable[P, T] | Callable[P, Awaitable[T]], ) -> TypeGuard[Callable[P, T]]: return not is_async_callable(obj) diff --git a/shiny/express/_module.py b/shiny/express/_module.py index ce23a77ba..f49458ce2 100644 --- a/shiny/express/_module.py +++ b/shiny/express/_module.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools from typing import Awaitable, Callable, TypeVar, overload