diff --git a/encord/client.py b/encord/client.py index 41064cdc..06291605 100644 --- a/encord/client.py +++ b/encord/client.py @@ -164,39 +164,26 @@ class EncordClient: with a project (e.g. label rows, datasets). """ - def __init__(self, querier: Querier, config: Config, api_client: Optional[ApiClient] = None): + def __init__(self, querier: Querier, config: Config, api_client: ApiClient): self._querier = querier self._config = config self._api_client = api_client - def _get_api_client(self) -> ApiClient: - if not (isinstance(self._config, (SshConfig, BearerConfig))): - raise EncordException( - "This functionality requires private SSH key authentication. API keys are not supported." - ) - - if not self._api_client: - raise RuntimeError("ApiClient should exist when authenticated with SSH key.") - - return self._api_client - def get_cloud_integrations(self) -> List[CloudIntegration]: return [ CloudIntegration( id=str(x.integration_uuid), title=x.title, ) - for x in self._get_api_client() - .get( + for x in self._api_client.get( "cloud-integrations", params=None, result_type=GetCloudIntegrationsResponse, - ) - .result + ).result ] def get_bearer_token(self) -> BearerTokenResponse: - return self._get_api_client().get("user/bearer-token", None, result_type=BearerTokenResponse) + return self._api_client.get("user/bearer-token", None, result_type=BearerTokenResponse) class EncordClientDataset(EncordClient): @@ -211,6 +198,9 @@ def __init__( dataset_access_settings: DatasetAccessSettings = DEFAULT_DATASET_ACCESS_SETTINGS, api_client: Optional[ApiClient] = None, ): + if api_client is None: + raise ValueError("api_client is None") + super().__init__(querier, config, api_client) self._dataset_access_settings = dataset_access_settings @@ -307,17 +297,15 @@ def add_users(self, user_emails: List[str], user_role: DatasetUserRole) -> List[ return [DatasetUser.from_dict(user) for user in users] def list_groups(self, dataset_hash: uuid.UUID) -> Page[DatasetGroup]: - return self._get_api_client().get( - f"datasets/{dataset_hash}/groups", params=None, result_type=Page[DatasetGroup] - ) + return self._api_client.get(f"datasets/{dataset_hash}/groups", params=None, result_type=Page[DatasetGroup]) def add_groups(self, dataset_hash: str, group_hash: List[uuid.UUID], user_role: DatasetUserRole) -> None: payload = AddDatasetGroupsPayload(group_hash_list=group_hash, user_role=user_role) - self._get_api_client().post(f"datasets/{dataset_hash}/groups", params=None, payload=payload, result_type=None) + self._api_client.post(f"datasets/{dataset_hash}/groups", params=None, payload=payload, result_type=None) def remove_groups(self, dataset_hash: uuid.UUID, group_hash: List[uuid.UUID]) -> None: params = RemoveGroupsParams(group_hash_list=group_hash) - self._get_api_client().delete(f"datasets/{dataset_hash}/groups", params=params, result_type=None) + self._api_client.delete(f"datasets/{dataset_hash}/groups", params=params, result_type=None) def __add_data_to_dataset_get_result( self, @@ -881,7 +869,7 @@ def get_project_v2(self) -> ProjectOrmV2: This is an internal method, do not use it directly. Use :meth:`UserClient.get_project` instead. """ - return self._get_api_client().get(f"/projects/{self.project_hash}", params=None, result_type=ProjectOrmV2) + return self._api_client.get(f"/projects/{self.project_hash}", params=None, result_type=ProjectOrmV2) def list_label_rows( self, @@ -956,17 +944,15 @@ def add_users(self, user_emails: List[str], user_role: ProjectUserRole) -> List[ return [ProjectUser.from_dict(user) for user in users] def list_groups(self, project_hash: uuid.UUID) -> Page[ProjectGroup]: - return self._get_api_client().get( - f"projects/{project_hash}/groups", params=None, result_type=Page[ProjectGroup] - ) + return self._api_client.get(f"projects/{project_hash}/groups", params=None, result_type=Page[ProjectGroup]) def add_groups(self, project_hash: uuid.UUID, group_hash: List[uuid.UUID], user_role: ProjectUserRole) -> None: payload = AddProjectGroupsPayload(group_hash_list=group_hash, user_role=user_role) - self._get_api_client().post(f"projects/{project_hash}/groups", params=None, payload=payload, result_type=None) + self._api_client.post(f"projects/{project_hash}/groups", params=None, payload=payload, result_type=None) def remove_groups(self, group_hash: List[uuid.UUID]) -> None: params = RemoveGroupsParams(group_hash_list=group_hash) - self._get_api_client().delete(f"projects/{self.project_hash}/groups", params=params, result_type=None) + self._api_client.delete(f"projects/{self.project_hash}/groups", params=params, result_type=None) def copy_project( self, @@ -1133,11 +1119,9 @@ def remove_datasets(self, dataset_hashes: List[str]) -> bool: return self._querier.basic_delete(ProjectDataset, uid=dataset_hashes) def list_project_datasets(self, project_hash: UUID) -> Iterable[ProjectDataset]: - return ( - self._get_api_client() - .get(f"projects/{project_hash}/datasets", params=None, result_type=Page[ProjectDataset]) - .results - ) + return self._api_client.get( + f"projects/{project_hash}/datasets", params=None, result_type=Page[ProjectDataset] + ).results @deprecated("0.1.102", alternative="encord.ontology.Ontology class") def get_project_ontology(self) -> LegacyOntology: @@ -1354,7 +1338,7 @@ def model_train_start( message="You must pass weights from the `encord.constants.model_weights` module to train a model." ) - training_hash = self._get_api_client().post( + training_hash = self._api_client.post( f"ml-models/{model_hash}/training", params=None, payload=PublicModelTrainStartPayload( @@ -1394,7 +1378,7 @@ def model_train_get_result( polling_available_seconds = max(0, timeout_seconds - polling_elapsed_seconds) logger.info(f"__model_train_get_result started polling call {polling_elapsed_seconds=}") - tmp_res = self._get_api_client().get( + tmp_res = self._api_client.get( f"ml-models/{model_hash}/{training_hash}/training", params=PublicModelTrainGetResultParams( timeout_seconds=min( @@ -1584,7 +1568,7 @@ def workflow_complete(self, label_hashes: List[str]) -> None: ) def workflow_set_priority(self, priorities: List[Tuple[str, float]]) -> None: - self._get_api_client().post( + self._api_client.post( f"projects/{self.project_hash}/priorities", params=None, payload=TaskPriorityParams(priorities=priorities), @@ -1592,12 +1576,12 @@ def workflow_set_priority(self, priorities: List[Tuple[str, float]]) -> None: ) def get_collaborator_timers(self, params: CollaboratorTimerParams) -> Iterable[CollaboratorTimer]: - yield from self._get_api_client().get_paged_iterator( + yield from self._api_client.get_paged_iterator( "analytics/collaborators/timers", params=params, result_type=CollaboratorTimer ) def get_label_validation_errors(self, label_hash: str) -> List[str]: - errors = self._get_api_client().get( + errors = self._api_client.get( f"projects/{self.project_hash}/labels/{label_hash}/validation-state", params=None, result_type=LabelValidationState, @@ -1609,7 +1593,7 @@ def get_label_validation_errors(self, label_hash: str) -> List[str]: return errors.errors or [] def active_import(self, project_mode: ActiveProjectMode, video_sampling_rate: Optional[float] = None) -> None: - self._get_api_client().post( + self._api_client.post( f"active/{self.project_hash}/import", params=None, payload=ActiveProjectImportPayload(project_mode=project_mode, video_sampling_rate=video_sampling_rate), @@ -1618,7 +1602,7 @@ def active_import(self, project_mode: ActiveProjectMode, video_sampling_rate: Op logger.info("Import initiated in Active, please check the app to see progress") def active_sync(self) -> None: - self._get_api_client().post( + self._api_client.post( f"active/{self.project_hash}/sync", params=None, payload=None, diff --git a/encord/collection.py b/encord/collection.py index 79ba7715..c3a04ff5 100644 --- a/encord/collection.py +++ b/encord/collection.py @@ -328,7 +328,7 @@ def __init__( orm_collection: OrmProjectCollection, ): self._project_uuid = project_uuid - self._client = project_client._get_api_client() + self._client = project_client._api_client self._project_client = project_client self._ontology = ontology self._collection_instance = orm_collection @@ -410,7 +410,7 @@ def _get_collection( ) -> "ProjectCollection": params = GetProjectCollectionParams(uuids=[collection_uuid]) orm_items = list( - project_client._get_api_client().get_paged_iterator( + project_client._api_client.get_paged_iterator( f"active/{project_uuid}/collections", params=params, result_type=OrmProjectCollection, @@ -434,7 +434,7 @@ def _list_collections( page_size: Optional[int] = None, ) -> Iterator["ProjectCollection"]: params = GetProjectCollectionParams(projectHash=project_uuid, uuids=collection_uuids, pageSize=page_size) - paged_collections = project_client._get_api_client().get_paged_iterator( + paged_collections = project_client._api_client.get_paged_iterator( f"active/{project_uuid}/collections", params=params, result_type=OrmProjectCollection, diff --git a/encord/project.py b/encord/project.py index 37f7a11b..ea4441d3 100644 --- a/encord/project.py +++ b/encord/project.py @@ -1235,7 +1235,7 @@ def delete_collection(self, collection_uuid: Union[str, UUID]) -> None: if isinstance(collection_uuid, str): collection_uuid = UUID(collection_uuid) ProjectCollection._delete_collection( - self._client._get_api_client(), self._project_instance.project_hash, collection_uuid + self._client._api_client, self._project_instance.project_hash, collection_uuid ) def create_collection( @@ -1253,7 +1253,7 @@ def create_collection( :class:`encord.exceptions.AuthorizationError` : If the user does not have access to the folder. """ new_uuid = ProjectCollection._create_collection( - self._client._get_api_client(), self._project_instance.project_hash, name, description, collection_type + self._client._api_client, self._project_instance.project_hash, name, description, collection_type ) return self.get_collection(new_uuid) @@ -1296,7 +1296,7 @@ def list_filter_presets( else None ) return ProjectFilterPreset._list_filter_presets( - client=self._client._get_api_client(), + client=self._client._api_client, project_uuid=self._project_instance.project_hash, filter_preset_uuids=filter_presets, page_size=page_size, @@ -1304,27 +1304,27 @@ def list_filter_presets( def get_filter_preset(self, filter_preset_uuid: Union[str, UUID]) -> ProjectFilterPreset: return ProjectFilterPreset._get_filter_preset( - client=self._client._get_api_client(), + client=self._client._api_client, project_uuid=self._project_instance.project_hash, filter_preset_uuid=UUID(filter_preset_uuid) if isinstance(filter_preset_uuid, str) else filter_preset_uuid, ) def delete_filter_preset(self, filter_preset_uuid: Union[str, UUID]) -> None: ProjectFilterPreset._delete_filter_preset( - client=self._client._get_api_client(), + client=self._client._api_client, project_uuid=self._project_instance.project_hash, filter_preset_uuid=UUID(filter_preset_uuid) if isinstance(filter_preset_uuid, str) else filter_preset_uuid, ) def create_filter_preset(self, name: str, filter_preset: ActiveFilterPresetDefinition) -> ProjectFilterPreset: uuid = ProjectFilterPreset._create_filter_preset( - client=self._client._get_api_client(), + client=self._client._api_client, project_uuid=self._project_instance.project_hash, name=name, filter_preset=filter_preset, ) return ProjectFilterPreset._get_filter_preset( - client=self._client._get_api_client(), + client=self._client._api_client, project_uuid=self._project_instance.project_hash, filter_preset_uuid=uuid, )