Skip to content

Commit

Permalink
fix based on review
Browse files Browse the repository at this point in the history
  • Loading branch information
zmezei committed Jul 6, 2023
1 parent e93c2b3 commit 9bf84ac
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 53 deletions.
63 changes: 20 additions & 43 deletions src/uagents/models.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,37 @@
import hashlib
from typing import Type, Union, Dict, ClassVar, Any

from typing import Type, Union, Dict
import json
from pydantic import BaseModel
from pydantic.schema import model_schema, default_ref_template


class Model(BaseModel):
schema_no_descriptions: ClassVar[Union[Dict[str, Any], None]] = None
@staticmethod
def remove_descriptions(schema: Dict[str, Dict[str, str]]):
fields_with_descr = []
if not "properties" in schema:
return
for field_name, field_props in schema["properties"].items():
if "description" in field_props:
fields_with_descr.append(field_name)

@classmethod
def _remove_descriptions(
cls, orig_descriptions: Dict[str, Union[str, Dict]]
):
for field_name, field in cls.__fields__.items():
if field.field_info and field.field_info.description:
orig_descriptions[field_name] = field.field_info.description
field.field_info.description = None
elif issubclass(field.type_, Model):
orig_descriptions[field_name] = {}
Model._remove_descriptions(field.type_, orig_descriptions[field_name])
for field_name in fields_with_descr:
del schema["properties"][field_name]["description"]

@classmethod
def _restore_descriptions(cls, orig_descriptions: Dict[str, Union[str, Dict]]
):
for field_name, field in cls.__fields__.items():
if (
field.field_info
and field_name in orig_descriptions
and not issubclass(field.type_, Model)
):
field.field_info.description = orig_descriptions[field_name]
elif issubclass(field.type_, Model):
Model._restore_descriptions(field.type_, orig_descriptions[field_name])
if "definitions" in schema:
for definition in schema["definitions"].values():
Model.remove_descriptions(definition)

@classmethod
def _restore_schema_cache(cls):
schema = model_schema(cls, by_alias=True, ref_template=default_ref_template)
cls.__schema_cache__[(True, default_ref_template)] = schema
def schema_json_no_descr(cls) -> str:
orig_schema = json.loads(cls.schema_json(indent=None, sort_keys=True))
Model.remove_descriptions(orig_schema)
return json.dumps(orig_schema)

@staticmethod
def build_schema_digest(model: Union["Model", Type["Model"]]) -> str:
type_obj: Type["Model"] = model if isinstance(model, type) else model.__class__
if type_obj.schema_no_descriptions is None:
orig_descriptions: Dict[str, Union[str, Dict]] = {}
type_obj._remove_descriptions(orig_descriptions)
digest = (
hashlib.sha256(
model.schema_json(indent=None, sort_keys=True).encode("utf8")
)
.digest()
.hex()
hashlib.sha256(model.schema_json_no_descr().encode("utf-8")).digest().hex()
)
if type_obj.schema_no_descriptions is None:
type_obj.schema_no_descriptions = type_obj.schema()
type_obj._restore_descriptions(orig_descriptions)
type_obj._restore_schema_cache()
return f"model:{digest}"


Expand Down
7 changes: 3 additions & 4 deletions src/uagents/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def manifest(self) -> Dict[str, Any]:

for schema_digest, model in all_models.items():
manifest["models"].append(
{"digest": schema_digest, "schema": model.schema_no_descriptions}
{"digest": schema_digest, "schema": model.schema_json_no_descr()}
)

for request, responses in self._replies.items():
Expand Down Expand Up @@ -204,7 +204,6 @@ def manifest(self) -> Dict[str, Any]:
manifest["models"].append(
{"digest": schema_digest, "schema": model.schema()}
)

final_manifest: Dict[str, Any] = copy.deepcopy(manifest)
final_manifest["metadata"] = metadata

Expand All @@ -213,9 +212,9 @@ def manifest(self) -> Dict[str, Any]:
@staticmethod
def compute_digest(manifest: Dict[str, Any]) -> str:
cleaned_manifest = copy.deepcopy(manifest)
if "metadata" in cleaned_manifest:
del cleaned_manifest["metadata"]
cleaned_manifest["metadata"] = {}
for model in cleaned_manifest["models"]:
Model.remove_descriptions(model["schema"])

encoded = json.dumps(cleaned_manifest, indent=None, sort_keys=True).encode(
"utf8"
Expand Down
46 changes: 40 additions & 6 deletions tests/test_field_descr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=function-redefined
import unittest
from pydantic import Field
from uagents import Model, Protocol
Expand Down Expand Up @@ -33,14 +34,37 @@ def setUp(self) -> None:
self.protocol_with_descr = protocol_with_descr
return super().setUp()

def test_field_description(self):
message_with_descr = create_message_with_descr()
def test_schema_json(self):
class Message(Model):
message: str
id: str

Model.build_schema_digest(message_with_descr)
self.assertEqual(
Message.schema_json(indent=None, sort_keys=True),
Message.schema_json_no_descr(),
)

message_field_info = message_with_descr.__fields__["message"].field_info
self.assertIsNotNone(message_field_info)
self.assertEqual(message_field_info.description, "message field description")
class Message(Model):
message: str = Field(description="message field description")
id: str = Field(description="id field description")

self.assertNotEqual(
Message.schema_json(indent=None, sort_keys=True),
Message.schema_json_no_descr(),
)

class MessageArgs(Model):
arg: str = Field(description="arg field description")

class Message(Model):
message: str = Field(description="message field description")
id: str = Field(description="id field description")
args: MessageArgs

self.assertNotEqual(
Message.schema_json(indent=None, sort_keys=True),
Message.schema_json_no_descr(),
)

def test_model_digest(self):
model_digest_no_descr = Model.build_schema_digest(create_message_no_descr())
Expand Down Expand Up @@ -82,6 +106,16 @@ def _(_ctx, _sender, _msg):
self.assertEqual(model_digest_no_descr, model_digest_with_descr)
self.assertEqual(proto_digest_no_descr, proto_digest_with_descr)

def test_compute_digest(self):
protocol = Protocol(name="test", version="1.1.1")

@protocol.on_message(create_message_with_descr())
def _(_ctx, _sender, _msg):
pass

# computed_digest = Protocol.compute_digest(protocol.manifest())
# self.assertEqual(protocol.digest, computed_digest)


if __name__ == "__main__":
unittest.main()

0 comments on commit 9bf84ac

Please sign in to comment.