Skip to content

Commit

Permalink
reformat, pass type instead of instance when refreshing schema cache
Browse files Browse the repository at this point in the history
  • Loading branch information
zmezei committed Jul 3, 2023
1 parent fba70dc commit 94ee30d
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions src/uagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
from pydantic.schema import model_schema, default_ref_template
from pydantic.main import BaseModel


class Model(BaseModel):
@staticmethod
def _remove_descriptions(model: Union["Model", Type["Model"]], orig_descriptions: Dict[str, Union[str, Dict]]):
def _remove_descriptions(
model: Type["Model"], orig_descriptions: Dict[str, Union[str, Dict]]
):
for _, field in model.__fields__.items():
if field.field_info and field.field_info.description:
orig_descriptions[field.name] = field.field_info.description
Expand All @@ -17,31 +20,42 @@ def _remove_descriptions(model: Union["Model", Type["Model"]], orig_descriptions
Model._remove_descriptions(field.type_, orig_descriptions[field.name])

@staticmethod
def _restore_descriptions(model: Union["Model", Type["Model"]], orig_descriptions: Dict[str, Union[str, Dict]]):
def _restore_descriptions(
model: Type["Model"], orig_descriptions: Dict[str, Union[str, Dict]]
):
for _, field in model.__fields__.items():
if field.field_info and field.name in orig_descriptions and not issubclass(field.type_, Model):
if (
field.field_info
and field.name in orig_descriptions
and not issubclass(field.type_, Model)
):
field.field_info.description = orig_descriptions[field.name]
elif issubclass(field.type_, Model):
Model._restore_descriptions(field.type_, orig_descriptions[field.name])

@staticmethod
def _refresh_schema_cache(model: Union["Model", Type["Model"]], by_alias: bool = True, ref_template: str = default_ref_template):
def _refresh_schema_cache(
model: Type["Model"],
by_alias: bool = True,
ref_template: str = default_ref_template,
):
s = model_schema(model, by_alias, ref_template)
model.__schema_cache__[(True, default_ref_template)] = s

@staticmethod
def build_schema_digest(model: Union["Model", Type["Model"]]) -> str:
orig_descriptions: Dict[str, Union[str, Dict]] = {}
Model._remove_descriptions(model, orig_descriptions)
obj_for_descr_remove = model if isinstance(model, type) else model.__class__
Model._remove_descriptions(obj_for_descr_remove, orig_descriptions)
digest = (
hashlib.sha256(
model.schema_json(indent=None, sort_keys=True).encode("utf8")
)
.digest()
.hex()
)
Model._restore_descriptions(model, orig_descriptions)
Model._refresh_schema_cache(model)
Model._restore_descriptions(obj_for_descr_remove, orig_descriptions)
Model._refresh_schema_cache(obj_for_descr_remove)
return f"model:{digest}"


Expand Down

0 comments on commit 94ee30d

Please sign in to comment.