Skip to content

Commit

Permalink
Feat: Active presets init (#806)
Browse files Browse the repository at this point in the history
Co-authored-by: Jim Broadbent <[email protected]>
  • Loading branch information
deoracord and Jim-Encord authored Dec 13, 2024
1 parent b02faf5 commit 51e95c4
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 23 deletions.
161 changes: 154 additions & 7 deletions encord/filter_preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from typing import Iterator, List, Optional, Union
from uuid import UUID

from encord.client import EncordClientProject
from encord.exceptions import (
AuthorisationError,
EncordException,
)
from encord.http.v2.api_client import ApiClient
from encord.orm.filter_preset import (
Expand All @@ -12,9 +14,11 @@
FilterPresetDefinition,
GetPresetParams,
GetPresetsResponse,
GetProjectFilterPresetParams,
UpdatePresetPayload,
)
from encord.orm.filter_preset import FilterPreset as OrmFilterPreset
from encord.orm.filter_preset import ProjectFilterPreset as OrmProjectFilterPreset


class FilterPreset:
Expand Down Expand Up @@ -126,11 +130,13 @@ def _delete_preset(api_client: ApiClient, preset_uuid: UUID) -> None:
)

@staticmethod
def _create_preset(api_client: ApiClient, name: str, description: str = "", *, filter_preset_json: dict) -> UUID:
def _create_preset(api_client: ApiClient, name: str, *, filter_preset_json: dict) -> UUID:
filter_preset = FilterPresetDefinition.from_dict(filter_preset_json)
if not filter_preset.local_filters and not filter_preset.global_filters:
raise EncordException("We require there to be a non-zero number of filters in a preset")
payload = CreatePresetPayload(
name=name,
description=description,
filter_preset_json=FilterPresetDefinition.from_dict(filter_preset_json).to_dict(),
filter_preset_json=filter_preset.to_dict(),
)
return api_client.post(
"index/presets",
Expand All @@ -146,9 +152,7 @@ def get_filter_preset_json(self) -> FilterPresetDefinition:
result_type=FilterPresetDefinition,
)

def update_preset(
self, name: Optional[str] = None, description: Optional[str] = None, filter_preset_json: Optional[dict] = None
) -> None:
def update_preset(self, name: Optional[str] = None, filter_preset_json: Optional[dict] = None) -> None:
"""
Update the preset's definition.
Args:
Expand All @@ -161,10 +165,153 @@ def update_preset(
filters_definition = FilterPresetDefinition.from_dict(filter_preset_json)
elif isinstance(filter_preset_json, FilterPresetDefinition):
filters_definition = filter_preset_json
payload = UpdatePresetPayload(name=name, description=description, filter_preset=filters_definition)
if filters_definition:
if not filters_definition.local_filters and not filters_definition.global_filters:
raise EncordException("We require there to be a non-zero number of filters in a preset")
payload = UpdatePresetPayload(name=name, filter_preset=filters_definition)
self._client.patch(
f"index/presets/{self.uuid}",
params=None,
payload=payload,
result_type=None,
)


class ProjectFilterPreset:
"""
Represents Active filter presets.
"""

def __init__(
self,
project_uuid: UUID,
client: ApiClient,
orm_filter_preset: OrmProjectFilterPreset,
):
self._project_uuid = project_uuid
self._client = client
self._filter_preset_instance = orm_filter_preset

@property
def uuid(self) -> UUID:
"""
Get the filter preset unique identifier (UUID).
Returns:
UUID: The filter preset UUID.
"""
return self._filter_preset_instance.preset_uuid

@property
def name(self) -> str:
"""
Get the filter preset name.
Returns:
str: The collection name.
"""
return self._filter_preset_instance.name

@property
def created_at(self) -> Optional[datetime]:
"""
Get the filter preset creation timestamp.
Returns:
Optional[datetime]: The timestamp when the filter preset was created, or None if not available.
"""
return self._filter_preset_instance.created_at

@property
def updated_at(self) -> Optional[datetime]:
"""
Get the filter preset last edit timestamp.
Returns:
Optional[datetime]: The timestamp when the filter preset was last edited, or None if not available.
"""
return self._filter_preset_instance.updated_at

@property
def project_hash(self) -> UUID:
"""
Get the project hash of the filter preset.
Returns:
UUID: The project hash of the filter preset.
"""
return self._project_uuid

@staticmethod
def _get_filter_preset(
client: ApiClient,
project_uuid: UUID,
filter_preset_uuid: UUID,
) -> "ProjectFilterPreset":
params = GetProjectFilterPresetParams(preset_uuids=[filter_preset_uuid])
orm_items = list(
client.get_paged_iterator(
f"active/{project_uuid}/presets",
params=params,
result_type=OrmProjectFilterPreset,
)
)
if len(orm_items) > 0:
return ProjectFilterPreset(
project_uuid=project_uuid,
client=client,
orm_filter_preset=orm_items[0],
)
raise AuthorisationError("No collection found")

@staticmethod
def _list_filter_presets(
client: ApiClient,
project_uuid: UUID,
filter_preset_uuids: Union[List[UUID], None],
page_size: Optional[int] = None,
) -> Iterator["ProjectFilterPreset"]:
params = GetProjectFilterPresetParams(preset_uuids=filter_preset_uuids, page_size=page_size)
paged_filter_presets = client.get_paged_iterator(
f"active/{project_uuid}/presets",
params=params,
result_type=OrmProjectFilterPreset,
)
for filter_preset in paged_filter_presets:
yield ProjectFilterPreset(
project_uuid=project_uuid,
client=client,
orm_filter_preset=filter_preset,
)

@staticmethod
def _delete_filter_preset(client: ApiClient, project_uuid: UUID, filter_preset_uuid: UUID) -> None:
client.delete(
f"active/{project_uuid}/presets/{filter_preset_uuid}",
params=None,
result_type=None,
)

def get_filter_preset_json(self) -> FilterPresetDefinition:
return self._client.get(
f"active/{self._project_uuid}/presets/{self._filter_preset_instance.preset_uuid}/raw",
params=None,
result_type=FilterPresetDefinition,
)

def update_preset(self, name: Optional[str] = None, filter_preset: Optional[FilterPresetDefinition] = None) -> None:
if name is None and filter_preset is None:
return
payload = UpdatePresetPayload(name=name, filter_preset=filter_preset)
self._client.patch(
f"active/{self.project_hash}/presets/{self.uuid}", params=None, payload=payload, result_type=None
)

@staticmethod
def _create_filter_preset(
client: ApiClient, project_uuid: UUID, name: str, filter_preset: FilterPresetDefinition
) -> UUID:
if not filter_preset.local_filters and not filter_preset.global_filters:
raise EncordException("We require there to be a non-zero number of filters in a preset for creation")
payload = CreatePresetPayload(name=name, filter_preset_json=filter_preset.to_dict())
orm_resp = client.post(f"active/{project_uuid}/presets", params=None, payload=payload, result_type=UUID)
return orm_resp
28 changes: 15 additions & 13 deletions encord/orm/filter_preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,28 @@ class FilterPreset(BaseDTO):
last_updated_at: Optional[datetime] = Field(default=None, alias="lastUpdatedAt")


class GetProjectFilterPresetParams(BaseDTO):
preset_uuids: Optional[List[uuid.UUID]] = Field(default=[])
page_token: Optional[str] = Field(default=None)
page_size: Optional[int] = Field(default=None)


class ProjectFilterPreset(BaseDTO):
preset_uuid: uuid.UUID = Field(alias="presetUuid")
name: str
created_at: Optional[datetime] = Field(default=None)
updated_at: Optional[datetime] = Field(default=None)


class FilterDefinition(BaseDTO):
filters: List[Dict] = Field(default_factory=list)


class FilterPresetDefinition(BaseDTO):
local_filters: Dict[str, FilterDefinition] = Field(
default_factory=lambda: {str(uuid.UUID(int=0)): FilterDefinition()}, alias="local_filters"
default_factory=lambda: {str(uuid.UUID(int=0)): FilterDefinition()},
)
global_filters: FilterDefinition = Field(default_factory=FilterDefinition, alias="global_filters")

@dto_validator(mode="after")
def check_not_empty(cls, self):
if len(self.global_filters.filters) == 0 and all(
[len(value.filters) == 0 for value in self.local_filters.values()]
):
raise ValueError("FilterPresetDefinition definition must contain at least one global or local filter.")

return self
global_filters: FilterDefinition = Field(default_factory=FilterDefinition)


class GetPresetsResponse(BaseDTO):
Expand All @@ -51,11 +55,9 @@ class CreatePresetParams(BaseDTO):

class CreatePresetPayload(BaseDTO):
name: str
description: Optional[str] = ""
filter_preset_json: Dict


class UpdatePresetPayload(BaseDTO):
name: Optional[str] = None
description: Optional[str] = None
filter_preset: Optional[FilterPresetDefinition] = None
60 changes: 60 additions & 0 deletions encord/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from encord.collection import ProjectCollection
from encord.common.deprecated import deprecated
from encord.constants.model import AutomationModels, Device
from encord.filter_preset import ProjectFilterPreset
from encord.http.bundle import Bundle
from encord.http.v2.api_client import ApiClient
from encord.objects import LabelRowV2, OntologyStructure
Expand All @@ -31,6 +32,7 @@
from encord.orm.cloud_integration import CloudIntegration
from encord.orm.collection import ProjectCollectionType
from encord.orm.dataset import Image, Video
from encord.orm.filter_preset import FilterPresetDefinition
from encord.orm.group import ProjectGroup
from encord.orm.label_log import LabelLog
from encord.orm.label_row import (
Expand Down Expand Up @@ -1242,3 +1244,61 @@ def active_import(self, project_mode: ActiveProjectMode, *, video_sampling_rate:
None
"""
self._client.active_import(project_mode, video_sampling_rate)

def list_filter_presets(
self,
filter_preset_uuids: Optional[List[Union[str, UUID]]] = None,
page_size: Optional[int] = None,
) -> Iterator[ProjectFilterPreset]:
"""
List all filter presets associated to the project.
Args:
filter_preset_uuids: The unique identifiers (UUIDs) of the filter presets to retrieve.
page_size (int): Number of items to return per page. Default if not specified is 100. Maximum value is 1000.
Returns:
The list of filter presets which match the given criteria.
Raises:
ValueError: If any of the filter preset uuids is a badly formed UUID.
:class:`encord.exceptions.AuthorizationError` : If the user does not have access to it.
"""
filter_presets = (
[
UUID(filter_preset) if isinstance(filter_preset, str) else filter_preset
for filter_preset in filter_preset_uuids
]
if filter_preset_uuids is not None
else None
)
return ProjectFilterPreset._list_filter_presets(
client=self._client._get_api_client(),
project_uuid=self._project_instance.project_hash,
filter_preset_uuids=filter_presets,
page_size=page_size,
)

def get_filter_preset(self, filter_preset_uuid: Union[str, UUID]) -> ProjectFilterPreset:
return ProjectFilterPreset._get_filter_preset(
client=self._client._get_api_client(),
project_uuid=self._project_instance.project_hash,
filter_preset_uuid=UUID(filter_preset_uuid) if isinstance(filter_preset_uuid, str) else filter_preset_uuid,
)

def delete_filter_preset(self, filter_preset_uuid: Union[str, UUID]) -> None:
ProjectFilterPreset._delete_filter_preset(
client=self._client._get_api_client(),
project_uuid=self._project_instance.project_hash,
filter_preset_uuid=UUID(filter_preset_uuid) if isinstance(filter_preset_uuid, str) else filter_preset_uuid,
)

def create_filter_preset(self, name: str, filter_preset: FilterPresetDefinition) -> ProjectFilterPreset:
uuid = ProjectFilterPreset._create_filter_preset(
client=self._client._get_api_client(),
project_uuid=self._project_instance.project_hash,
name=name,
filter_preset=filter_preset,
)
return ProjectFilterPreset._get_filter_preset(
client=self._client._get_api_client(),
project_uuid=self._project_instance.project_hash,
filter_preset_uuid=uuid,
)
4 changes: 1 addition & 3 deletions encord/user_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,9 +1536,7 @@ def create_preset(self, name: str, filter_preset_json: dict, description: str =
Returns:
FilterPreset: Newly created collection.
"""
new_uuid = FilterPreset._create_preset(
self._api_client, name, description, filter_preset_json=filter_preset_json
)
new_uuid = FilterPreset._create_preset(self._api_client, name, filter_preset_json=filter_preset_json)
return self.get_filter_preset(new_uuid)

def delete_preset(self, preset_uuid: Union[str, UUID]) -> None:
Expand Down

0 comments on commit 51e95c4

Please sign in to comment.