Skip to content

Commit

Permalink
SDK changes to make use of 'V2' ontologies APIs, plus 'include_org_ac…
Browse files Browse the repository at this point in the history
…cess' feature (#807)
  • Loading branch information
alexey-cord-tech authored Dec 2, 2024
1 parent 6ed7af0 commit 8a90571
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 82 deletions.
11 changes: 10 additions & 1 deletion encord/http/v2/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ def post(
) -> T:
return self._request_with_payload("POST", path, params, payload, result_type)

def put(
self,
path: str,
params: Optional[BaseDTO],
payload: Union[BaseDTO, Sequence[BaseDTO], None],
) -> None:
self._request_with_payload("PUT", path, params, payload, None, allow_none=True)

def patch(
self, path: str, params: Optional[BaseDTO], payload: Optional[BaseDTO], result_type: Optional[Type[T]]
) -> T:
Expand Down Expand Up @@ -143,6 +151,7 @@ def _request_with_payload(
params: Optional[BaseDTO],
payload: Union[BaseDTO, Sequence[BaseDTO], None],
result_type: Optional[Type[T]],
allow_none: bool = False,
) -> T:
params_dict = params.to_dict() if params is not None else None
payload_serialised = self._serialise_payload(payload)
Expand All @@ -154,7 +163,7 @@ def _request_with_payload(
json=payload_serialised,
).prepare()

return self._request(req, result_type=result_type) # type: ignore
return self._request(req, result_type=result_type, allow_none=allow_none) # type: ignore

def _request_without_payload(
self,
Expand Down
61 changes: 50 additions & 11 deletions encord/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@
---
"""

import dataclasses
import datetime
from typing import Iterable, List, Union
from typing import Iterable, List, Optional, Union
from uuid import UUID

from encord.http.querier import Querier
from encord.http.v2.api_client import ApiClient
from encord.http.v2.payloads import Page
from encord.objects.ontology_structure import OntologyStructure
from encord.orm.group import AddOntologyGroupsPayload, OntologyGroup, RemoveGroupsParams
from encord.orm.ontology import CreateOrUpdateOntologyPayload
from encord.orm.ontology import Ontology as OrmOntology
from encord.utilities.hash_utilities import convert_to_uuid
from encord.utilities.ontology_user import OntologyUserRole
from encord.utilities.ontology_user import OntologyUserRole, OntologyWithUserRole


class Ontology:
Expand All @@ -30,8 +31,7 @@ class Ontology:
:meth:`encord.user_client.EncordUserClient.get_ontology()`
"""

def __init__(self, querier: Querier, instance: OrmOntology, api_client: ApiClient):
self._querier = querier
def __init__(self, instance: Union[OrmOntology], api_client: ApiClient):
self._ontology_instance = instance
self.api_client = api_client

Expand Down Expand Up @@ -85,6 +85,10 @@ def structure(self) -> OntologyStructure:
"""
return self._ontology_instance.structure

@property
def user_role(self) -> Optional[OntologyUserRole]:
return self._ontology_instance.user_role

def refetch_data(self) -> None:
"""
The Ontology class will only fetch its properties once. Use this function if you suspect the state of those
Expand All @@ -97,13 +101,31 @@ def save(self) -> None:
Sync local state to the server, if updates are made to structure, title or description fields
"""
if self._ontology_instance:
payload = dict(**self._ontology_instance)
payload["editor"] = self._ontology_instance.structure.to_dict() # we're using internal/legacy name here
payload.pop("structure", None)
self._querier.basic_put(OrmOntology, self._ontology_instance.ontology_hash, payload)
try:
structure_dict = self._ontology_instance.structure.to_dict()
except ValueError as e:
raise ValueError("Can't save an Ontology containing a Classification without any attributes. " + str(e))

payload = CreateOrUpdateOntologyPayload(
title=self.title,
description=self.description,
editor=structure_dict,
)

self.api_client.put(
f"ontologies/{self._ontology_instance.ontology_hash}",
params=None,
payload=payload,
)

def _get_ontology(self) -> OrmOntology:
ontology_model = self.api_client.get(
f"/ontologies/{self._ontology_instance.ontology_hash}",
params=None,
result_type=OntologyWithUserRole,
)

def _get_ontology(self):
return self._querier.basic_getter(OrmOntology, self._ontology_instance.ontology_hash)
return self._legacy_orm_from_api_payload(ontology_model)

def list_groups(self) -> Iterable[OntologyGroup]:
"""
Expand Down Expand Up @@ -151,3 +173,20 @@ def remove_group(self, group_hash: Union[List[UUID], UUID]):
group_hash = [group_hash]
params = RemoveGroupsParams(group_hash_list=group_hash)
self.api_client.delete(f"ontologies/{ontology_hash}/groups", params=params, result_type=None)

@staticmethod
def _legacy_orm_from_api_payload(
ontology_with_user_role: OntologyWithUserRole,
) -> OrmOntology:
flattened = ontology_with_user_role.to_dict(
by_alias=False, # we need 'python' names for the legacy ORM
)
flattened["ontology_hash"] = flattened.pop("ontology_uuid") # backwards compat
return OrmOntology.from_dict(flattened)

@staticmethod
def _from_api_payload(
ontology_with_user_role: OntologyWithUserRole,
api_client: ApiClient,
) -> "Ontology":
return Ontology(Ontology._legacy_orm_from_api_payload(ontology_with_user_role), api_client)
20 changes: 15 additions & 5 deletions encord/orm/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@

# pylint: disable=unused-import
from encord.objects.ontology_structure import OntologyStructure
from encord.orm.base_dto import BaseDTO, dto_validator
from encord.orm.formatter import Formatter


class OntologyUserRole(IntEnum):
ADMIN = 0
USER = 1
from encord.utilities.ontology_user import OntologyUserRole


class Ontology(dict, Formatter):
Expand All @@ -23,6 +20,7 @@ def __init__(
created_at: datetime,
last_edited_at: datetime,
description: Optional[str] = None,
user_role: Optional[OntologyUserRole] = None,
):
"""
DEPRECATED - prefer using the :class:`encord.ontology.Ontology` class instead.
Expand All @@ -43,6 +41,7 @@ def __init__(
"structure": structure,
"created_at": created_at,
"last_edited_at": last_edited_at,
"user_role": user_role,
}
)

Expand Down Expand Up @@ -82,6 +81,10 @@ def created_at(self) -> datetime:
def last_edited_at(self) -> datetime:
return self["last_edited_at"]

@property
def user_role(self) -> OntologyUserRole:
return self["user_role"]

@classmethod
def from_dict(cls, json_dict: Dict) -> Ontology:
return Ontology(
Expand All @@ -91,4 +94,11 @@ def from_dict(cls, json_dict: Dict) -> Ontology:
structure=OntologyStructure.from_dict(json_dict["editor"]),
created_at=json_dict["created_at"],
last_edited_at=json_dict["last_edited_at"],
user_role=json_dict.get("user_role"), # has to be like this to support the legacy endpoints for tests
)


class CreateOrUpdateOntologyPayload(BaseDTO):
title: str
description: str
editor: dict
51 changes: 28 additions & 23 deletions encord/user_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from __future__ import annotations

import base64
import dataclasses
import logging
import time
import uuid
Expand Down Expand Up @@ -75,6 +76,7 @@
DicomDeIdStartPayload,
)
from encord.orm.group import Group as OrmGroup
from encord.orm.ontology import CreateOrUpdateOntologyPayload
from encord.orm.ontology import Ontology as OrmOntology
from encord.orm.project import (
BenchmarkQaWorkflowSettings,
Expand Down Expand Up @@ -102,7 +104,7 @@
Issues,
LocalImport,
)
from encord.utilities.ontology_user import OntologyUserRole, OntologyWithUserRole
from encord.utilities.ontology_user import OntologiesFilterParams, OntologyUserRole, OntologyWithUserRole
from encord.utilities.project_user import ProjectUserRole

CVAT_LONG_POLLING_RESPONSE_RETRY_N = 3
Expand Down Expand Up @@ -187,15 +189,15 @@ def get_project(self, project_hash: str | UUID) -> Project:
client = EncordClientProject(querier=querier, config=self._config.config, api_client=self._api_client)
project_orm = client.get_project_v2()

orm_ontology = querier.basic_getter(OrmOntology, project_orm.ontology_hash)
project_ontology = Ontology(querier, orm_ontology, self._api_client)
project_ontology = self.get_ontology(project_orm.ontology_hash)

return Project(client, project_orm, project_ontology, self._api_client)

def get_ontology(self, ontology_hash: str) -> Ontology:
querier = Querier(self._config.config, resource_type=TYPE_ONTOLOGY, resource_id=ontology_hash)
orm_ontology = querier.basic_getter(OrmOntology, ontology_hash)
return Ontology(querier, orm_ontology, self._api_client)
ontology_with_user_role = self._api_client.get(
f"ontologies/{ontology_hash}", params=None, result_type=OntologyWithUserRole
)
return Ontology._from_api_payload(ontology_with_user_role, self._api_client)

@deprecated("0.1.104", alternative=".create_dataset")
def create_private_dataset(
Expand Down Expand Up @@ -785,6 +787,7 @@ def get_ontologies(
created_after: Optional[Union[str, datetime]] = None,
edited_before: Optional[Union[str, datetime]] = None,
edited_after: Optional[Union[str, datetime]] = None,
include_org_access: bool = False,
) -> List[Dict]:
"""
List either all (if called with no arguments) or matching ontologies the user has access to.
Expand All @@ -798,21 +801,25 @@ def get_ontologies(
created_after: optional creation date filter, 'greater'
edited_before: optional last modification date filter, 'less'
edited_after: optional last modification date filter, 'greater'
include_org_access: if set to true and the calling user is the organization admin, the
method will return all ontologies in the organization.
Returns:
list of (role, projects) pairs for ontologies matching filter conditions.
list of ontologies matching filter conditions, with the roles that the current user has on them. Each item
is a dictionary with `"ontology"` and `"user_role"` keys. If include_org_access is set to
True, some of the ontologies may have a `None` value for the `"user_role"` key.
"""
properties_filter = self.__validate_filter(locals())
properties_filter = OntologiesFilterParams.from_dict(self.__validate_filter(locals()))
properties_filter.include_org_access = include_org_access
page = self._api_client.get("ontologies", params=properties_filter, result_type=Page[OntologyWithUserRole])

# a hack to be able to share validation code without too much c&p
data = self._querier.get_multiple(OntologyWithUserRole, payload={"filter": properties_filter})
retval: List[Dict] = []
for row in data:
ontology = OrmOntology.from_dict(row.ontology)
querier = Querier(self._config, resource_type=TYPE_ONTOLOGY, resource_id=ontology.ontology_hash)
for row in page.results:
retval.append(
{
"ontology": Ontology(querier, ontology, api_client=self._api_client),
"user_role": OntologyUserRole(row.user_role),
"ontology": Ontology._from_api_payload(row, self._api_client),
"user_role": row.user_role,
}
)
return retval
Expand All @@ -828,17 +835,15 @@ def create_ontology(
except ValueError as e:
raise ValueError("Can't create an Ontology containing a Classification without any attributes. " + str(e))

ontology = {
"title": title,
"description": description,
"editor": structure_dict,
}
payload = CreateOrUpdateOntologyPayload(
title=title,
description=description,
editor=structure_dict,
)

retval = self._querier.basic_setter(OrmOntology, uid=None, payload=ontology)
ontology = OrmOntology.from_dict(retval)
querier = Querier(self._config, resource_type=TYPE_ONTOLOGY, resource_id=ontology.ontology_hash)
ontology = self._api_client.post("ontologies", payload=payload, params=None, result_type=OntologyWithUserRole)

return Ontology(querier, ontology, self._api_client)
return Ontology._from_api_payload(ontology, self._api_client)

def __validate_filter(self, properties_filter: Dict) -> Dict:
if not isinstance(properties_filter, dict):
Expand Down
36 changes: 30 additions & 6 deletions encord/utilities/ontology_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,44 @@
"""

from dataclasses import dataclass
from datetime import datetime
from enum import IntEnum
from typing import Optional, Union
from uuid import UUID

from encord.orm.base_dto import BaseDTO


class OntologyUserRole(IntEnum):
ADMIN = 0
USER = 1


@dataclass(frozen=True)
class OntologyWithUserRole:
class OntologyWithUserRole(BaseDTO):
"""
An on-the-wire representation from /v2/public/ontologies endpoints
"""

ontology_uuid: UUID
title: str
description: str
editor: dict
created_at: datetime
last_edited_at: datetime
user_role: Optional[OntologyUserRole]


class OntologiesFilterParams(BaseDTO):
"""
This is a helper class denoting the relationship between the current user an an ontology
Filter parameters for the /v2/public/ontologies endpoint
"""

user_role: int
user_email: str
ontology: dict
title_eq: Optional[str] = None
title_like: Optional[str] = None
desc_eq: Optional[str] = None
desc_like: Optional[str] = None
created_before: Optional[Union[str, datetime]] = None
created_after: Optional[Union[str, datetime]] = None
edited_before: Optional[Union[str, datetime]] = None
edited_after: Optional[Union[str, datetime]] = None
include_org_access: bool = False
16 changes: 10 additions & 6 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import uuid
from datetime import datetime
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey

from encord import EncordUserClient, Project
from encord.client import EncordClientProject
from encord.http.querier import Querier
from encord.ontology import Ontology
from encord.orm.ontology import Ontology as OrmOntology
from encord.orm.project import ProjectDTO, ProjectType
Expand All @@ -25,7 +24,7 @@

@pytest.fixture
def ontology() -> Ontology:
return Ontology(None, OrmOntology.from_dict(ONTOLOGY_BLURB), None)
return Ontology(OrmOntology.from_dict(ONTOLOGY_BLURB), MagicMock())


@pytest.fixture
Expand All @@ -35,9 +34,14 @@ def user_client() -> EncordUserClient:

@pytest.fixture
@patch.object(EncordClientProject, "get_project_v2")
@patch.object(Querier, "basic_getter")
def project(querier_mock: Querier, client_project_mock, user_client: EncordUserClient, ontology: Ontology) -> Project:
querier_mock.return_value = OrmOntology.from_dict(ONTOLOGY_BLURB)
@patch.object(EncordUserClient, "get_ontology")
def project(
client_ontology_mock: MagicMock,
client_project_mock: MagicMock,
user_client: EncordUserClient,
ontology: Ontology,
) -> Project:
client_ontology_mock.return_value = ontology

client_project_mock.return_value = ProjectDTO(
project_hash=uuid.uuid4(),
Expand Down
Loading

0 comments on commit 8a90571

Please sign in to comment.