Skip to content

Commit

Permalink
feat: protocol-aware envelopes (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrriehl committed Jun 6, 2023
1 parent 9b04c67 commit 54139c4
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 44 deletions.
22 changes: 14 additions & 8 deletions docs/protocol.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ Envelopes have the following form and are quite similar to blockchain transactio
```python
@dataclass
class Envelope:
sender: str: # bech32-encoded public address
target: str: # bech32-encoded public address
session: str # UUID
protocol: str # protocol digest
payload: bytes # JSON type: base64 str
expires: int # Unix timestamp in seconds
signature: str # bech32-encoded signature
sender: str: # bech32-encoded public address
target: str: # bech32-encoded public address
session: str # UUID
schema_digest: str # digest of message schema used for routing
protocol_digest: str # digest of protocol containing message
payload: bytes # JSON type: base64 str
expires: int # Unix timestamp in seconds
nonce: int # unique message nonce
signature: str # bech32-encoded signature
```

### Semantics
Expand All @@ -47,12 +49,16 @@ The **sender** field exposes the address of the sender of the message.

The **target** field exposes the address of the recipient of the message.

The **protocol** contains the unique schema digest string for the message.
The **schema_digest** contains the unique schema digest string for the message.

The **protocol_digest** contains the unique digest for protocol containing the message if available.

The **payload** field exposes the payload of the protocol. Its JSON representation should be a base64 encoded string.

The **expires** field contains the Unix timestamp in seconds at which the message is no longer valid.

The **nonce** is a sequential number used to ensure each message is unique.

The **signature** field contains the signature that is used to authenticate that the message has been sent from the **sender** agent.

Envelopes are then JSON encoded and sent to endpoints of other agents or services.
Expand Down
16 changes: 10 additions & 6 deletions examples/09-booking-protocol-demo/query.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import asyncio

from protocols.query import GetTotalQueries
from protocols.query import GetTotalQueries, TotalQueries

from uagents.contrib.protocols.protocol_query import ProtocolQuery, ProtocolResponse
from uagents.query import query


RESTAURANT_ADDRESS = "agent1qw50wcs4nd723ya9j8mwxglnhs2kzzhh0et0yl34vr75hualsyqvqdzl990"

get_total_queries = GetTotalQueries()
RESTAURANT_ADDRESS = "agent1qfpqn9jhvp9cg33f27q6jvmuv52dgyg9rfuu37rmxrletlqe7lewwjed5gy"


async def main():
env = await query(RESTAURANT_ADDRESS, get_total_queries)
print(f"Query response: {env.decode_payload()}")
env = await query(RESTAURANT_ADDRESS, GetTotalQueries())
msg = TotalQueries.parse_raw(env.decode_payload())
print(f"Query response: {msg.json()}\n\n")

env = await query(RESTAURANT_ADDRESS, ProtocolQuery())
msg = ProtocolResponse.parse_raw(env.decode_payload())
print("Protocol query response:", msg.json(indent=4))


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions examples/09-booking-protocol-demo/restaurant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from protocols.query import query_proto, TableStatus

from uagents import Agent
from uagents.contrib.protocols.protocol_query import proto_query
from uagents.setup import fund_agent_if_low


Expand All @@ -19,6 +20,7 @@
# build the restaurant agent from stock protocols
restaurant.include(query_proto)
restaurant.include(book_proto)
restaurant.include(proto_query)

TABLES = {
1: TableStatus(seats=2, time_start=16, time_end=22),
Expand Down
25 changes: 14 additions & 11 deletions src/uagents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,18 @@ def __init__(
self._models: Dict[str, Type[Model]] = {}
self._replies: Dict[str, Set[Type[Model]]] = {}
self._queries: Dict[str, asyncio.Future] = {}
self._dispatcher = dispatcher
self._message_queue = asyncio.Queue()
self._on_startup = []
self._on_shutdown = []
self._version = version or "0.1.0"

# initialize the internal agent protocol
self._protocol = Protocol(name=self._name, version=self._version)

# keep track of supported protocols
self.protocols: Dict[str, Protocol] = {}

self._ctx = Context(
self._identity.address,
self._name,
Expand All @@ -122,19 +134,9 @@ def __init__(
self._queries,
replies=self._replies,
interval_messages=self._interval_messages,
protocols=self.protocols,
logger=self._logger,
)
self._dispatcher = dispatcher
self._message_queue = asyncio.Queue()
self._on_startup = []
self._on_shutdown = []
self._version = version or "0.1.0"

# initialize the internal agent protocol
self._protocol = Protocol(name=self._name, version=self._version)

# keep track of supported protocols
self.protocols: Dict[str, Protocol] = {}

# register with the dispatcher
self._dispatcher.register(self.address, self)
Expand Down Expand Up @@ -401,6 +403,7 @@ async def _process_message_queue(self):
message_received=MsgDigest(
message=message, schema_digest=schema_digest
),
protocols=self.protocols,
logger=self._logger,
)

Expand Down
2 changes: 1 addition & 1 deletion src/uagents/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ async def __call__(self, scope, receive, send):
return

await dispatcher.dispatch(
env.sender, env.target, env.protocol, env.decode_payload()
env.sender, env.target, env.schema_digest, env.decode_payload()
)

# wait for any queries to be resolved
Expand Down
22 changes: 20 additions & 2 deletions src/uagents/context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations
import asyncio
import logging
import uuid
from dataclasses import dataclass
from time import time
from typing import Dict, Set, Optional, Callable, Any, Awaitable, Type
from typing import Dict, Set, Optional, Callable, Any, Awaitable, Type, TYPE_CHECKING

import aiohttp
from cosmpy.aerial.client import LedgerClient
Expand All @@ -17,6 +18,9 @@
from uagents.resolver import Resolver
from uagents.storage import KeyValueStore

if TYPE_CHECKING:
from uagents.protocol import Protocol

IntervalCallback = Callable[["Context"], Awaitable[None]]
MessageCallback = Callable[["Context", str, Any], Awaitable[None]]
EventCallback = Callable[["Context"], Awaitable[None]]
Expand Down Expand Up @@ -45,6 +49,7 @@ def __init__(
replies: Optional[Dict[str, Set[Type[Model]]]] = None,
interval_messages: Optional[Set[str]] = None,
message_received: Optional[MsgDigest] = None,
protocols: Optional[Dict[str, Protocol]] = None,
logger: Optional[logging.Logger] = None,
):
self.storage = storage
Expand All @@ -58,6 +63,7 @@ def __init__(
self._replies = replies
self._interval_messages = interval_messages
self._message_received = message_received
self._protocols = protocols or {}
self._logger = logger

@property
Expand All @@ -74,6 +80,17 @@ def address(self) -> str:
def logger(self) -> logging.Logger:
return self._logger

@property
def protocols(self) -> Optional[Dict[str, Protocol]]:
return self._protocols

def get_message_protocol(self, message_schema_digest) -> Optional[str]:
for protocol_digest, protocol in self._protocols.items():
for reply_models in protocol.replies.values():
if message_schema_digest in reply_models:
return protocol_digest
return None

async def send(
self,
destination: str,
Expand Down Expand Up @@ -138,7 +155,8 @@ async def send(
sender=self.address,
target=destination,
session=uuid.uuid4(),
protocol=schema_digest,
schema_digest=schema_digest,
protocol_digest=self.get_message_protocol(schema_digest),
expires=expires,
)
env.encode_payload(json_message)
Expand Down
Empty file added src/uagents/contrib/__init__.py
Empty file.
Empty file.
26 changes: 26 additions & 0 deletions src/uagents/contrib/protocols/protocol_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Any, Dict, List, Optional

from uagents import Context, Model, Protocol


class ProtocolQuery(Model):
protocol_digest: Optional[str]


class ProtocolResponse(Model):
manifests: List[Dict[str, Any]]


proto_query = Protocol(name="QueryProtocolManifests", version="0.1.0")


@proto_query.on_query(ProtocolQuery)
async def send_protocol_manifests(ctx: Context, sender: str, msg: ProtocolQuery):
manifests = []
if msg.protocol_digest is not None:
if msg.protocol_digest in ctx.protocols:
manifests = [ctx.protocols[msg.protocol_digest].manifest()]
else:
manifests = [proto.manifest() for proto in ctx.protocols.values()]

await ctx.send(sender, ProtocolResponse(manifests=manifests))
13 changes: 10 additions & 3 deletions src/uagents/envelope.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import struct
from typing import Optional, Any

from pydantic import BaseModel, UUID4
from pydantic import BaseModel, Field, UUID4

from uagents.crypto import Identity
from uagents.dispatch import JsonStr
Expand All @@ -14,11 +14,16 @@ class Envelope(BaseModel):
sender: str
target: str
session: UUID4
protocol: str
schema_digest: str = Field(alias="protocol")
protocol_digest: Optional[str] = None
payload: Optional[str] = None
expires: Optional[int] = None
nonce: Optional[int] = None
signature: Optional[str] = None

class Config:
allow_population_by_field_name = True

def encode_payload(self, value: JsonStr):
self.payload = base64.b64encode(value.encode()).decode()

Expand All @@ -42,9 +47,11 @@ def _digest(self) -> bytes:
hasher.update(self.sender.encode())
hasher.update(self.target.encode())
hasher.update(str(self.session).encode())
hasher.update(self.protocol.encode())
hasher.update(self.schema_digest.encode())
if self.payload is not None:
hasher.update(self.payload.encode())
if self.expires is not None:
hasher.update(struct.pack(">Q", self.expires))
if self.nonce is not None:
hasher.update(struct.pack(">Q", self.nonce))
return hasher.digest()
2 changes: 1 addition & 1 deletion src/uagents/mailbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def _handle_envelope(self, payload: dict):
await dispatcher.dispatch(
env.sender,
env.target,
env.protocol,
env.schema_digest,
env.decode_payload(),
)

Expand Down
4 changes: 2 additions & 2 deletions src/uagents/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def query(
sender=generate_user_address(),
target=destination,
session=uuid.uuid4(),
protocol=schema_digest,
schema_digest=schema_digest,
expires=expires,
)
env.encode_payload(json_message)
Expand Down Expand Up @@ -73,7 +73,7 @@ def enclose_response(message: Model, sender: str, session: str) -> str:
sender=sender,
target="",
session=session,
protocol=Model.build_schema_digest(message),
schema_digest=Model.build_schema_digest(message),
)
response_env.encode_payload(message.json())
return response_env.json()
10 changes: 0 additions & 10 deletions tests/test_basic.py

This file was deleted.

0 comments on commit 54139c4

Please sign in to comment.