Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions encord/metadata_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from enum import Enum
from enum import Enum, StrEnum
from typing import Dict, Literal, Sequence, Union

from pydantic import BaseModel, Field
Expand Down Expand Up @@ -41,9 +41,13 @@ class _ClientMetadataSchemaTypeEnum(BaseModel):
values: Sequence[str] = Field([], min_length=1, max_length=256)


class ClientMetadataSchemaTypeEmbeddingModel(StrEnum):
CLIP = "CLIP"

class _ClientMetadataSchemaTypeEmbedding(BaseModel):
ty: Literal["embedding"] = "embedding"
size: int = Field(gt=0, le=4096)
model: ClientMetadataSchemaTypeEmbeddingModel | None = None


class _ClientMetadataSchemaTypeText(BaseModel):
Expand Down Expand Up @@ -214,7 +218,7 @@ def save(self) -> None:
)
self._dirty = False

def add_embedding(self, k: str, *, size: int) -> None:
def add_embedding(self, k: str, *, size: int, embedding_model: ClientMetadataSchemaTypeEmbeddingModel | None = None) -> None:
"""Adds a new embedding to the metadata schema.

**Parameters:**
Expand All @@ -229,7 +233,7 @@ def add_embedding(self, k: str, *, size: int) -> None:
if k in self._schema:
raise MetadataSchemaError(f"{k} is already defined")
_assert_valid_metadata_key(k)
self._schema[k] = _ClientMetadataSchemaOption(root=_ClientMetadataSchemaTypeEmbedding(size=size))
self._schema[k] = _ClientMetadataSchemaOption(root=_ClientMetadataSchemaTypeEmbedding(size=size, model=encord_supported_model if encord_supported_model else None))
self._dirty = True

def add_enum(self, k: str, *, values: Sequence[str]) -> None:
Expand Down