From 54139c47caea282b77f9e673e3a7bc00e2ddb116 Mon Sep 17 00:00:00 2001 From: James Riehl <33920192+jrriehl@users.noreply.github.com> Date: Tue, 6 Jun 2023 09:51:29 +0100 Subject: [PATCH] feat: protocol-aware envelopes (#99) --- docs/protocol.md | 22 ++++++++++------ examples/09-booking-protocol-demo/query.py | 16 +++++++----- .../09-booking-protocol-demo/restaurant.py | 2 ++ src/uagents/agent.py | 25 ++++++++++-------- src/uagents/asgi.py | 2 +- src/uagents/context.py | 22 ++++++++++++++-- src/uagents/contrib/__init__.py | 0 src/uagents/contrib/protocols/__init__.py | 0 .../contrib/protocols/protocol_query.py | 26 +++++++++++++++++++ src/uagents/envelope.py | 13 +++++++--- src/uagents/mailbox.py | 2 +- src/uagents/query.py | 4 +-- tests/test_basic.py | 10 ------- 13 files changed, 100 insertions(+), 44 deletions(-) create mode 100644 src/uagents/contrib/__init__.py create mode 100644 src/uagents/contrib/protocols/__init__.py create mode 100644 src/uagents/contrib/protocols/protocol_query.py delete mode 100644 tests/test_basic.py diff --git a/docs/protocol.md b/docs/protocol.md index a0d02c1c..ac6ce511 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -32,13 +32,15 @@ Envelopes have the following form and are quite similar to blockchain transactio ```python @dataclass class Envelope: - sender: str: # bech32-encoded public address - target: str: # bech32-encoded public address - session: str # UUID - protocol: str # protocol digest - payload: bytes # JSON type: base64 str - expires: int # Unix timestamp in seconds - signature: str # bech32-encoded signature + sender: str: # bech32-encoded public address + target: str: # bech32-encoded public address + session: str # UUID + schema_digest: str # digest of message schema used for routing + protocol_digest: str # digest of protocol containing message + payload: bytes # JSON type: base64 str + expires: int # Unix timestamp in seconds + nonce: int # unique message nonce + signature: str # bech32-encoded signature ``` ### Semantics @@ -47,12 +49,16 @@ The **sender** field exposes the address of the sender of the message. The **target** field exposes the address of the recipient of the message. -The **protocol** contains the unique schema digest string for the message. +The **schema_digest** contains the unique schema digest string for the message. + +The **protocol_digest** contains the unique digest for protocol containing the message if available. The **payload** field exposes the payload of the protocol. Its JSON representation should be a base64 encoded string. The **expires** field contains the Unix timestamp in seconds at which the message is no longer valid. +The **nonce** is a sequential number used to ensure each message is unique. + The **signature** field contains the signature that is used to authenticate that the message has been sent from the **sender** agent. Envelopes are then JSON encoded and sent to endpoints of other agents or services. diff --git a/examples/09-booking-protocol-demo/query.py b/examples/09-booking-protocol-demo/query.py index 06c5262f..6edf5066 100644 --- a/examples/09-booking-protocol-demo/query.py +++ b/examples/09-booking-protocol-demo/query.py @@ -1,18 +1,22 @@ import asyncio -from protocols.query import GetTotalQueries +from protocols.query import GetTotalQueries, TotalQueries +from uagents.contrib.protocols.protocol_query import ProtocolQuery, ProtocolResponse from uagents.query import query -RESTAURANT_ADDRESS = "agent1qw50wcs4nd723ya9j8mwxglnhs2kzzhh0et0yl34vr75hualsyqvqdzl990" - -get_total_queries = GetTotalQueries() +RESTAURANT_ADDRESS = "agent1qfpqn9jhvp9cg33f27q6jvmuv52dgyg9rfuu37rmxrletlqe7lewwjed5gy" async def main(): - env = await query(RESTAURANT_ADDRESS, get_total_queries) - print(f"Query response: {env.decode_payload()}") + env = await query(RESTAURANT_ADDRESS, GetTotalQueries()) + msg = TotalQueries.parse_raw(env.decode_payload()) + print(f"Query response: {msg.json()}\n\n") + + env = await query(RESTAURANT_ADDRESS, ProtocolQuery()) + msg = ProtocolResponse.parse_raw(env.decode_payload()) + print("Protocol query response:", msg.json(indent=4)) if __name__ == "__main__": diff --git a/examples/09-booking-protocol-demo/restaurant.py b/examples/09-booking-protocol-demo/restaurant.py index 3b24770d..0551d63e 100644 --- a/examples/09-booking-protocol-demo/restaurant.py +++ b/examples/09-booking-protocol-demo/restaurant.py @@ -2,6 +2,7 @@ from protocols.query import query_proto, TableStatus from uagents import Agent +from uagents.contrib.protocols.protocol_query import proto_query from uagents.setup import fund_agent_if_low @@ -19,6 +20,7 @@ # build the restaurant agent from stock protocols restaurant.include(query_proto) restaurant.include(book_proto) +restaurant.include(proto_query) TABLES = { 1: TableStatus(seats=2, time_start=16, time_end=22), diff --git a/src/uagents/agent.py b/src/uagents/agent.py index 38a1de89..ed576df2 100644 --- a/src/uagents/agent.py +++ b/src/uagents/agent.py @@ -111,6 +111,18 @@ def __init__( self._models: Dict[str, Type[Model]] = {} self._replies: Dict[str, Set[Type[Model]]] = {} self._queries: Dict[str, asyncio.Future] = {} + self._dispatcher = dispatcher + self._message_queue = asyncio.Queue() + self._on_startup = [] + self._on_shutdown = [] + self._version = version or "0.1.0" + + # initialize the internal agent protocol + self._protocol = Protocol(name=self._name, version=self._version) + + # keep track of supported protocols + self.protocols: Dict[str, Protocol] = {} + self._ctx = Context( self._identity.address, self._name, @@ -122,19 +134,9 @@ def __init__( self._queries, replies=self._replies, interval_messages=self._interval_messages, + protocols=self.protocols, logger=self._logger, ) - self._dispatcher = dispatcher - self._message_queue = asyncio.Queue() - self._on_startup = [] - self._on_shutdown = [] - self._version = version or "0.1.0" - - # initialize the internal agent protocol - self._protocol = Protocol(name=self._name, version=self._version) - - # keep track of supported protocols - self.protocols: Dict[str, Protocol] = {} # register with the dispatcher self._dispatcher.register(self.address, self) @@ -401,6 +403,7 @@ async def _process_message_queue(self): message_received=MsgDigest( message=message, schema_digest=schema_digest ), + protocols=self.protocols, logger=self._logger, ) diff --git a/src/uagents/asgi.py b/src/uagents/asgi.py index 0418211c..f7eb039d 100644 --- a/src/uagents/asgi.py +++ b/src/uagents/asgi.py @@ -158,7 +158,7 @@ async def __call__(self, scope, receive, send): return await dispatcher.dispatch( - env.sender, env.target, env.protocol, env.decode_payload() + env.sender, env.target, env.schema_digest, env.decode_payload() ) # wait for any queries to be resolved diff --git a/src/uagents/context.py b/src/uagents/context.py index ca813a16..ba58d245 100644 --- a/src/uagents/context.py +++ b/src/uagents/context.py @@ -1,9 +1,10 @@ +from __future__ import annotations import asyncio import logging import uuid from dataclasses import dataclass from time import time -from typing import Dict, Set, Optional, Callable, Any, Awaitable, Type +from typing import Dict, Set, Optional, Callable, Any, Awaitable, Type, TYPE_CHECKING import aiohttp from cosmpy.aerial.client import LedgerClient @@ -17,6 +18,9 @@ from uagents.resolver import Resolver from uagents.storage import KeyValueStore +if TYPE_CHECKING: + from uagents.protocol import Protocol + IntervalCallback = Callable[["Context"], Awaitable[None]] MessageCallback = Callable[["Context", str, Any], Awaitable[None]] EventCallback = Callable[["Context"], Awaitable[None]] @@ -45,6 +49,7 @@ def __init__( replies: Optional[Dict[str, Set[Type[Model]]]] = None, interval_messages: Optional[Set[str]] = None, message_received: Optional[MsgDigest] = None, + protocols: Optional[Dict[str, Protocol]] = None, logger: Optional[logging.Logger] = None, ): self.storage = storage @@ -58,6 +63,7 @@ def __init__( self._replies = replies self._interval_messages = interval_messages self._message_received = message_received + self._protocols = protocols or {} self._logger = logger @property @@ -74,6 +80,17 @@ def address(self) -> str: def logger(self) -> logging.Logger: return self._logger + @property + def protocols(self) -> Optional[Dict[str, Protocol]]: + return self._protocols + + def get_message_protocol(self, message_schema_digest) -> Optional[str]: + for protocol_digest, protocol in self._protocols.items(): + for reply_models in protocol.replies.values(): + if message_schema_digest in reply_models: + return protocol_digest + return None + async def send( self, destination: str, @@ -138,7 +155,8 @@ async def send( sender=self.address, target=destination, session=uuid.uuid4(), - protocol=schema_digest, + schema_digest=schema_digest, + protocol_digest=self.get_message_protocol(schema_digest), expires=expires, ) env.encode_payload(json_message) diff --git a/src/uagents/contrib/__init__.py b/src/uagents/contrib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/uagents/contrib/protocols/__init__.py b/src/uagents/contrib/protocols/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/uagents/contrib/protocols/protocol_query.py b/src/uagents/contrib/protocols/protocol_query.py new file mode 100644 index 00000000..5053bfc5 --- /dev/null +++ b/src/uagents/contrib/protocols/protocol_query.py @@ -0,0 +1,26 @@ +from typing import Any, Dict, List, Optional + +from uagents import Context, Model, Protocol + + +class ProtocolQuery(Model): + protocol_digest: Optional[str] + + +class ProtocolResponse(Model): + manifests: List[Dict[str, Any]] + + +proto_query = Protocol(name="QueryProtocolManifests", version="0.1.0") + + +@proto_query.on_query(ProtocolQuery) +async def send_protocol_manifests(ctx: Context, sender: str, msg: ProtocolQuery): + manifests = [] + if msg.protocol_digest is not None: + if msg.protocol_digest in ctx.protocols: + manifests = [ctx.protocols[msg.protocol_digest].manifest()] + else: + manifests = [proto.manifest() for proto in ctx.protocols.values()] + + await ctx.send(sender, ProtocolResponse(manifests=manifests)) diff --git a/src/uagents/envelope.py b/src/uagents/envelope.py index 33d1f9d1..5349be10 100644 --- a/src/uagents/envelope.py +++ b/src/uagents/envelope.py @@ -3,7 +3,7 @@ import struct from typing import Optional, Any -from pydantic import BaseModel, UUID4 +from pydantic import BaseModel, Field, UUID4 from uagents.crypto import Identity from uagents.dispatch import JsonStr @@ -14,11 +14,16 @@ class Envelope(BaseModel): sender: str target: str session: UUID4 - protocol: str + schema_digest: str = Field(alias="protocol") + protocol_digest: Optional[str] = None payload: Optional[str] = None expires: Optional[int] = None + nonce: Optional[int] = None signature: Optional[str] = None + class Config: + allow_population_by_field_name = True + def encode_payload(self, value: JsonStr): self.payload = base64.b64encode(value.encode()).decode() @@ -42,9 +47,11 @@ def _digest(self) -> bytes: hasher.update(self.sender.encode()) hasher.update(self.target.encode()) hasher.update(str(self.session).encode()) - hasher.update(self.protocol.encode()) + hasher.update(self.schema_digest.encode()) if self.payload is not None: hasher.update(self.payload.encode()) if self.expires is not None: hasher.update(struct.pack(">Q", self.expires)) + if self.nonce is not None: + hasher.update(struct.pack(">Q", self.nonce)) return hasher.digest() diff --git a/src/uagents/mailbox.py b/src/uagents/mailbox.py index 57ca56b1..a02de559 100644 --- a/src/uagents/mailbox.py +++ b/src/uagents/mailbox.py @@ -73,7 +73,7 @@ async def _handle_envelope(self, payload: dict): await dispatcher.dispatch( env.sender, env.target, - env.protocol, + env.schema_digest, env.decode_payload(), ) diff --git a/src/uagents/query.py b/src/uagents/query.py index 8cbb28f5..678367dc 100644 --- a/src/uagents/query.py +++ b/src/uagents/query.py @@ -44,7 +44,7 @@ async def query( sender=generate_user_address(), target=destination, session=uuid.uuid4(), - protocol=schema_digest, + schema_digest=schema_digest, expires=expires, ) env.encode_payload(json_message) @@ -73,7 +73,7 @@ def enclose_response(message: Model, sender: str, session: str) -> str: sender=sender, target="", session=session, - protocol=Model.build_schema_digest(message), + schema_digest=Model.build_schema_digest(message), ) response_env.encode_payload(message.json()) return response_env.json() diff --git a/tests/test_basic.py b/tests/test_basic.py deleted file mode 100644 index 6fa47bb5..00000000 --- a/tests/test_basic.py +++ /dev/null @@ -1,10 +0,0 @@ -import unittest - - -class MyTestCase(unittest.TestCase): - def test_something(self): - self.assertEqual(True, True) # add assertion here - - -if __name__ == "__main__": - unittest.main()