Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions lib/hermes/server/transport/streamable_http.ex
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,20 @@ defmodule Hermes.Server.Transport.StreamableHTTP do

- `:server` - The server process (required)
- `:name` - Name for registering the GenServer (required)
- `:call_timeout` - Timeout for internal GenServer calls in milliseconds (default: 30 seconds)
"""
@type option ::
{:server, GenServer.server()}
| {:name, GenServer.name()}
| {:call_timeout, pos_integer()}
| GenServer.option()

defschema(:parse_options, [
{:server, {:required, Hermes.get_schema(:process_name)}},
{:name, {:required, {:custom, &Hermes.genserver_name/1}}},
{:registry, {:atom, {:default, Hermes.Server.Registry}}},
{:request_timeout, {:integer, {:default, to_timeout(second: 30)}}},
{:call_timeout, {:integer, {:default, to_timeout(second: 30)}}},
{:task_supervisor, {:required, {:custom, &Hermes.genserver_name/1}}}
])

Expand Down Expand Up @@ -106,8 +109,9 @@ defmodule Hermes.Server.Transport.StreamableHTTP do
"""
@impl Transport
@spec send_message(GenServer.server(), binary(), keyword()) :: :ok | {:error, term()}
def send_message(transport, message, _opts \\ []) when is_binary(message) do
GenServer.call(transport, {:send_message, message})
def send_message(transport, message, opts \\ []) when is_binary(message) do
timeout = Keyword.get(opts, :call_timeout, 5000)
GenServer.call(transport, {:send_message, message}, timeout)
end

@doc """
Expand All @@ -133,10 +137,11 @@ defmodule Hermes.Server.Transport.StreamableHTTP do
Called by the Plug when establishing an SSE connection.
The calling process becomes the SSE handler for the session.
"""
@spec register_sse_handler(GenServer.server(), String.t()) ::
@spec register_sse_handler(GenServer.server(), String.t(), keyword()) ::
:ok | {:error, term()}
def register_sse_handler(transport, session_id) do
GenServer.call(transport, {:register_sse_handler, session_id, self()})
def register_sse_handler(transport, session_id, opts \\ []) do
timeout = Keyword.get(opts, :call_timeout, 5000)
GenServer.call(transport, {:register_sse_handler, session_id, self()}, timeout)
end

@doc """
Expand All @@ -154,10 +159,11 @@ defmodule Hermes.Server.Transport.StreamableHTTP do

Called by the Plug when a message is received via HTTP POST.
"""
@spec handle_message(GenServer.server(), String.t(), map() | list(map), map()) ::
@spec handle_message(GenServer.server(), String.t(), map() | list(map), map(), keyword()) ::
{:ok, binary() | nil} | {:error, term()}
def handle_message(transport, session_id, message, context) do
GenServer.call(transport, {:handle_message, session_id, message, context})
def handle_message(transport, session_id, message, context, opts \\ []) do
timeout = Keyword.get(opts, :call_timeout, 5000)
GenServer.call(transport, {:handle_message, session_id, message, context}, timeout)
end

@doc """
Expand All @@ -166,12 +172,15 @@ defmodule Hermes.Server.Transport.StreamableHTTP do
This allows the Plug to know whether to stream the response via SSE
or return it as a regular HTTP response.
"""
@spec handle_message_for_sse(GenServer.server(), String.t(), map(), map()) ::
@spec handle_message_for_sse(GenServer.server(), String.t(), map(), map(), keyword()) ::
{:ok, binary()} | {:sse, binary()} | {:error, term()}
def handle_message_for_sse(transport, session_id, message, context) do
def handle_message_for_sse(transport, session_id, message, context, opts \\ []) do
timeout = Keyword.get(opts, :call_timeout, 5000)

GenServer.call(
transport,
{:handle_message_for_sse, session_id, message, context}
{:handle_message_for_sse, session_id, message, context},
timeout
)
end

Expand All @@ -181,9 +190,10 @@ defmodule Hermes.Server.Transport.StreamableHTTP do
Returns the pid of the process handling SSE for this session,
or nil if no SSE connection exists.
"""
@spec get_sse_handler(GenServer.server(), String.t()) :: pid() | nil
def get_sse_handler(transport, session_id) do
GenServer.call(transport, {:get_sse_handler, session_id})
@spec get_sse_handler(GenServer.server(), String.t(), keyword()) :: pid() | nil
def get_sse_handler(transport, session_id, opts \\ []) do
timeout = Keyword.get(opts, :call_timeout, 5000)
GenServer.call(transport, {:get_sse_handler, session_id}, timeout)
end

@doc """
Expand All @@ -207,6 +217,7 @@ defmodule Hermes.Server.Transport.StreamableHTTP do
server: server,
registry: opts.registry,
request_timeout: opts.request_timeout,
call_timeout: opts.call_timeout,
task_supervisor: opts.task_supervisor,
# Map of session_id => {pid, monitor_ref}
sse_handlers: %{},
Expand Down
56 changes: 32 additions & 24 deletions lib/hermes/server/transport/streamable_http/plug.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ if Code.ensure_loaded?(Plug) do

## SSE Streaming Architecture

This Plug handles SSE streaming by keeping the request process alive
This Plug handles SSE streaming by keeping the request process alive
and managing the streaming loop for server-to-client communication.

## Usage in Phoenix Router
Expand Down Expand Up @@ -81,8 +81,9 @@ if Code.ensure_loaded?(Plug) do
transport = registry.transport(server, :streamable_http)
session_header = Keyword.get(opts, :session_header, @default_session_header)
timeout = Keyword.get(opts, :timeout, @default_timeout)
call_timeout = Keyword.get(opts, :call_timeout, @default_timeout)

%{transport: transport, session_header: session_header, timeout: timeout}
%{transport: transport, session_header: session_header, timeout: timeout, call_timeout: call_timeout}
end

@impl Plug
Expand All @@ -97,11 +98,11 @@ if Code.ensure_loaded?(Plug) do

# GET request handler - establishes SSE connection

defp handle_get(conn, %{transport: transport, session_header: session_header}) do
defp handle_get(conn, %{transport: transport, session_header: session_header, call_timeout: call_timeout}) do
if wants_sse?(conn) do
session_id = get_or_create_session_id(conn, session_header)

case StreamableHTTP.register_sse_handler(transport, session_id) do
case StreamableHTTP.register_sse_handler(transport, session_id, call_timeout: call_timeout) do
:ok ->
start_sse_streaming(conn, transport, session_id, session_header)

Expand Down Expand Up @@ -129,7 +130,7 @@ if Code.ensure_loaded?(Plug) do
session_id: session_id
})

process_message(message, conn, transport, session_id, context, session_header)
process_message(message, conn, transport, session_id, context, session_header, opts.call_timeout)
else
{:error, :invalid_accept_header} ->
send_error(
Expand All @@ -156,20 +157,22 @@ if Code.ensure_loaded?(Plug) do
end
end

defp process_message(message, conn, transport, session_id, context, session_header) when is_map(message) do
defp process_message(message, conn, transport, session_id, context, session_header, call_timeout)
when is_map(message) do
if Message.is_request(message) do
handle_request_with_possible_sse(
conn,
transport,
session_id,
message,
context,
session_header
session_header,
call_timeout
)
else
# Notification
transport
|> StreamableHTTP.handle_message(session_id, message, context)
|> StreamableHTTP.handle_message(session_id, message, context, call_timeout: call_timeout)
|> format_notification_response(conn)
end
end
Expand Down Expand Up @@ -212,15 +215,16 @@ if Code.ensure_loaded?(Plug) do

# Handle requests that might need SSE streaming

defp handle_request_with_possible_sse(conn, transport, session_id, body, context, session_header) do
defp handle_request_with_possible_sse(conn, transport, session_id, body, context, session_header, call_timeout) do
if wants_sse?(conn) do
handle_sse_request(
conn,
transport,
session_id,
body,
context,
session_header
session_header,
call_timeout
)
else
handle_json_request(
Expand All @@ -229,17 +233,19 @@ if Code.ensure_loaded?(Plug) do
session_id,
body,
context,
session_header
session_header,
call_timeout
)
end
end

defp handle_sse_request(conn, transport, session_id, body, context, session_header) do
defp handle_sse_request(conn, transport, session_id, body, context, session_header, call_timeout) do
case StreamableHTTP.handle_message_for_sse(
transport,
session_id,
body,
context
context,
call_timeout: call_timeout
) do
{:sse, response} ->
route_sse_response(
Expand All @@ -249,7 +255,8 @@ if Code.ensure_loaded?(Plug) do
response,
body,
context,
session_header
session_header,
call_timeout
)

{:ok, response} ->
Expand All @@ -263,8 +270,8 @@ if Code.ensure_loaded?(Plug) do
end
end

defp handle_json_request(conn, transport, session_id, body, context, session_header) do
case StreamableHTTP.handle_message(transport, session_id, body, context) do
defp handle_json_request(conn, transport, session_id, body, context, session_header, call_timeout) do
case StreamableHTTP.handle_message(transport, session_id, body, context, call_timeout: call_timeout) do
{:ok, response} ->
conn
|> put_resp_content_type("application/json")
Expand All @@ -276,8 +283,8 @@ if Code.ensure_loaded?(Plug) do
end
end

defp route_sse_response(conn, transport, session_id, response, body, context, session_header) do
if handler_pid = StreamableHTTP.get_sse_handler(transport, session_id) do
defp route_sse_response(conn, transport, session_id, response, body, context, session_header, call_timeout) do
if handler_pid = StreamableHTTP.get_sse_handler(transport, session_id, call_timeout: call_timeout) do
send(handler_pid, {:sse_message, response})

conn
Expand All @@ -290,7 +297,8 @@ if Code.ensure_loaded?(Plug) do
session_id,
body,
context,
session_header
session_header,
call_timeout
)
end
end
Expand All @@ -309,10 +317,10 @@ if Code.ensure_loaded?(Plug) do
)
end

defp establish_sse_for_request(conn, transport, session_id, body, context, session_header) do
case StreamableHTTP.register_sse_handler(transport, session_id) do
defp establish_sse_for_request(conn, transport, session_id, body, context, session_header, call_timeout) do
case StreamableHTTP.register_sse_handler(transport, session_id, call_timeout: call_timeout) do
:ok ->
start_background_request(transport, session_id, body, context)
start_background_request(transport, session_id, body, context, call_timeout)
start_sse_streaming(conn, transport, session_id, session_header)

{:error, reason} ->
Expand All @@ -326,11 +334,11 @@ if Code.ensure_loaded?(Plug) do
end
end

defp start_background_request(transport, session_id, body, context) do
defp start_background_request(transport, session_id, body, context, call_timeout) do
self_pid = self()

Task.start(fn ->
case StreamableHTTP.handle_message(transport, session_id, body, context) do
case StreamableHTTP.handle_message(transport, session_id, body, context, call_timeout: call_timeout) do
{:ok, response} when is_binary(response) ->
send(self_pid, {:sse_message, response})

Expand Down