Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add descriptions without modifying model and protocol digest #111

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions src/uagents/models.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,61 @@
import hashlib
from typing import Type, Union
from typing import Type, Union, Dict, ClassVar, Any

from pydantic import BaseModel
from pydantic.schema import model_schema, default_ref_template


class Model(BaseModel):
schema_no_descriptions: ClassVar[Union[Dict[str, Any], None]] = None

@staticmethod
def _remove_descriptions(
model: Type["Model"], orig_descriptions: Dict[str, Union[str, Dict]]
):
for field_name, field in model.__fields__.items():
if field.field_info and field.field_info.description:
orig_descriptions[field_name] = field.field_info.description
field.field_info.description = None
elif issubclass(field.type_, Model):
orig_descriptions[field_name] = {}
Model._remove_descriptions(field.type_, orig_descriptions[field_name])

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my point of view these methods shouldn't be static most of the times the model is even passed in why not use self then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model param passed to these newly created methods is actually not an instance of a Model subclass but rather a Model subclass type (so a class object) that's why I couldn't implement these methods as instance methods (so using self.). However you are right in the sense that these method could be class methods using cls.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the refactored version I created class methods instead of static methods where possible, please check it.

def _restore_descriptions(
model: Type["Model"], orig_descriptions: Dict[str, Union[str, Dict]]
):
for field_name, field in model.__fields__.items():
if (
field.field_info
and field_name in orig_descriptions
and not issubclass(field.type_, Model)
):
field.field_info.description = orig_descriptions[field_name]
elif issubclass(field.type_, Model):
Model._restore_descriptions(field.type_, orig_descriptions[field_name])

@staticmethod
def _restore_schema_cache(model: Type["Model"]):
schema = model_schema(model, by_alias=True, ref_template=default_ref_template)
model.__schema_cache__[(True, default_ref_template)] = schema

@staticmethod
def build_schema_digest(model: Union["Model", Type["Model"]]) -> str:
type_obj = model if isinstance(model, type) else model.__class__
if type_obj.schema_no_descriptions is None:
orig_descriptions: Dict[str, Union[str, Dict]] = {}
Model._remove_descriptions(type_obj, orig_descriptions)
digest = (
hashlib.sha256(
model.schema_json(indent=None, sort_keys=True).encode("utf8")
)
.digest()
.hex()
)

if type_obj.schema_no_descriptions is None:
type_obj.schema_no_descriptions = type_obj.schema()
Model._restore_descriptions(type_obj, orig_descriptions)
Model._restore_schema_cache(type_obj)
return f"model:{digest}"


Expand Down
8 changes: 7 additions & 1 deletion src/uagents/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def manifest(self) -> Dict[str, Any]:

for schema_digest, model in all_models.items():
manifest["models"].append(
{"digest": schema_digest, "schema": model.schema()}
{"digest": schema_digest, "schema": model.schema_no_descriptions}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it guaranteed that this value won't be None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the refactored version I call the schema_no_descriptions(), that method also uses a class attribute as a cache and that class attribute could possibly be None, but in the schema_no_descriptions() method I guarantee to set the value of that class variable now if that is None.

)

for request, responses in self._replies.items():
Expand All @@ -199,6 +199,12 @@ def manifest(self) -> Dict[str, Any]:
encoded = json.dumps(manifest, indent=None, sort_keys=True).encode("utf8")
metadata["digest"] = f"proto:{hashlib.sha256(encoded).digest().hex()}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

digest is also computed by the compute_digest static method.
It takes as input the manifest which will contain the descriptions hence producing a different digest.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified this method accordingly and also created test for it in test_fields_descr.py, please check it.


manifest["models"] = []
for schema_digest, model in all_models.items():
manifest["models"].append(
{"digest": schema_digest, "schema": model.schema()}
)

final_manifest: Dict[str, Any] = copy.deepcopy(manifest)
final_manifest["metadata"] = metadata

Expand Down
87 changes: 87 additions & 0 deletions tests/test_field_descr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import unittest
from pydantic import Field
from uagents import Model, Protocol


protocol_no_descr = Protocol(name="test", version="1.1.1")
protocol_with_descr = Protocol(name="test", version="1.1.1")


def create_message_no_descr():
class Message(Model):
message: str

return Message


def create_message_with_descr():
class Message(Model):
message: str = Field(description="message field description")

return Message


def get_digests(protocol: Protocol):
model_digest = next(iter(protocol.models))
proto_digest = protocol.digest
return (model_digest, proto_digest)


class TestFieldDescr(unittest.TestCase):
def setUp(self) -> None:
self.protocol_no_descr = protocol_no_descr
self.protocol_with_descr = protocol_with_descr
return super().setUp()

def test_field_description(self):
message_with_descr = create_message_with_descr()

Model.build_schema_digest(message_with_descr)

message_field_info = message_with_descr.__fields__["message"].field_info
self.assertIsNotNone(message_field_info)
self.assertEqual(message_field_info.description, "message field description")

def test_model_digest(self):
model_digest_no_descr = Model.build_schema_digest(create_message_no_descr())
model_digest_with_descr = Model.build_schema_digest(create_message_with_descr())

self.assertEqual(model_digest_no_descr, model_digest_with_descr)

def test_protocol(self):
@self.protocol_no_descr.on_message(create_message_no_descr())
def _(_ctx, _sender, _msg):
pass

self.assertEqual(len(self.protocol_no_descr.models), 1)
self.assertNotIn(
"description",
self.protocol_no_descr.manifest()["models"][0]["schema"]["properties"][
"message"
],
)
(model_digest_no_descr, proto_digest_no_descr) = get_digests(
self.protocol_no_descr
)

@self.protocol_with_descr.on_message(create_message_with_descr())
def _(_ctx, _sender, _msg):
pass

self.assertEqual(len(self.protocol_with_descr.models), 1)
self.assertEqual(
self.protocol_with_descr.manifest()["models"][0]["schema"]["properties"][
"message"
]["description"],
"message field description",
)
(model_digest_with_descr, proto_digest_with_descr) = get_digests(
self.protocol_with_descr
)

self.assertEqual(model_digest_no_descr, model_digest_with_descr)
self.assertEqual(proto_digest_no_descr, proto_digest_with_descr)


if __name__ == "__main__":
unittest.main()