Skip to content

Commit

Permalink
feat: add an export format for dynamic llm annotations (#1738)
Browse files Browse the repository at this point in the history
  • Loading branch information
baptiste-olivier committed Jul 10, 2024
1 parent 3e8ba39 commit 86b2b79
Show file tree
Hide file tree
Showing 5 changed files with 554 additions and 34 deletions.
1 change: 1 addition & 0 deletions src/kili/services/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def export_labels( # pylint: disable=too-many-arguments, too-many-locals
"pascal_voc": VocExporter,
"geojson": GeoJsonExporter,
"llm_v1": LLMExporter,
"llm_dynamic_v1": LLMExporter,
}
assert set(format_exporter_selector_mapping.keys()) == set(
get_args(LabelFormat)
Expand Down
178 changes: 144 additions & 34 deletions src/kili/services/export/format/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import json
import logging
from ast import literal_eval
from pathlib import Path
from typing import Dict, List, Optional, Union

from kili.services.asset_import.helpers import SEPARATOR
from kili.services.export.exceptions import NotCompatibleInputType
from kili.services.export.format.base import AbstractExporter
from kili.services.export.format.llm.types import ExportLLMItem, RankingValue
from kili.services.types import Job


Expand Down Expand Up @@ -44,14 +46,48 @@ def process_and_save(
self, assets: List[Dict], output_filename: Path
) -> Optional[List[Dict[str, Union[List[str], str]]]]:
"""LLM specific process and save."""
result = self._process(assets)
result = self.process(assets)
self._save_assets_export(result, output_filename)

def process(self, assets: List[Dict]) -> List[Dict[str, Union[List[str], str]]]:
"""LLM specific process."""
return self._process(assets)
if self.label_format == "llm_v1":
return self._process_llm_v1(assets)
return self._process_llm_dynamic_v1(assets)

def _process(self, assets: List[Dict]) -> List[Dict[str, Union[List[str], str]]]:
def _process_llm_dynamic_v1(self, assets: List[Dict]) -> List[Dict[str, Union[List[str], str]]]:
result = []
for asset in assets:
step_number = _count_step(asset)
label = asset["latestLabel"]
steps = {}
context = []
formatted_asset = _format_raw_data(asset, all_model_keys=True)
for i in range(step_number):
steps[f"{i}"] = {
"raw_data": context + _format_raw_data(asset, i),
"status": asset["status"],
"external_id": asset["externalId"],
"metadata": asset["jsonMetadata"],
"labels": [
{
"author": label["author"]["email"],
"created_at": label["createdAt"],
"label_type": label["labelType"],
"label": _format_json_response_dynamic(
self.project["jsonInterface"]["jobs"], label["jsonResponse"], i
),
}
],
}
next_context = _get_next_step_context(formatted_asset, label["jsonResponse"], i)
context = context + next_context

if step_number > 0:
result.append(steps)
return result

def _process_llm_v1(self, assets: List[Dict]) -> List[Dict[str, Union[List[str], str]]]:
result = []
for asset in assets:
result.append(
Expand All @@ -60,26 +96,87 @@ def _process(self, assets: List[Dict]) -> List[Dict[str, Union[List[str], str]]]
"status": asset["status"],
"external_id": asset["externalId"],
"metadata": asset["jsonMetadata"],
"labels": [
list(
map(
lambda label: {
"author": label["author"]["email"],
"created_at": label["createdAt"],
"label_type": label["labelType"],
"label": _format_json_response(
self.project["jsonInterface"]["jobs"], label["jsonResponse"]
),
},
asset["labels"],
)
"labels": list(
map(
lambda label: {
"author": label["author"]["email"],
"created_at": label["createdAt"],
"label_type": label["labelType"],
"label": _format_json_response(
self.project["jsonInterface"]["jobs"], label["jsonResponse"]
),
},
asset["labels"],
)
],
),
}
)
return result


def _get_step_ranking_value(json_response: Dict, step_number: int) -> RankingValue:
prefix = f"STEP_{step_number+1}_"
for category in json_response["CLASSIFICATION_JOB"]["categories"]:
if category["name"] != f"STEP_{step_number+1}":
continue

for children_name, children_value in category["children"].items():
if children_name == f"STEP_{step_number+1}_RANKING":
raw_value = children_value["categories"][0]["name"]
return raw_value[len(prefix) :]
return RankingValue.TIE


def _get_next_step_context(
formatted_asset: List[ExportLLMItem], json_response: Dict, step_number: int
) -> List[ExportLLMItem]:
context = []
skipped_context = 0
completion_index = 0
ranking = _get_step_ranking_value(json_response, step_number)
for item in formatted_asset:
if skipped_context > step_number:
break

if skipped_context == step_number:
if item["role"] == "user":
context.append(item)
else:
if completion_index == 0 and ranking in ["A_1", "A_2", "A_3", "TIE"]:
context.append(item)
break
if completion_index == 1 and ranking in ["B_1", "B_2", "B_3"]:
context.append(item)
break
completion_index += 1

if item["role"] == "assistant":
skipped_context += 1

return context


def _count_step(asset: Dict) -> int:
label = asset["latestLabel"]
if "jsonResponse" not in label and "CLASSIFICATION_JOB" not in label["jsonResponse"]:
return 0
return len(label["jsonResponse"]["CLASSIFICATION_JOB"]["categories"])


def _format_json_response_dynamic(
jobs_config: Dict, json_response: Dict, step_number: int
) -> Dict[str, Union[str, List[str]]]:
# check subjobs of the step
job_step = f"STEP_{step_number+1}"
for item in json_response["CLASSIFICATION_JOB"]["categories"]:
if item["name"] != job_step:
continue
response_step = _format_json_response(jobs_config, item["children"])
formatted_response = literal_eval(str(response_step).replace(job_step + "_", ""))
return formatted_response
return {}


def _format_json_response(
jobs_config: Dict, json_response: Dict
) -> Dict[str, Union[str, List[str]]]:
Expand All @@ -104,7 +201,9 @@ def _format_json_response(
return result


def _format_raw_data(asset) -> List[Dict]:
def _format_raw_data(
asset, step_number: Optional[int] = None, all_model_keys: Optional[bool] = False
) -> List[ExportLLMItem]:
raw_data = []

chat_id = asset["jsonMetadata"].get("chat_id", None)
Expand All @@ -115,6 +214,8 @@ def _format_raw_data(asset) -> List[Dict]:
and len(asset["jsonMetadata"]["chat_item_ids"]) > 0
):
chat_items_ids = str.split(asset["jsonMetadata"]["chat_item_ids"], SEPARATOR)
if step_number is not None:
chat_items_ids = chat_items_ids[step_number * 3 :]
else:
chat_items_ids = []

Expand All @@ -131,25 +232,34 @@ def _format_raw_data(asset) -> List[Dict]:
data = json.load(file)
version = data.get("version", None)
if version == "0.1":
for index, prompt in enumerate(data["prompts"]):
prompts = data["prompts"]
if step_number is not None:
prompts = [prompts[step_number]]
for index, prompt in enumerate(prompts):
raw_data.append(
{
"role": prompt.get("title", "user"),
"content": prompt["prompt"],
"id": _safe_pop(chat_items_ids),
"chat_id": chat_id,
"model": None,
}
ExportLLMItem(
{
"role": prompt.get("title", "user"),
"content": prompt["prompt"],
"id": _safe_pop(chat_items_ids),
"chat_id": chat_id,
"model": None,
}
)
)
raw_data.extend(
{
"role": completion.get("title", "assistant"),
"content": completion["content"],
"id": _safe_pop(chat_items_ids),
"chat_id": chat_id,
"model": _safe_pop(models) if index == len(data["prompts"]) - 1 else None,
}
for completion in prompt["completions"]
ExportLLMItem(
{
"role": completion.get("title", "assistant"),
"content": completion["content"],
"id": _safe_pop(chat_items_ids),
"chat_id": chat_id,
"model": models[index_completion]
if (index == len(prompts) - 1 or all_model_keys)
else None,
}
)
for index_completion, completion in enumerate(prompt["completions"])
)
else:
raise ValueError(f"Version {version} not supported")
Expand Down
32 changes: 32 additions & 0 deletions src/kili/services/export/format/llm/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Custom Types."""

from enum import Enum
from typing import List, Optional, TypedDict


class ExportLLMItem(TypedDict):
"""LLM asset chat part."""

role: str
content: str
id: Optional[str]
chat_id: Optional[str]
model: Optional[str]


class ExportLLMAsset(TypedDict):
"""LLM export asset format."""

raw_data: List[ExportLLMItem]


class RankingValue(str, Enum):
"""Possible value for ranking."""

A_3 = "A_3"
A_2 = "A_2"
A_1 = "A_1"
TIE = "TIE"
B_1 = "B_1"
B_2 = "B_2"
B_3 = "B_3"
1 change: 1 addition & 0 deletions src/kili/services/export/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"pascal_voc",
"geojson",
"llm_v1",
"llm_dynamic_v1",
]


Expand Down
Loading

0 comments on commit 86b2b79

Please sign in to comment.