From e693c4d03b1874c1fb36df0c368853a70ed14ed9 Mon Sep 17 00:00:00 2001 From: "@fanny.gaudin" Date: Wed, 18 Sep 2024 16:27:41 +0200 Subject: [PATCH] feat(LAB-3105): add param for new llm.export path --- src/kili/llm/presentation/client/llm.py | 3 + src/kili/llm/services/export/__init__.py | 99 ++++++++++++++++++- src/kili/llm/services/export/dynamic.py | 46 +-------- src/kili/llm/services/export/static.py | 28 +----- src/kili/services/export/__init__.py | 2 +- .../unit/llm/services/export/test_dynamic.py | 9 +- tests/unit/llm/services/export/test_static.py | 63 +----------- tests/unit/services/export/test_export.py | 71 +++++++++++++ 8 files changed, 182 insertions(+), 139 deletions(-) diff --git a/src/kili/llm/presentation/client/llm.py b/src/kili/llm/presentation/client/llm.py index 5a009d797..306f1bc56 100644 --- a/src/kili/llm/presentation/client/llm.py +++ b/src/kili/llm/presentation/client/llm.py @@ -40,6 +40,7 @@ def export( disable_tqdm: Optional[bool] = False, asset_ids: Optional[List[str]] = None, external_ids: Optional[List[str]] = None, + include_sent_back_labels: Optional[bool] = False, ) -> Optional[List[Dict[str, Union[List[str], str]]]]: # pylint: disable=line-too-long """Returns an export of llm assets with valid labels. @@ -49,6 +50,7 @@ def export( asset_ids: Optional list of the assets internal IDs from which to export the labels. disable_tqdm: Disable the progress bar if True. external_ids: Optional list of the assets external IDs from which to export the labels. + include_sent_back_labels: Include sent back labels if True. !!! Example ```python @@ -79,6 +81,7 @@ def export( project_id=ProjectId(project_id), asset_filter=asset_filter, disable_tqdm=disable_tqdm, + include_sent_back_labels=include_sent_back_labels, ) except NoCompatibleJobError as excp: warnings.warn(str(excp), stacklevel=2) diff --git a/src/kili/llm/services/export/__init__.py b/src/kili/llm/services/export/__init__.py index bc32b28fb..d8fc44216 100644 --- a/src/kili/llm/services/export/__init__.py +++ b/src/kili/llm/services/export/__init__.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional, Union +from kili.adapters.kili_api_gateway.helpers.queries import QueryOptions from kili.adapters.kili_api_gateway.kili_api_gateway import KiliAPIGateway from kili.domain.asset.asset import AssetFilters from kili.domain.project import ProjectId @@ -9,23 +10,111 @@ from .dynamic import LLMDynamicExporter from .static import LLMStaticExporter +CHAT_ITEMS_NEEDED_FIELDS = [ + "id", + "content", + "createdAt", + "modelId", + "parentId", + "role", +] + +LABELS_NEEDED_FIELDS = [ + "annotations.id", + "author.id", + "author.email", + "author.firstname", + "author.lastname", + *(f"chatItems.{field}" for field in CHAT_ITEMS_NEEDED_FIELDS), + "createdAt", + "id", + "isLatestLabelForUser", + "isSentBackToQueue", + "jsonResponse", # This is needed to keep annotations + "labelType", + "modelName", +] + +ASSET_DYNAMIC_NEEDED_FIELDS = [ + "assetProjectModels.id", + "assetProjectModels.configuration", + "assetProjectModels.name", + "content", + "externalId", + "jsonMetadata", + *(f"labels.{field}" for field in LABELS_NEEDED_FIELDS), + "status", +] + +ASSET_STATIC_NEEDED_FIELDS = [ + "content", + "externalId", + "jsonMetadata", + "labels.jsonResponse", + "labels.author.id", + "labels.author.email", + "labels.author.firstname", + "labels.author.lastname", + "labels.createdAt", + "labels.isLatestLabelForUser", + "labels.isSentBackToQueue", + "labels.labelType", + "labels.modelName", + "status", +] + def export( # pylint: disable=too-many-arguments, too-many-locals kili_api_gateway: KiliAPIGateway, project_id: ProjectId, asset_filter: AssetFilters, disable_tqdm: Optional[bool], + include_sent_back_labels: Optional[bool], ) -> Optional[List[Dict[str, Union[List[str], str]]]]: """Export the selected assets with their labels into the required format, and save it into a file archive.""" project = kili_api_gateway.get_project(project_id, ["id", "inputType", "jsonInterface"]) input_type = project["inputType"] + fields = get_fields_to_fetch(input_type) + assets = list( + kili_api_gateway.list_assets(asset_filter, fields, QueryOptions(disable_tqdm=disable_tqdm)) + ) + cleaned_assets = preprocess_assets(assets, include_sent_back_labels or False) if input_type == "LLM_RLHF": - return LLMStaticExporter(kili_api_gateway, disable_tqdm).export( - project_id, asset_filter, project["jsonInterface"] + return LLMStaticExporter(kili_api_gateway).export( + cleaned_assets, project_id, project["jsonInterface"] ) if input_type == "LLM_INSTR_FOLLOWING": - return LLMDynamicExporter(kili_api_gateway, disable_tqdm).export( - asset_filter, project["jsonInterface"] - ) + return LLMDynamicExporter(kili_api_gateway).export(cleaned_assets, project["jsonInterface"]) raise ValueError(f'Project Input type "{input_type}" cannot be used for llm exports.') + + +def get_fields_to_fetch(input_type: str) -> List[str]: + """Return the fields to fetch depending on the export type.""" + if input_type == "LLM_RLHF": + return ASSET_STATIC_NEEDED_FIELDS + return ASSET_DYNAMIC_NEEDED_FIELDS + + +def preprocess_assets(assets: List[Dict], include_sent_back_labels: bool) -> List[Dict]: + """Format labels in the requested format, and filter out autosave labels.""" + assets_in_format = [] + for asset in assets: + if "labels" in asset: + labels_of_asset = [] + for label in asset["labels"]: + labels_of_asset.append(label) + if not include_sent_back_labels: + labels_of_asset = list( + filter(lambda label: label["isSentBackToQueue"] is False, labels_of_asset) + ) + if len(labels_of_asset) > 0: + asset["labels"] = labels_of_asset + assets_in_format.append(asset) + if "latestLabel" in asset: + label = asset["latestLabel"] + if label is not None: + asset["latestLabel"] = label + if include_sent_back_labels or asset["latestLabel"]["isSentBackToQueue"] is False: + assets_in_format.append(asset) + return assets_in_format diff --git a/src/kili/llm/services/export/dynamic.py b/src/kili/llm/services/export/dynamic.py index 3519e590c..644bfe24a 100644 --- a/src/kili/llm/services/export/dynamic.py +++ b/src/kili/llm/services/export/dynamic.py @@ -1,61 +1,21 @@ """Handle LLM_INSTR_FOLLOWING project exports.""" import logging -from typing import Dict, List, Optional, Union +from typing import Dict, List, Union -from kili.adapters.kili_api_gateway.helpers.queries import QueryOptions from kili.adapters.kili_api_gateway.kili_api_gateway import KiliAPIGateway -from kili.domain.asset.asset import AssetFilters - -CHAT_ITEMS_NEEDED_FIELDS = [ - "id", - "content", - "createdAt", - "modelId", - "parentId", - "role", -] - -LABELS_NEEDED_FIELDS = [ - "annotations.id", - "author.id", - "author.email", - "author.firstname", - "author.lastname", - *(f"chatItems.{field}" for field in CHAT_ITEMS_NEEDED_FIELDS), - "createdAt", - "id", - "isLatestLabelForUser", - "jsonResponse", # This is needed to keep annotations - "labelType", - "modelName", -] - -ASSET_NEEDED_FIELDS = [ - "assetProjectModels.id", - "assetProjectModels.configuration", - "assetProjectModels.name", - "content", - "externalId", - "jsonMetadata", - *(f"labels.{field}" for field in LABELS_NEEDED_FIELDS), - "status", -] class LLMDynamicExporter: """Handle exports of LLM_RLHF projects.""" - def __init__(self, kili_api_gateway: KiliAPIGateway, disable_tqdm: Optional[bool]): + def __init__(self, kili_api_gateway: KiliAPIGateway): self.kili_api_gateway = kili_api_gateway - self.disable_tqdm = disable_tqdm def export( - self, asset_filter: AssetFilters, json_interface: Dict + self, assets: List[Dict], json_interface: Dict ) -> List[Dict[str, Union[List[str], str]]]: """Asset content depends of each label.""" - options = QueryOptions(disable_tqdm=self.disable_tqdm) - assets = self.kili_api_gateway.list_assets(asset_filter, ASSET_NEEDED_FIELDS, options) export_res = [] for asset in assets: # obfuscate models here diff --git a/src/kili/llm/services/export/static.py b/src/kili/llm/services/export/static.py index b7d2702b3..55fa7e81a 100644 --- a/src/kili/llm/services/export/static.py +++ b/src/kili/llm/services/export/static.py @@ -5,48 +5,24 @@ from pathlib import Path from typing import Dict, List, Optional, Union -from kili.adapters.kili_api_gateway.helpers.queries import QueryOptions from kili.adapters.kili_api_gateway.kili_api_gateway import KiliAPIGateway -from kili.domain.asset.asset import AssetFilters from kili.domain.project import ProjectId from kili.services.asset_import.helpers import SEPARATOR from kili.services.export.format.llm.types import ExportLLMItem from kili.use_cases.asset.media_downloader import MediaDownloader from kili.utils.tempfile import TemporaryDirectory -ASSET_NEEDED_FIELDS = [ - "content", - "externalId", - "jsonMetadata", - "labels.jsonResponse", - "labels.author.id", - "labels.author.email", - "labels.author.firstname", - "labels.author.lastname", - "labels.createdAt", - "labels.isLatestLabelForUser", - "labels.labelType", - "labels.modelName", - "status", -] - class LLMStaticExporter: """Handle exports of LLM_RLHF projects.""" - def __init__(self, kili_api_gateway: KiliAPIGateway, disable_tqdm: Optional[bool]): + def __init__(self, kili_api_gateway: KiliAPIGateway): self.kili_api_gateway = kili_api_gateway - self.disable_tqdm = disable_tqdm def export( - self, project_id: ProjectId, asset_filter: AssetFilters, json_interface: Dict + self, assets: List[Dict], project_id: ProjectId, json_interface: Dict ) -> List[Dict[str, Union[List[str], str]]]: """Assets are static, with n labels.""" - assets = list( - self.kili_api_gateway.list_assets( - asset_filter, ASSET_NEEDED_FIELDS, QueryOptions(disable_tqdm=self.disable_tqdm) - ) - ) with TemporaryDirectory() as tmpdirname: assets = MediaDownloader( tmpdirname, diff --git a/src/kili/services/export/__init__.py b/src/kili/services/export/__init__.py index e2ed8f237..fd38c81cf 100644 --- a/src/kili/services/export/__init__.py +++ b/src/kili/services/export/__init__.py @@ -55,7 +55,7 @@ def export_labels( # pylint: disable=too-many-arguments, too-many-locals annotation_modifier=annotation_modifier, asset_filter_kwargs=asset_filter_kwargs, normalized_coordinates=normalized_coordinates, - include_sent_back_labels=include_sent_back_labels, + include_sent_back_labels=include_sent_back_labels if label_format != "llm_v1" else False, ) logger = get_logger(log_level) diff --git a/tests/unit/llm/services/export/test_dynamic.py b/tests/unit/llm/services/export/test_dynamic.py index bc8a95cc9..7453b41d3 100644 --- a/tests/unit/llm/services/export/test_dynamic.py +++ b/tests/unit/llm/services/export/test_dynamic.py @@ -181,9 +181,10 @@ ], "createdAt": "2024-08-06T12:30:42.122Z", "isLatestLabelForUser": True, + "isSentBackToQueue": False, "id": "clzief6q2003e7tc91jm46uii", "jsonResponse": {}, - "labelType": "AUTOSAVE", + "labelType": "DEFAULT", "modelName": None, } ], @@ -257,7 +258,7 @@ { "author": "test+admin@kili-technology.com", "created_at": "2024-08-06T12:30:42.122Z", - "label_type": "AUTOSAVE", + "label_type": "DEFAULT", "label": {"COMPARISON_JOB": "A_3", "CLASSIFICATION_JOB": ["BOTH_ARE_GOOD"]}, } ], @@ -325,7 +326,7 @@ { "author": "test+admin@kili-technology.com", "created_at": "2024-08-06T12:30:42.122Z", - "label_type": "AUTOSAVE", + "label_type": "DEFAULT", "label": {"COMPARISON_JOB": "B_1"}, } ], @@ -407,7 +408,7 @@ { "author": "test+admin@kili-technology.com", "created_at": "2024-08-06T12:30:42.122Z", - "label_type": "AUTOSAVE", + "label_type": "DEFAULT", "label": {"COMPARISON_JOB": "A_2"}, } ], diff --git a/tests/unit/llm/services/export/test_static.py b/tests/unit/llm/services/export/test_static.py index 2fe219f44..9380de8d6 100644 --- a/tests/unit/llm/services/export/test_static.py +++ b/tests/unit/llm/services/export/test_static.py @@ -54,6 +54,7 @@ }, "createdAt": "2024-08-05T13:03:00.051Z", "isLatestLabelForUser": True, + "isSentBackToQueue": False, "labelType": "DEFAULT", "modelName": None, } @@ -77,6 +78,7 @@ }, "createdAt": "2024-08-05T13:03:03.061Z", "isLatestLabelForUser": True, + "isSentBackToQueue": False, "labelType": "DEFAULT", "modelName": None, } @@ -101,6 +103,7 @@ }, "createdAt": "2024-08-05T13:03:16.028Z", "isLatestLabelForUser": True, + "isSentBackToQueue": True, "labelType": "DEFAULT", "modelName": None, } @@ -257,66 +260,6 @@ } ], }, - { - "raw_data": [ - { - "role": "user", - "content": "BLABLABLA", - "id": None, - "chat_id": None, - "model": None, - }, - { - "role": "assistant", - "content": "response A1", - "id": None, - "chat_id": None, - "model": None, - }, - { - "role": "assistant", - "content": "response B1", - "id": None, - "chat_id": None, - "model": None, - }, - { - "role": "user", - "content": "BLIBLIBLI", - "id": None, - "chat_id": None, - "model": None, - }, - { - "role": "assistant", - "content": "response A2", - "id": None, - "chat_id": None, - "model": None, - }, - { - "role": "assistant", - "content": "response B2", - "id": None, - "chat_id": None, - "model": None, - }, - ], - "status": "LABELED", - "external_id": "asset#2", - "metadata": {}, - "labels": [ - { - "author": "test+admin@kili-technology.com", - "created_at": "2024-08-05T13:03:16.028Z", - "label_type": "DEFAULT", - "label": { - "CLASSIFICATION_JOB": ["TIE"], - "TRANSCRIPTION_JOB": "There is only some formatting changes\n", - }, - } - ], - }, ] diff --git a/tests/unit/services/export/test_export.py b/tests/unit/services/export/test_export.py index a8d971a04..8333a18e2 100644 --- a/tests/unit/services/export/test_export.py +++ b/tests/unit/services/export/test_export.py @@ -859,6 +859,7 @@ def test_export_with_asset_filter_kwargs(mocker): "latestLabel.author.lastname", "latestLabel.createdAt", "latestLabel.isLatestLabelForUser", + "latestLabel.isSentBackToQueue", "latestLabel.labelType", "latestLabel.modelName", ] @@ -989,6 +990,7 @@ def test_when_exporting_geotiff_asset_with_incompatible_options_then_it_crashes( }, "createdAt": "2023-07-19T09:06:03.028Z", "isLatestLabelForUser": True, + "isSentBackToQueue": False, "labelType": "DEFAULT", "modelName": None, }, @@ -1100,3 +1102,72 @@ def test_given_kili_when_exporting_it_does_not_call_dataconnection_resolver( # Then process_and_save_mock.assert_called_once() kili.graphql_client.execute.assert_not_called() # pyright: ignore[reportGeneralTypeIssues] + + +def test_when_exporting_asset_with_include_sent_back_labels_parameter_it_filter_asset_exported( + mocker: pytest_mock.MockerFixture, +): + mocker.patch( + "kili.services.export.format.base.fetch_assets", + return_value=[ + { + "latestLabel": { + "author": { + "id": "user-feat1-1", + "email": "test+admin+1@kili-technology.com", + "firstname": "Feat1", + "lastname": "Test Admin", + }, + "jsonResponse": { + "OBJECT_DETECTION_JOB": { + "annotations": [ + { + "children": {}, + "boundingPoly": [ + { + "normalizedVertices": [ + {"x": 4.1, "y": 52.2}, + {"x": 4.5, "y": 52.7}, + {"x": 4.5, "y": 52.3}, + {"x": 4.1, "y": 52.4}, + ] + } + ], + "categories": [{"name": "A"}], + "mid": "20230719110559896-2495", + "type": "rectangle", + } + ] + } + }, + "createdAt": "2023-07-19T09:06:03.028Z", + "isLatestLabelForUser": True, + "isSentBackToQueue": True, + "labelType": "DEFAULT", + "modelName": None, + }, + "resolution": None, + "pageResolutions": None, + "id": "clk9i0hn000002a68a2zcd1v7", + "externalId": "BoundingBox.png", + "content": ( + "https://storage.googleapis.com/label-public-staging/demo-projects/Computer_vision_tutorial/BoundingBox.png" + ), + "jsonContent": None, + "jsonMetadata": {}, + } + ], + ) + + process_and_save_mock = mocker.patch.object(KiliExporter, "process_and_save", return_value=None) + kili = mock_kili(mocker, with_data_connection=False) + kili.api_endpoint = "https://" # type: ignore + kili.api_key = "" # type: ignore + kili.graphql_client = mocker.MagicMock() # pyright: ignore[reportGeneralTypeIssues] + kili.http_client = mocker.MagicMock() # pyright: ignore[reportGeneralTypeIssues] + + # When + kili.export_labels(project_id="fake_proj_id", filename="exp.zip", fmt="kili") + + # Then + process_and_save_mock.assert_called_once()