Skip to content

Commit

Permalink
Remove redundant function (#829)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-encord authored Jan 3, 2025
1 parent 4751e20 commit e6715e0
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 50 deletions.
64 changes: 24 additions & 40 deletions encord/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,39 +164,26 @@ class EncordClient:
with a project (e.g. label rows, datasets).
"""

def __init__(self, querier: Querier, config: Config, api_client: Optional[ApiClient] = None):
def __init__(self, querier: Querier, config: Config, api_client: ApiClient):
self._querier = querier
self._config = config
self._api_client = api_client

def _get_api_client(self) -> ApiClient:
if not (isinstance(self._config, (SshConfig, BearerConfig))):
raise EncordException(
"This functionality requires private SSH key authentication. API keys are not supported."
)

if not self._api_client:
raise RuntimeError("ApiClient should exist when authenticated with SSH key.")

return self._api_client

def get_cloud_integrations(self) -> List[CloudIntegration]:
return [
CloudIntegration(
id=str(x.integration_uuid),
title=x.title,
)
for x in self._get_api_client()
.get(
for x in self._api_client.get(
"cloud-integrations",
params=None,
result_type=GetCloudIntegrationsResponse,
)
.result
).result
]

def get_bearer_token(self) -> BearerTokenResponse:
return self._get_api_client().get("user/bearer-token", None, result_type=BearerTokenResponse)
return self._api_client.get("user/bearer-token", None, result_type=BearerTokenResponse)


class EncordClientDataset(EncordClient):
Expand All @@ -211,6 +198,9 @@ def __init__(
dataset_access_settings: DatasetAccessSettings = DEFAULT_DATASET_ACCESS_SETTINGS,
api_client: Optional[ApiClient] = None,
):
if api_client is None:
raise ValueError("api_client is None")

super().__init__(querier, config, api_client)
self._dataset_access_settings = dataset_access_settings

Expand Down Expand Up @@ -307,17 +297,15 @@ def add_users(self, user_emails: List[str], user_role: DatasetUserRole) -> List[
return [DatasetUser.from_dict(user) for user in users]

def list_groups(self, dataset_hash: uuid.UUID) -> Page[DatasetGroup]:
return self._get_api_client().get(
f"datasets/{dataset_hash}/groups", params=None, result_type=Page[DatasetGroup]
)
return self._api_client.get(f"datasets/{dataset_hash}/groups", params=None, result_type=Page[DatasetGroup])

def add_groups(self, dataset_hash: str, group_hash: List[uuid.UUID], user_role: DatasetUserRole) -> None:
payload = AddDatasetGroupsPayload(group_hash_list=group_hash, user_role=user_role)
self._get_api_client().post(f"datasets/{dataset_hash}/groups", params=None, payload=payload, result_type=None)
self._api_client.post(f"datasets/{dataset_hash}/groups", params=None, payload=payload, result_type=None)

def remove_groups(self, dataset_hash: uuid.UUID, group_hash: List[uuid.UUID]) -> None:
params = RemoveGroupsParams(group_hash_list=group_hash)
self._get_api_client().delete(f"datasets/{dataset_hash}/groups", params=params, result_type=None)
self._api_client.delete(f"datasets/{dataset_hash}/groups", params=params, result_type=None)

def __add_data_to_dataset_get_result(
self,
Expand Down Expand Up @@ -881,7 +869,7 @@ def get_project_v2(self) -> ProjectOrmV2:
This is an internal method, do not use it directly.
Use :meth:`UserClient.get_project` instead.
"""
return self._get_api_client().get(f"/projects/{self.project_hash}", params=None, result_type=ProjectOrmV2)
return self._api_client.get(f"/projects/{self.project_hash}", params=None, result_type=ProjectOrmV2)

def list_label_rows(
self,
Expand Down Expand Up @@ -956,17 +944,15 @@ def add_users(self, user_emails: List[str], user_role: ProjectUserRole) -> List[
return [ProjectUser.from_dict(user) for user in users]

def list_groups(self, project_hash: uuid.UUID) -> Page[ProjectGroup]:
return self._get_api_client().get(
f"projects/{project_hash}/groups", params=None, result_type=Page[ProjectGroup]
)
return self._api_client.get(f"projects/{project_hash}/groups", params=None, result_type=Page[ProjectGroup])

def add_groups(self, project_hash: uuid.UUID, group_hash: List[uuid.UUID], user_role: ProjectUserRole) -> None:
payload = AddProjectGroupsPayload(group_hash_list=group_hash, user_role=user_role)
self._get_api_client().post(f"projects/{project_hash}/groups", params=None, payload=payload, result_type=None)
self._api_client.post(f"projects/{project_hash}/groups", params=None, payload=payload, result_type=None)

def remove_groups(self, group_hash: List[uuid.UUID]) -> None:
params = RemoveGroupsParams(group_hash_list=group_hash)
self._get_api_client().delete(f"projects/{self.project_hash}/groups", params=params, result_type=None)
self._api_client.delete(f"projects/{self.project_hash}/groups", params=params, result_type=None)

def copy_project(
self,
Expand Down Expand Up @@ -1133,11 +1119,9 @@ def remove_datasets(self, dataset_hashes: List[str]) -> bool:
return self._querier.basic_delete(ProjectDataset, uid=dataset_hashes)

def list_project_datasets(self, project_hash: UUID) -> Iterable[ProjectDataset]:
return (
self._get_api_client()
.get(f"projects/{project_hash}/datasets", params=None, result_type=Page[ProjectDataset])
.results
)
return self._api_client.get(
f"projects/{project_hash}/datasets", params=None, result_type=Page[ProjectDataset]
).results

@deprecated("0.1.102", alternative="encord.ontology.Ontology class")
def get_project_ontology(self) -> LegacyOntology:
Expand Down Expand Up @@ -1354,7 +1338,7 @@ def model_train_start(
message="You must pass weights from the `encord.constants.model_weights` module to train a model."
)

training_hash = self._get_api_client().post(
training_hash = self._api_client.post(
f"ml-models/{model_hash}/training",
params=None,
payload=PublicModelTrainStartPayload(
Expand Down Expand Up @@ -1394,7 +1378,7 @@ def model_train_get_result(
polling_available_seconds = max(0, timeout_seconds - polling_elapsed_seconds)

logger.info(f"__model_train_get_result started polling call {polling_elapsed_seconds=}")
tmp_res = self._get_api_client().get(
tmp_res = self._api_client.get(
f"ml-models/{model_hash}/{training_hash}/training",
params=PublicModelTrainGetResultParams(
timeout_seconds=min(
Expand Down Expand Up @@ -1584,20 +1568,20 @@ def workflow_complete(self, label_hashes: List[str]) -> None:
)

def workflow_set_priority(self, priorities: List[Tuple[str, float]]) -> None:
self._get_api_client().post(
self._api_client.post(
f"projects/{self.project_hash}/priorities",
params=None,
payload=TaskPriorityParams(priorities=priorities),
result_type=None,
)

def get_collaborator_timers(self, params: CollaboratorTimerParams) -> Iterable[CollaboratorTimer]:
yield from self._get_api_client().get_paged_iterator(
yield from self._api_client.get_paged_iterator(
"analytics/collaborators/timers", params=params, result_type=CollaboratorTimer
)

def get_label_validation_errors(self, label_hash: str) -> List[str]:
errors = self._get_api_client().get(
errors = self._api_client.get(
f"projects/{self.project_hash}/labels/{label_hash}/validation-state",
params=None,
result_type=LabelValidationState,
Expand All @@ -1609,7 +1593,7 @@ def get_label_validation_errors(self, label_hash: str) -> List[str]:
return errors.errors or []

def active_import(self, project_mode: ActiveProjectMode, video_sampling_rate: Optional[float] = None) -> None:
self._get_api_client().post(
self._api_client.post(
f"active/{self.project_hash}/import",
params=None,
payload=ActiveProjectImportPayload(project_mode=project_mode, video_sampling_rate=video_sampling_rate),
Expand All @@ -1618,7 +1602,7 @@ def active_import(self, project_mode: ActiveProjectMode, video_sampling_rate: Op
logger.info("Import initiated in Active, please check the app to see progress")

def active_sync(self) -> None:
self._get_api_client().post(
self._api_client.post(
f"active/{self.project_hash}/sync",
params=None,
payload=None,
Expand Down
6 changes: 3 additions & 3 deletions encord/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def __init__(
orm_collection: OrmProjectCollection,
):
self._project_uuid = project_uuid
self._client = project_client._get_api_client()
self._client = project_client._api_client
self._project_client = project_client
self._ontology = ontology
self._collection_instance = orm_collection
Expand Down Expand Up @@ -410,7 +410,7 @@ def _get_collection(
) -> "ProjectCollection":
params = GetProjectCollectionParams(uuids=[collection_uuid])
orm_items = list(
project_client._get_api_client().get_paged_iterator(
project_client._api_client.get_paged_iterator(
f"active/{project_uuid}/collections",
params=params,
result_type=OrmProjectCollection,
Expand All @@ -434,7 +434,7 @@ def _list_collections(
page_size: Optional[int] = None,
) -> Iterator["ProjectCollection"]:
params = GetProjectCollectionParams(projectHash=project_uuid, uuids=collection_uuids, pageSize=page_size)
paged_collections = project_client._get_api_client().get_paged_iterator(
paged_collections = project_client._api_client.get_paged_iterator(
f"active/{project_uuid}/collections",
params=params,
result_type=OrmProjectCollection,
Expand Down
14 changes: 7 additions & 7 deletions encord/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,7 @@ def delete_collection(self, collection_uuid: Union[str, UUID]) -> None:
if isinstance(collection_uuid, str):
collection_uuid = UUID(collection_uuid)
ProjectCollection._delete_collection(
self._client._get_api_client(), self._project_instance.project_hash, collection_uuid
self._client._api_client, self._project_instance.project_hash, collection_uuid
)

def create_collection(
Expand All @@ -1253,7 +1253,7 @@ def create_collection(
:class:`encord.exceptions.AuthorizationError` : If the user does not have access to the folder.
"""
new_uuid = ProjectCollection._create_collection(
self._client._get_api_client(), self._project_instance.project_hash, name, description, collection_type
self._client._api_client, self._project_instance.project_hash, name, description, collection_type
)
return self.get_collection(new_uuid)

Expand Down Expand Up @@ -1296,35 +1296,35 @@ def list_filter_presets(
else None
)
return ProjectFilterPreset._list_filter_presets(
client=self._client._get_api_client(),
client=self._client._api_client,
project_uuid=self._project_instance.project_hash,
filter_preset_uuids=filter_presets,
page_size=page_size,
)

def get_filter_preset(self, filter_preset_uuid: Union[str, UUID]) -> ProjectFilterPreset:
return ProjectFilterPreset._get_filter_preset(
client=self._client._get_api_client(),
client=self._client._api_client,
project_uuid=self._project_instance.project_hash,
filter_preset_uuid=UUID(filter_preset_uuid) if isinstance(filter_preset_uuid, str) else filter_preset_uuid,
)

def delete_filter_preset(self, filter_preset_uuid: Union[str, UUID]) -> None:
ProjectFilterPreset._delete_filter_preset(
client=self._client._get_api_client(),
client=self._client._api_client,
project_uuid=self._project_instance.project_hash,
filter_preset_uuid=UUID(filter_preset_uuid) if isinstance(filter_preset_uuid, str) else filter_preset_uuid,
)

def create_filter_preset(self, name: str, filter_preset: ActiveFilterPresetDefinition) -> ProjectFilterPreset:
uuid = ProjectFilterPreset._create_filter_preset(
client=self._client._get_api_client(),
client=self._client._api_client,
project_uuid=self._project_instance.project_hash,
name=name,
filter_preset=filter_preset,
)
return ProjectFilterPreset._get_filter_preset(
client=self._client._get_api_client(),
client=self._client._api_client,
project_uuid=self._project_instance.project_hash,
filter_preset_uuid=uuid,
)

0 comments on commit e6715e0

Please sign in to comment.