diff --git a/encord/metadata_schema.py b/encord/metadata_schema.py index 9a8f6558a..d3fa0d1c6 100644 --- a/encord/metadata_schema.py +++ b/encord/metadata_schema.py @@ -1,5 +1,5 @@ import json -from enum import Enum +from enum import Enum, StrEnum from typing import Dict, Literal, Sequence, Union from pydantic import BaseModel, Field @@ -41,9 +41,13 @@ class _ClientMetadataSchemaTypeEnum(BaseModel): values: Sequence[str] = Field([], min_length=1, max_length=256) +class ClientMetadataSchemaTypeEmbeddingModel(StrEnum): + CLIP = "CLIP" + class _ClientMetadataSchemaTypeEmbedding(BaseModel): ty: Literal["embedding"] = "embedding" size: int = Field(gt=0, le=4096) + model: ClientMetadataSchemaTypeEmbeddingModel | None = None class _ClientMetadataSchemaTypeText(BaseModel): @@ -214,7 +218,7 @@ def save(self) -> None: ) self._dirty = False - def add_embedding(self, k: str, *, size: int) -> None: + def add_embedding(self, k: str, *, size: int, embedding_model: ClientMetadataSchemaTypeEmbeddingModel | None = None) -> None: """Adds a new embedding to the metadata schema. **Parameters:** @@ -229,7 +233,7 @@ def add_embedding(self, k: str, *, size: int) -> None: if k in self._schema: raise MetadataSchemaError(f"{k} is already defined") _assert_valid_metadata_key(k) - self._schema[k] = _ClientMetadataSchemaOption(root=_ClientMetadataSchemaTypeEmbedding(size=size)) + self._schema[k] = _ClientMetadataSchemaOption(root=_ClientMetadataSchemaTypeEmbedding(size=size, model=encord_supported_model if encord_supported_model else None)) self._dirty = True def add_enum(self, k: str, *, values: Sequence[str]) -> None: