diff --git a/src/uagents/agent.py b/src/uagents/agent.py index ed576df2..647b0a6c 100644 --- a/src/uagents/agent.py +++ b/src/uagents/agent.py @@ -1,6 +1,7 @@ import asyncio import functools from typing import Dict, List, Optional, Set, Union, Type, Tuple, Any +import uuid from cosmpy.aerial.wallet import LocalWallet, PrivateKey from cosmpy.crypto.address import Address @@ -326,8 +327,10 @@ def include(self, protocol: Protocol): if protocol.digest is not None: self.protocols[protocol.digest] = protocol - async def handle_message(self, sender, schema_digest: str, message: JsonStr): - await self._message_queue.put((schema_digest, sender, message)) + async def handle_message( + self, sender, schema_digest: str, message: JsonStr, session: uuid.UUID + ): + await self._message_queue.put((schema_digest, sender, message, session)) async def _startup(self): for handler in self._on_startup: @@ -379,7 +382,7 @@ def run(self): async def _process_message_queue(self): while True: # get an element from the queue - schema_digest, sender, message = await self._message_queue.get() + schema_digest, sender, message, session = await self._message_queue.get() # lookup the model definition model_class: Model = self._models.get(schema_digest) @@ -398,6 +401,7 @@ async def _process_message_queue(self): self._wallet, self._ledger, self._queries, + session=session, replies=self._replies, interval_messages=self._interval_messages, message_received=MsgDigest( diff --git a/src/uagents/asgi.py b/src/uagents/asgi.py index f7eb039d..0b3956e7 100644 --- a/src/uagents/asgi.py +++ b/src/uagents/asgi.py @@ -158,7 +158,7 @@ async def __call__(self, scope, receive, send): return await dispatcher.dispatch( - env.sender, env.target, env.schema_digest, env.decode_payload() + env.sender, env.target, env.schema_digest, env.decode_payload(), env.session ) # wait for any queries to be resolved diff --git a/src/uagents/context.py b/src/uagents/context.py index ba58d245..a0e9b44b 100644 --- a/src/uagents/context.py +++ b/src/uagents/context.py @@ -46,6 +46,7 @@ def __init__( wallet: LocalWallet, ledger: LedgerClient, queries: Dict[str, asyncio.Future], + session: Optional[uuid.UUID] = None, replies: Optional[Dict[str, Set[Type[Model]]]] = None, interval_messages: Optional[Set[str]] = None, message_received: Optional[MsgDigest] = None, @@ -60,6 +61,7 @@ def __init__( self._resolver = resolve self._identity = identity self._queries = queries + self._session = session or uuid.uuid4() self._replies = replies self._interval_messages = interval_messages self._message_received = message_received @@ -84,6 +86,10 @@ def logger(self) -> logging.Logger: def protocols(self) -> Optional[Dict[str, Protocol]]: return self._protocols + @property + def session(self) -> uuid.UUID: + return self._session + def get_message_protocol(self, message_schema_digest) -> Optional[str]: for protocol_digest, protocol in self._protocols.items(): for reply_models in protocol.replies.values(): @@ -128,7 +134,7 @@ async def send( # handle local dispatch of messages if dispatcher.contains(destination): await dispatcher.dispatch( - self.address, destination, schema_digest, json_message + self.address, destination, schema_digest, json_message, self._session ) return @@ -154,7 +160,7 @@ async def send( version=1, sender=self.address, target=destination, - session=uuid.uuid4(), + session=self._session, schema_digest=schema_digest, protocol_digest=self.get_message_protocol(schema_digest), expires=expires, diff --git a/src/uagents/dispatch.py b/src/uagents/dispatch.py index 50bf145d..62ea014b 100644 --- a/src/uagents/dispatch.py +++ b/src/uagents/dispatch.py @@ -1,12 +1,15 @@ from abc import ABC, abstractmethod from typing import Dict, Set +import uuid JsonStr = str class Sink(ABC): @abstractmethod - async def handle_message(self, sender: str, schema_digest: str, message: JsonStr): + async def handle_message( + self, sender: str, schema_digest: str, message: JsonStr, session: uuid.UUID + ): pass @@ -28,10 +31,15 @@ def contains(self, address: str) -> bool: return address in self._sinks async def dispatch( - self, sender: str, destination: str, schema_digest: str, message: JsonStr + self, + sender: str, + destination: str, + schema_digest: str, + message: JsonStr, + session: uuid.UUID, ): for handler in self._sinks.get(destination, set()): - await handler.handle_message(sender, schema_digest, message) + await handler.handle_message(sender, schema_digest, message, session) dispatcher = Dispatcher() diff --git a/src/uagents/mailbox.py b/src/uagents/mailbox.py index a02de559..90553bf3 100644 --- a/src/uagents/mailbox.py +++ b/src/uagents/mailbox.py @@ -75,6 +75,7 @@ async def _handle_envelope(self, payload: dict): env.target, env.schema_digest, env.decode_payload(), + env.session, ) # queue envelope for deletion from server