Skip to content

Commit

Permalink
Allow transparent passing of aiohttp connection options to support co…
Browse files Browse the repository at this point in the history
…okie JWT-auth over websockets
  • Loading branch information
jegger committed Oct 11, 2024
1 parent c90166c commit abacd05
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 5 additions & 1 deletion nats/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ async def connect(
inbox_prefix: Union[str, bytes] = DEFAULT_INBOX_PREFIX,
pending_size: int = DEFAULT_PENDING_SIZE,
flush_timeout: Optional[float] = None,
ws_client_options: Optional[dict] = None,
) -> None:
"""
Establishes a connection to NATS.
Expand Down Expand Up @@ -450,6 +451,9 @@ async def subscribe_handler(msg):
self._nkeys_seed = nkeys_seed
self._nkeys_seed_str = nkeys_seed_str

# Options to customize aiohttp client in case of websocket transport
self._ws_client_options = ws_client_options

# Customizable options
self.options["verbose"] = verbose
self.options["pedantic"] = pedantic
Expand Down Expand Up @@ -1348,7 +1352,7 @@ async def _select_next_server(self) -> None:
s.last_attempt = time.monotonic()
if not self._transport:
if s.uri.scheme in ("ws", "wss"):
self._transport = WebSocketTransport()
self._transport = WebSocketTransport(self._ws_client_options)
else:
# use TcpTransport as a fallback
self._transport = TcpTransport()
Expand Down
4 changes: 2 additions & 2 deletions nats/aio/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,13 @@ def __bool__(self):

class WebSocketTransport(Transport):

def __init__(self):
def __init__(self, client_options: Optional[dict] = None):
if not aiohttp:
raise ImportError(
"Could not import aiohttp transport, please install it with `pip install aiohttp`"
)
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
self._client: aiohttp.ClientSession = aiohttp.ClientSession()
self._client: aiohttp.ClientSession = aiohttp.ClientSession(**client_options)
self._pending = asyncio.Queue()
self._close_task = asyncio.Future()
self._using_tls: Optional[bool] = None
Expand Down

0 comments on commit abacd05

Please sign in to comment.