Skip to content

Commit

Permalink
feat: use session id correctly (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrriehl committed Jun 15, 2023
1 parent 54139c4 commit e3514d0
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 9 deletions.
10 changes: 7 additions & 3 deletions src/uagents/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import functools
from typing import Dict, List, Optional, Set, Union, Type, Tuple, Any
import uuid

from cosmpy.aerial.wallet import LocalWallet, PrivateKey
from cosmpy.crypto.address import Address
Expand Down Expand Up @@ -326,8 +327,10 @@ def include(self, protocol: Protocol):
if protocol.digest is not None:
self.protocols[protocol.digest] = protocol

async def handle_message(self, sender, schema_digest: str, message: JsonStr):
await self._message_queue.put((schema_digest, sender, message))
async def handle_message(
self, sender, schema_digest: str, message: JsonStr, session: uuid.UUID
):
await self._message_queue.put((schema_digest, sender, message, session))

async def _startup(self):
for handler in self._on_startup:
Expand Down Expand Up @@ -379,7 +382,7 @@ def run(self):
async def _process_message_queue(self):
while True:
# get an element from the queue
schema_digest, sender, message = await self._message_queue.get()
schema_digest, sender, message, session = await self._message_queue.get()

# lookup the model definition
model_class: Model = self._models.get(schema_digest)
Expand All @@ -398,6 +401,7 @@ async def _process_message_queue(self):
self._wallet,
self._ledger,
self._queries,
session=session,
replies=self._replies,
interval_messages=self._interval_messages,
message_received=MsgDigest(
Expand Down
2 changes: 1 addition & 1 deletion src/uagents/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ async def __call__(self, scope, receive, send):
return

await dispatcher.dispatch(
env.sender, env.target, env.schema_digest, env.decode_payload()
env.sender, env.target, env.schema_digest, env.decode_payload(), env.session
)

# wait for any queries to be resolved
Expand Down
10 changes: 8 additions & 2 deletions src/uagents/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
wallet: LocalWallet,
ledger: LedgerClient,
queries: Dict[str, asyncio.Future],
session: Optional[uuid.UUID] = None,
replies: Optional[Dict[str, Set[Type[Model]]]] = None,
interval_messages: Optional[Set[str]] = None,
message_received: Optional[MsgDigest] = None,
Expand All @@ -60,6 +61,7 @@ def __init__(
self._resolver = resolve
self._identity = identity
self._queries = queries
self._session = session or uuid.uuid4()
self._replies = replies
self._interval_messages = interval_messages
self._message_received = message_received
Expand All @@ -84,6 +86,10 @@ def logger(self) -> logging.Logger:
def protocols(self) -> Optional[Dict[str, Protocol]]:
return self._protocols

@property
def session(self) -> uuid.UUID:
return self._session

def get_message_protocol(self, message_schema_digest) -> Optional[str]:
for protocol_digest, protocol in self._protocols.items():
for reply_models in protocol.replies.values():
Expand Down Expand Up @@ -128,7 +134,7 @@ async def send(
# handle local dispatch of messages
if dispatcher.contains(destination):
await dispatcher.dispatch(
self.address, destination, schema_digest, json_message
self.address, destination, schema_digest, json_message, self._session
)
return

Expand All @@ -154,7 +160,7 @@ async def send(
version=1,
sender=self.address,
target=destination,
session=uuid.uuid4(),
session=self._session,
schema_digest=schema_digest,
protocol_digest=self.get_message_protocol(schema_digest),
expires=expires,
Expand Down
14 changes: 11 additions & 3 deletions src/uagents/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from abc import ABC, abstractmethod
from typing import Dict, Set
import uuid

JsonStr = str


class Sink(ABC):
@abstractmethod
async def handle_message(self, sender: str, schema_digest: str, message: JsonStr):
async def handle_message(
self, sender: str, schema_digest: str, message: JsonStr, session: uuid.UUID
):
pass


Expand All @@ -28,10 +31,15 @@ def contains(self, address: str) -> bool:
return address in self._sinks

async def dispatch(
self, sender: str, destination: str, schema_digest: str, message: JsonStr
self,
sender: str,
destination: str,
schema_digest: str,
message: JsonStr,
session: uuid.UUID,
):
for handler in self._sinks.get(destination, set()):
await handler.handle_message(sender, schema_digest, message)
await handler.handle_message(sender, schema_digest, message, session)


dispatcher = Dispatcher()
1 change: 1 addition & 0 deletions src/uagents/mailbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ async def _handle_envelope(self, payload: dict):
env.target,
env.schema_digest,
env.decode_payload(),
env.session,
)

# queue envelope for deletion from server
Expand Down

0 comments on commit e3514d0

Please sign in to comment.