Skip to content

Commit

Permalink
feat(LAB-3105): add param for new llm.export path
Browse files Browse the repository at this point in the history
  • Loading branch information
FannyGaudin committed Sep 18, 2024
1 parent 3691f6f commit 6a30a63
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 139 deletions.
3 changes: 3 additions & 0 deletions src/kili/llm/presentation/client/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def export(
disable_tqdm: Optional[bool] = False,
asset_ids: Optional[List[str]] = None,
external_ids: Optional[List[str]] = None,
include_sent_back_labels: Optional[bool] = False,
) -> Optional[List[Dict[str, Union[List[str], str]]]]:
# pylint: disable=line-too-long
"""Returns an export of llm assets with valid labels.
Expand All @@ -49,6 +50,7 @@ def export(
asset_ids: Optional list of the assets internal IDs from which to export the labels.
disable_tqdm: Disable the progress bar if True.
external_ids: Optional list of the assets external IDs from which to export the labels.
include_sent_back_labels: Include sent back labels if True.
!!! Example
```python
Expand Down Expand Up @@ -79,6 +81,7 @@ def export(
project_id=ProjectId(project_id),
asset_filter=asset_filter,
disable_tqdm=disable_tqdm,
include_sent_back_labels=include_sent_back_labels,
)
except NoCompatibleJobError as excp:
warnings.warn(str(excp), stacklevel=2)
Expand Down
99 changes: 94 additions & 5 deletions src/kili/llm/services/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,119 @@

from typing import Dict, List, Optional, Union

from kili.adapters.kili_api_gateway.helpers.queries import QueryOptions
from kili.adapters.kili_api_gateway.kili_api_gateway import KiliAPIGateway
from kili.domain.asset.asset import AssetFilters
from kili.domain.project import ProjectId

from .dynamic import LLMDynamicExporter
from .static import LLMStaticExporter

CHAT_ITEMS_NEEDED_FIELDS = [
"id",
"content",
"createdAt",
"modelId",
"parentId",
"role",
]

LABELS_NEEDED_FIELDS = [
"annotations.id",
"author.id",
"author.email",
"author.firstname",
"author.lastname",
*(f"chatItems.{field}" for field in CHAT_ITEMS_NEEDED_FIELDS),
"createdAt",
"id",
"isLatestLabelForUser",
"isSentBackToQueue",
"jsonResponse", # This is needed to keep annotations
"labelType",
"modelName",
]

ASSET_DYNAMIC_NEEDED_FIELDS = [
"assetProjectModels.id",
"assetProjectModels.configuration",
"assetProjectModels.name",
"content",
"externalId",
"jsonMetadata",
*(f"labels.{field}" for field in LABELS_NEEDED_FIELDS),
"status",
]

ASSET_STATIC_NEEDED_FIELDS = [
"content",
"externalId",
"jsonMetadata",
"labels.jsonResponse",
"labels.author.id",
"labels.author.email",
"labels.author.firstname",
"labels.author.lastname",
"labels.createdAt",
"labels.isLatestLabelForUser",
"labels.isSentBackToQueue",
"labels.labelType",
"labels.modelName",
"status",
]


def export( # pylint: disable=too-many-arguments, too-many-locals
kili_api_gateway: KiliAPIGateway,
project_id: ProjectId,
asset_filter: AssetFilters,
disable_tqdm: Optional[bool],
include_sent_back_labels: Optional[bool],
) -> Optional[List[Dict[str, Union[List[str], str]]]]:
"""Export the selected assets with their labels into the required format, and save it into a file archive."""
project = kili_api_gateway.get_project(project_id, ["id", "inputType", "jsonInterface"])
input_type = project["inputType"]

fields = get_fields_to_fetch(input_type)
assets = list(
kili_api_gateway.list_assets(asset_filter, fields, QueryOptions(disable_tqdm=disable_tqdm))
)
cleaned_assets = preprocess_assets(assets, include_sent_back_labels or False)
if input_type == "LLM_RLHF":
return LLMStaticExporter(kili_api_gateway, disable_tqdm).export(
project_id, asset_filter, project["jsonInterface"]
return LLMStaticExporter(kili_api_gateway).export(
cleaned_assets, project_id, project["jsonInterface"]
)
if input_type == "LLM_INSTR_FOLLOWING":
return LLMDynamicExporter(kili_api_gateway, disable_tqdm).export(
asset_filter, project["jsonInterface"]
)
return LLMDynamicExporter(kili_api_gateway).export(cleaned_assets, project["jsonInterface"])
raise ValueError(f'Project Input type "{input_type}" cannot be used for llm exports.')


def get_fields_to_fetch(input_type: str) -> List[str]:
"""Return the fields to fetch depending on the export type."""
if input_type == "LLM_RLHF":
return ASSET_STATIC_NEEDED_FIELDS
return ASSET_DYNAMIC_NEEDED_FIELDS


def preprocess_assets(assets: List[Dict], include_sent_back_labels: bool) -> List[Dict]:
"""Format labels in the requested format, and filter out autosave labels."""
assets_in_format = []
for asset in assets:
if "labels" in asset:
labels_of_asset = []
for label in asset["labels"]:
labels_of_asset.append(label)
if not include_sent_back_labels:
labels_of_asset = list(
filter(lambda label: label["isSentBackToQueue"] is False, labels_of_asset)
)
if len(labels_of_asset) > 0:
asset["labels"] = labels_of_asset
assets_in_format.append(asset)
if "latestLabel" in asset:
label = asset["latestLabel"]
if label is not None:
asset["latestLabel"] = label
if include_sent_back_labels or asset["latestLabel"]["isSentBackToQueue"] is False:
assets_in_format.append(asset)
return assets_in_format
46 changes: 3 additions & 43 deletions src/kili/llm/services/export/dynamic.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,21 @@
"""Handle LLM_INSTR_FOLLOWING project exports."""

import logging
from typing import Dict, List, Optional, Union
from typing import Dict, List, Union

from kili.adapters.kili_api_gateway.helpers.queries import QueryOptions
from kili.adapters.kili_api_gateway.kili_api_gateway import KiliAPIGateway
from kili.domain.asset.asset import AssetFilters

CHAT_ITEMS_NEEDED_FIELDS = [
"id",
"content",
"createdAt",
"modelId",
"parentId",
"role",
]

LABELS_NEEDED_FIELDS = [
"annotations.id",
"author.id",
"author.email",
"author.firstname",
"author.lastname",
*(f"chatItems.{field}" for field in CHAT_ITEMS_NEEDED_FIELDS),
"createdAt",
"id",
"isLatestLabelForUser",
"jsonResponse", # This is needed to keep annotations
"labelType",
"modelName",
]

ASSET_NEEDED_FIELDS = [
"assetProjectModels.id",
"assetProjectModels.configuration",
"assetProjectModels.name",
"content",
"externalId",
"jsonMetadata",
*(f"labels.{field}" for field in LABELS_NEEDED_FIELDS),
"status",
]


class LLMDynamicExporter:
"""Handle exports of LLM_RLHF projects."""

def __init__(self, kili_api_gateway: KiliAPIGateway, disable_tqdm: Optional[bool]):
def __init__(self, kili_api_gateway: KiliAPIGateway):
self.kili_api_gateway = kili_api_gateway
self.disable_tqdm = disable_tqdm

def export(
self, asset_filter: AssetFilters, json_interface: Dict
self, assets: List[Dict], json_interface: Dict
) -> List[Dict[str, Union[List[str], str]]]:
"""Asset content depends of each label."""
options = QueryOptions(disable_tqdm=self.disable_tqdm)
assets = self.kili_api_gateway.list_assets(asset_filter, ASSET_NEEDED_FIELDS, options)
export_res = []
for asset in assets:
# obfuscate models here
Expand Down
28 changes: 2 additions & 26 deletions src/kili/llm/services/export/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,24 @@
from pathlib import Path
from typing import Dict, List, Optional, Union

from kili.adapters.kili_api_gateway.helpers.queries import QueryOptions
from kili.adapters.kili_api_gateway.kili_api_gateway import KiliAPIGateway
from kili.domain.asset.asset import AssetFilters
from kili.domain.project import ProjectId
from kili.services.asset_import.helpers import SEPARATOR
from kili.services.export.format.llm.types import ExportLLMItem
from kili.use_cases.asset.media_downloader import MediaDownloader
from kili.utils.tempfile import TemporaryDirectory

ASSET_NEEDED_FIELDS = [
"content",
"externalId",
"jsonMetadata",
"labels.jsonResponse",
"labels.author.id",
"labels.author.email",
"labels.author.firstname",
"labels.author.lastname",
"labels.createdAt",
"labels.isLatestLabelForUser",
"labels.labelType",
"labels.modelName",
"status",
]


class LLMStaticExporter:
"""Handle exports of LLM_RLHF projects."""

def __init__(self, kili_api_gateway: KiliAPIGateway, disable_tqdm: Optional[bool]):
def __init__(self, kili_api_gateway: KiliAPIGateway):
self.kili_api_gateway = kili_api_gateway
self.disable_tqdm = disable_tqdm

def export(
self, project_id: ProjectId, asset_filter: AssetFilters, json_interface: Dict
self, assets: List[Dict], project_id: ProjectId, json_interface: Dict
) -> List[Dict[str, Union[List[str], str]]]:
"""Assets are static, with n labels."""
assets = list(
self.kili_api_gateway.list_assets(
asset_filter, ASSET_NEEDED_FIELDS, QueryOptions(disable_tqdm=self.disable_tqdm)
)
)
with TemporaryDirectory() as tmpdirname:
assets = MediaDownloader(
tmpdirname,
Expand Down
2 changes: 1 addition & 1 deletion src/kili/services/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def export_labels( # pylint: disable=too-many-arguments, too-many-locals
annotation_modifier=annotation_modifier,
asset_filter_kwargs=asset_filter_kwargs,
normalized_coordinates=normalized_coordinates,
include_sent_back_labels=include_sent_back_labels,
include_sent_back_labels=include_sent_back_labels if label_format != "llm_v1" else False,
)

logger = get_logger(log_level)
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/llm/services/export/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,10 @@
],
"createdAt": "2024-08-06T12:30:42.122Z",
"isLatestLabelForUser": True,
"isSentBackToQueue": False,
"id": "clzief6q2003e7tc91jm46uii",
"jsonResponse": {},
"labelType": "AUTOSAVE",
"labelType": "DEFAULT",
"modelName": None,
}
],
Expand Down Expand Up @@ -257,7 +258,7 @@
{
"author": "[email protected]",
"created_at": "2024-08-06T12:30:42.122Z",
"label_type": "AUTOSAVE",
"label_type": "DEFAULT",
"label": {"COMPARISON_JOB": "A_3", "CLASSIFICATION_JOB": ["BOTH_ARE_GOOD"]},
}
],
Expand Down Expand Up @@ -325,7 +326,7 @@
{
"author": "[email protected]",
"created_at": "2024-08-06T12:30:42.122Z",
"label_type": "AUTOSAVE",
"label_type": "DEFAULT",
"label": {"COMPARISON_JOB": "B_1"},
}
],
Expand Down Expand Up @@ -407,7 +408,7 @@
{
"author": "[email protected]",
"created_at": "2024-08-06T12:30:42.122Z",
"label_type": "AUTOSAVE",
"label_type": "DEFAULT",
"label": {"COMPARISON_JOB": "A_2"},
}
],
Expand Down
Loading

0 comments on commit 6a30a63

Please sign in to comment.