Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add descriptions without modifying model and protocol digest #111

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions src/uagents/models.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,48 @@
import hashlib
from typing import Type, Union

from typing import Type, Union, Dict, Any
import json
import copy
from pydantic import BaseModel


class Model(BaseModel):
_schema_no_descriptions = None

@staticmethod
def remove_descriptions(schema: Dict[str, Any]):
if not "properties" in schema:
return

fields_with_descr = []
for field_name, field_props in schema["properties"].items():
if "description" in field_props:
fields_with_descr.append(field_name)
for field_name in fields_with_descr:
del schema["properties"][field_name]["description"]

if "definitions" in schema:
for definition in schema["definitions"].values():
Model.remove_descriptions(definition)

@classmethod
def schema_no_descriptions(cls) -> Dict[str, Any]:
if cls._schema_no_descriptions is None:
schema = copy.deepcopy(cls.schema())
Model.remove_descriptions(schema)
cls._schema_no_descriptions = schema
return cls._schema_no_descriptions

@classmethod
def schema_json_no_descriptions(cls) -> str:
return json.dumps(cls.schema_no_descriptions(), indent=None, sort_keys=True)

@staticmethod
def build_schema_digest(model: Union["Model", Type["Model"]]) -> str:
digest = (
hashlib.sha256(
model.schema_json(indent=None, sort_keys=True).encode("utf8")
)
hashlib.sha256(model.schema_json_no_descriptions().encode("utf-8"))
.digest()
.hex()
)

return f"model:{digest}"


Expand Down
11 changes: 8 additions & 3 deletions src/uagents/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def manifest(self) -> Dict[str, Any]:

for schema_digest, model in all_models.items():
manifest["models"].append(
{"digest": schema_digest, "schema": model.schema()}
{"digest": schema_digest, "schema": model.schema_no_descriptions()}
)

for request, responses in self._replies.items():
Expand All @@ -199,6 +199,11 @@ def manifest(self) -> Dict[str, Any]:
encoded = json.dumps(manifest, indent=None, sort_keys=True).encode("utf8")
metadata["digest"] = f"proto:{hashlib.sha256(encoded).digest().hex()}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

digest is also computed by the compute_digest static method.
It takes as input the manifest which will contain the descriptions hence producing a different digest.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified this method accordingly and also created test for it in test_fields_descr.py, please check it.


manifest["models"] = []
for schema_digest, model in all_models.items():
manifest["models"].append(
{"digest": schema_digest, "schema": model.schema()}
)
final_manifest: Dict[str, Any] = copy.deepcopy(manifest)
final_manifest["metadata"] = metadata

Expand All @@ -207,9 +212,9 @@ def manifest(self) -> Dict[str, Any]:
@staticmethod
def compute_digest(manifest: Dict[str, Any]) -> str:
cleaned_manifest = copy.deepcopy(manifest)
if "metadata" in cleaned_manifest:
del cleaned_manifest["metadata"]
cleaned_manifest["metadata"] = {}
for model in cleaned_manifest["models"]:
Model.remove_descriptions(model["schema"])

encoded = json.dumps(cleaned_manifest, indent=None, sort_keys=True).encode(
"utf8"
Expand Down
121 changes: 121 additions & 0 deletions tests/test_field_descr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# pylint: disable=function-redefined
import unittest
from pydantic import Field
from uagents import Model, Protocol


protocol_no_descr = Protocol(name="test", version="1.1.1")
protocol_with_descr = Protocol(name="test", version="1.1.1")


def create_message_no_descr():
class Message(Model):
message: str

return Message


def create_message_with_descr():
class Message(Model):
message: str = Field(description="message field description")

return Message


def get_digests(protocol: Protocol):
model_digest = next(iter(protocol.models))
proto_digest = protocol.digest
return (model_digest, proto_digest)


class TestFieldDescr(unittest.TestCase):
def setUp(self) -> None:
self.protocol_no_descr = protocol_no_descr
self.protocol_with_descr = protocol_with_descr
return super().setUp()

def test_schema_json(self):
class Message(Model):
message: str
id: str

self.assertEqual(
Message.schema_json(indent=None, sort_keys=True),
Message.schema_json_no_descriptions(),
)

class Message(Model):
message: str = Field(description="message field description")
id: str = Field(description="id field description")

self.assertNotEqual(
Message.schema_json(indent=None, sort_keys=True),
Message.schema_json_no_descriptions(),
)

class MessageArgs(Model):
arg: str = Field(description="arg field description")

class Message(Model):
message: str = Field(description="message field description")
id: str = Field(description="id field description")
args: MessageArgs

self.assertNotEqual(
Message.schema_json(indent=None, sort_keys=True),
Message.schema_json_no_descriptions(),
)

def test_model_digest(self):
model_digest_no_descr = Model.build_schema_digest(create_message_no_descr())
model_digest_with_descr = Model.build_schema_digest(create_message_with_descr())

self.assertEqual(model_digest_no_descr, model_digest_with_descr)

def test_protocol(self):
@self.protocol_no_descr.on_message(create_message_no_descr())
def _(_ctx, _sender, _msg):
pass

self.assertEqual(len(self.protocol_no_descr.models), 1)
self.assertNotIn(
"description",
self.protocol_no_descr.manifest()["models"][0]["schema"]["properties"][
"message"
],
)
(model_digest_no_descr, proto_digest_no_descr) = get_digests(
self.protocol_no_descr
)

@self.protocol_with_descr.on_message(create_message_with_descr())
def _(_ctx, _sender, _msg):
pass

self.assertEqual(len(self.protocol_with_descr.models), 1)
self.assertEqual(
self.protocol_with_descr.manifest()["models"][0]["schema"]["properties"][
"message"
]["description"],
"message field description",
)
(model_digest_with_descr, proto_digest_with_descr) = get_digests(
self.protocol_with_descr
)

self.assertEqual(model_digest_no_descr, model_digest_with_descr)
self.assertEqual(proto_digest_no_descr, proto_digest_with_descr)

def test_compute_digest(self):
protocol = Protocol(name="test", version="1.1.1")

@protocol.on_message(create_message_with_descr())
def _(_ctx, _sender, _msg):
pass

computed_digest = Protocol.compute_digest(protocol.manifest())
self.assertEqual(protocol.digest, computed_digest)


if __name__ == "__main__":
unittest.main()