diff --git a/src/uagents/models.py b/src/uagents/models.py index 7c3089f2..69b2f6eb 100644 --- a/src/uagents/models.py +++ b/src/uagents/models.py @@ -1,20 +1,48 @@ import hashlib -from typing import Type, Union - +from typing import Type, Union, Dict, Any +import json +import copy from pydantic import BaseModel class Model(BaseModel): + _schema_no_descriptions = None + + @staticmethod + def remove_descriptions(schema: Dict[str, Any]): + if not "properties" in schema: + return + + fields_with_descr = [] + for field_name, field_props in schema["properties"].items(): + if "description" in field_props: + fields_with_descr.append(field_name) + for field_name in fields_with_descr: + del schema["properties"][field_name]["description"] + + if "definitions" in schema: + for definition in schema["definitions"].values(): + Model.remove_descriptions(definition) + + @classmethod + def schema_no_descriptions(cls) -> Dict[str, Any]: + if cls._schema_no_descriptions is None: + schema = copy.deepcopy(cls.schema()) + Model.remove_descriptions(schema) + cls._schema_no_descriptions = schema + return cls._schema_no_descriptions + + @classmethod + def schema_json_no_descriptions(cls) -> str: + return json.dumps(cls.schema_no_descriptions(), indent=None, sort_keys=True) + @staticmethod def build_schema_digest(model: Union["Model", Type["Model"]]) -> str: digest = ( - hashlib.sha256( - model.schema_json(indent=None, sort_keys=True).encode("utf8") - ) + hashlib.sha256(model.schema_json_no_descriptions().encode("utf-8")) .digest() .hex() ) - return f"model:{digest}" diff --git a/src/uagents/protocol.py b/src/uagents/protocol.py index 40e2a17d..6ffb26e8 100644 --- a/src/uagents/protocol.py +++ b/src/uagents/protocol.py @@ -176,7 +176,7 @@ def manifest(self) -> Dict[str, Any]: for schema_digest, model in all_models.items(): manifest["models"].append( - {"digest": schema_digest, "schema": model.schema()} + {"digest": schema_digest, "schema": model.schema_no_descriptions()} ) for request, responses in self._replies.items(): @@ -199,6 +199,11 @@ def manifest(self) -> Dict[str, Any]: encoded = json.dumps(manifest, indent=None, sort_keys=True).encode("utf8") metadata["digest"] = f"proto:{hashlib.sha256(encoded).digest().hex()}" + manifest["models"] = [] + for schema_digest, model in all_models.items(): + manifest["models"].append( + {"digest": schema_digest, "schema": model.schema()} + ) final_manifest: Dict[str, Any] = copy.deepcopy(manifest) final_manifest["metadata"] = metadata @@ -207,9 +212,9 @@ def manifest(self) -> Dict[str, Any]: @staticmethod def compute_digest(manifest: Dict[str, Any]) -> str: cleaned_manifest = copy.deepcopy(manifest) - if "metadata" in cleaned_manifest: - del cleaned_manifest["metadata"] cleaned_manifest["metadata"] = {} + for model in cleaned_manifest["models"]: + Model.remove_descriptions(model["schema"]) encoded = json.dumps(cleaned_manifest, indent=None, sort_keys=True).encode( "utf8" diff --git a/tests/test_field_descr.py b/tests/test_field_descr.py new file mode 100644 index 00000000..38d4270e --- /dev/null +++ b/tests/test_field_descr.py @@ -0,0 +1,121 @@ +# pylint: disable=function-redefined +import unittest +from pydantic import Field +from uagents import Model, Protocol + + +protocol_no_descr = Protocol(name="test", version="1.1.1") +protocol_with_descr = Protocol(name="test", version="1.1.1") + + +def create_message_no_descr(): + class Message(Model): + message: str + + return Message + + +def create_message_with_descr(): + class Message(Model): + message: str = Field(description="message field description") + + return Message + + +def get_digests(protocol: Protocol): + model_digest = next(iter(protocol.models)) + proto_digest = protocol.digest + return (model_digest, proto_digest) + + +class TestFieldDescr(unittest.TestCase): + def setUp(self) -> None: + self.protocol_no_descr = protocol_no_descr + self.protocol_with_descr = protocol_with_descr + return super().setUp() + + def test_schema_json(self): + class Message(Model): + message: str + id: str + + self.assertEqual( + Message.schema_json(indent=None, sort_keys=True), + Message.schema_json_no_descriptions(), + ) + + class Message(Model): + message: str = Field(description="message field description") + id: str = Field(description="id field description") + + self.assertNotEqual( + Message.schema_json(indent=None, sort_keys=True), + Message.schema_json_no_descriptions(), + ) + + class MessageArgs(Model): + arg: str = Field(description="arg field description") + + class Message(Model): + message: str = Field(description="message field description") + id: str = Field(description="id field description") + args: MessageArgs + + self.assertNotEqual( + Message.schema_json(indent=None, sort_keys=True), + Message.schema_json_no_descriptions(), + ) + + def test_model_digest(self): + model_digest_no_descr = Model.build_schema_digest(create_message_no_descr()) + model_digest_with_descr = Model.build_schema_digest(create_message_with_descr()) + + self.assertEqual(model_digest_no_descr, model_digest_with_descr) + + def test_protocol(self): + @self.protocol_no_descr.on_message(create_message_no_descr()) + def _(_ctx, _sender, _msg): + pass + + self.assertEqual(len(self.protocol_no_descr.models), 1) + self.assertNotIn( + "description", + self.protocol_no_descr.manifest()["models"][0]["schema"]["properties"][ + "message" + ], + ) + (model_digest_no_descr, proto_digest_no_descr) = get_digests( + self.protocol_no_descr + ) + + @self.protocol_with_descr.on_message(create_message_with_descr()) + def _(_ctx, _sender, _msg): + pass + + self.assertEqual(len(self.protocol_with_descr.models), 1) + self.assertEqual( + self.protocol_with_descr.manifest()["models"][0]["schema"]["properties"][ + "message" + ]["description"], + "message field description", + ) + (model_digest_with_descr, proto_digest_with_descr) = get_digests( + self.protocol_with_descr + ) + + self.assertEqual(model_digest_no_descr, model_digest_with_descr) + self.assertEqual(proto_digest_no_descr, proto_digest_with_descr) + + def test_compute_digest(self): + protocol = Protocol(name="test", version="1.1.1") + + @protocol.on_message(create_message_with_descr()) + def _(_ctx, _sender, _msg): + pass + + computed_digest = Protocol.compute_digest(protocol.manifest()) + self.assertEqual(protocol.digest, computed_digest) + + +if __name__ == "__main__": + unittest.main()