Skip to content

Commit

Permalink
feat: export for llm dynamic
Browse files Browse the repository at this point in the history
  • Loading branch information
baptiste-olivier committed Aug 7, 2024
1 parent 04a0c1e commit e38c76f
Show file tree
Hide file tree
Showing 16 changed files with 850 additions and 28 deletions.
31 changes: 16 additions & 15 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,22 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: 3.8
cache: "pip"
cache: 'pip'
- uses: pre-commit/[email protected]

pylint:
runs-on: ubuntu-latest
name: Pylint test
strategy:
matrix:
python-version: ["3.8", "3.12"]
python-version: ['3.8', '3.12']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
cache: 'pip'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -37,13 +37,13 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
version: ["3.8", "3.12"]
version: ['3.8', '3.12']
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.version }}
cache: "pip"
cache: 'pip'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down Expand Up @@ -72,7 +72,7 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
cache: 'pip'

- name: Install dependencies
run: |
Expand Down Expand Up @@ -102,7 +102,7 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: 3.8
cache: "pip"
cache: 'pip'

- name: Install dependencies
run: |
Expand Down Expand Up @@ -132,7 +132,7 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: 3.8
cache: "pip"
cache: 'pip'

- name: Install dependencies
run: |
Expand Down Expand Up @@ -163,7 +163,7 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: 3.8
cache: "pip"
cache: 'pip'

- name: Install dependencies
run: |
Expand All @@ -184,15 +184,16 @@ jobs:
-e "is only referenced in tests" \
-e "_ is never read" \
-e "affected_rows" \
-e "src/kili/orm" \
-e "src/kili/entrypoints" \
-e "src/kili/services" \
-e "src/kili/types.py" \
-e "src/kili/core/graphql/ws_graphql_client" \
-e "internal" \
-e "src/kili/domain/ontology.py" \
-e "src/kili/client.py" \
-e "src/kili/core/graphql/ws_graphql_client" \
-e "src/kili/domain/annotation.py" \
-e "src/kili/domain/ontology.py" \
-e "src/kili/domain/project.py" \
-e "src/kili/entrypoints" \
-e "src/kili/orm" \
-e "src/kili/services" \
-e "src/kili/types.py" \
> dead_code_filtered.txt || true
- name: Crash if dead code found
Expand Down
8 changes: 4 additions & 4 deletions src/kili/adapters/kili_api_gateway/asset/operations_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def list_assets(
project_info = get_project(
self.graphql_client, filters.project_id, ("inputType", "jsonInterface")
)
if project_info["inputType"] in {"VIDEO", "LLM_RLHF"}:
if project_info["inputType"] in {"VIDEO", "LLM_RLHF", "LLM_INSTR_FOLLOWING"}:
yield from self.list_assets_split(filters, fields, options, project_info)
return

Expand Down Expand Up @@ -90,23 +90,23 @@ def list_assets_split(
assets_gen = (
load_asset_json_fields(asset, fields, self.http_client) for asset in assets_gen
)

converter = AnnotationsToJsonResponseConverter(
json_interface=project_info["jsonInterface"],
project_input_type=project_info["inputType"],
)
is_requesting_annotations = any("annotations." in element for element in fields)
for asset in assets_gen:
if "latestLabel.jsonResponse" in fields and asset.get("latestLabel"):
converter.patch_label_json_response(
asset["latestLabel"], asset["latestLabel"]["annotations"]
)
if "latestLabel.annotations" not in fields:
if not is_requesting_annotations:
asset["latestLabel"].pop("annotations")

if "labels.jsonResponse" in fields:
for label in asset.get("labels", []):
converter.patch_label_json_response(label, label["annotations"])
if "labels.annotations" not in fields:
if not is_requesting_annotations:
label.pop("annotations")
yield asset

Expand Down
4 changes: 4 additions & 0 deletions src/kili/adapters/kili_api_gateway/kili_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from kili.adapters.kili_api_gateway.cloud_storage import CloudStorageOperationMixin
from kili.adapters.kili_api_gateway.issue import IssueOperationMixin
from kili.adapters.kili_api_gateway.label.operations_mixin import LabelOperationMixin
from kili.adapters.kili_api_gateway.model_configuration.operations_mixin import (
ModelConfigurationOperationMixin,
)
from kili.adapters.kili_api_gateway.notification.operations_mixin import (
NotificationOperationMixin,
)
Expand All @@ -24,6 +27,7 @@ class KiliAPIGateway(
CloudStorageOperationMixin,
IssueOperationMixin,
LabelOperationMixin,
ModelConfigurationOperationMixin,
NotificationOperationMixin,
OrganizationOperationMixin,
ProjectOperationMixin,
Expand Down
14 changes: 12 additions & 2 deletions src/kili/adapters/kili_api_gateway/label/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,24 @@ def get_annotation_fragment():
"""Generates a fragment to get all annotations and their values."""
return get_annotations_partial_query(
annotation_fragment=fragment_builder(("__typename", "id", "job", "path", "labelId")),
classification_annotation_fragment=fragment_builder(("annotationValue.categories",)),
classification_annotation_fragment=fragment_builder(
("annotationValue.categories", "chatItemId")
),
ranking_annotation_fragment=fragment_builder(
(
"annotationValue.orders.elements",
"annotationValue.orders.rank",
)
),
transcription_annotation_fragment=fragment_builder(("annotationValue.text",)),
comparison_annotation_fragment=fragment_builder(
(
"annotationValue.choice.code",
"annotationValue.choice.firstId",
"annotationValue.choice.secondId",
"chatItemId",
)
),
transcription_annotation_fragment=fragment_builder(("annotationValue.text", "chatItemId")),
video_annotation_fragment=fragment_builder(
(
"frames.start",
Expand Down
8 changes: 8 additions & 0 deletions src/kili/adapters/kili_api_gateway/label/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def get_annotations_partial_query(
annotation_fragment: str,
classification_annotation_fragment: str,
ranking_annotation_fragment: str,
comparison_annotation_fragment: str,
transcription_annotation_fragment: str,
video_annotation_fragment: str,
video_object_detection_annotation_fragment: str,
Expand All @@ -115,6 +116,13 @@ def get_annotations_partial_query(
}}
"""

if comparison_annotation_fragment.strip():
inline_fragments += f"""
... on ComparisonAnnotation {{
{comparison_annotation_fragment}
}}
"""

if transcription_annotation_fragment.strip():
inline_fragments += f"""
... on TranscriptionAnnotation {{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def list_labels(
project_info = get_project(
self.graphql_client, filters.project_id, ("inputType", "jsonInterface")
)
if project_info["inputType"] in {"VIDEO", "LLM_RLHF"}:
if project_info["inputType"] in {"VIDEO", "LLM_RLHF", "LLM_INSTR_FOLLOWING"}:
yield from self.list_labels_split(filters, fields, options, project_info)
return

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Api key Kili Gateway module."""
11 changes: 11 additions & 0 deletions src/kili/adapters/kili_api_gateway/model_configuration/mappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""GraphQL payload data mappers for api keys operations."""

from kili.domain.project_model import ProjectModelFilters


def project_model_where_mapper(filter: ProjectModelFilters):
"""Build the GraphQL ProjectMapperWhere variable to be sent in an operation."""
return {
"projectId": filter.project_id,
"modelId": filter.model_id,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""GraphQL Asset operations."""


def get_project_models_query(fragment: str) -> str:
"""Return the GraphQL projectModels query."""
return f"""
query ProjectModels($where: ProjectModelWhere!, $first: PageSize!, $skip: Int!) {{
data: projectModels(where: $where, first: $first, skip: $skip) {{
{fragment}
}}
}}
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Mixin extending Kili API Gateway class with Api Keys related operations."""

from typing import Dict, Generator, Optional

from kili.adapters.kili_api_gateway.base import BaseOperationMixin
from kili.adapters.kili_api_gateway.helpers.queries import (
PaginatedGraphQLQuery,
QueryOptions,
fragment_builder,
)
from kili.adapters.kili_api_gateway.model_configuration.mappers import project_model_where_mapper
from kili.adapters.kili_api_gateway.model_configuration.operations import get_project_models_query
from kili.domain.project_model import ProjectModelFilters
from kili.domain.types import ListOrTuple


class ModelConfigurationOperationMixin(BaseOperationMixin):
"""Mixin extending Kili API Gateway class with model configuration related operations."""

def list_project_models(
self,
filters: ProjectModelFilters,
fields: ListOrTuple[str],
options: Optional[QueryOptions] = None,
) -> Generator[Dict, None, None]:
"""List assets with given options."""
fragment = fragment_builder(fields)
query = get_project_models_query(fragment)
where = project_model_where_mapper(filters)
return PaginatedGraphQLQuery(self.graphql_client).execute_query_from_paginated_call(
query,
where,
options if options else QueryOptions(disable_tqdm=False),
"Retrieving project models",
None,
)
12 changes: 12 additions & 0 deletions src/kili/domain/project_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""API Key domain."""

from dataclasses import dataclass
from typing import Optional


@dataclass
class ProjectModelFilters:
"""Project model filters for running a project model search."""

project_id: Optional[str] = None
model_id: Optional[str] = None
26 changes: 26 additions & 0 deletions src/kili/llm/presentation/client/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,21 @@
from kili.adapters.kili_api_gateway.kili_api_gateway import KiliAPIGateway
from kili.domain.asset import AssetExternalId, AssetFilters, AssetId
from kili.domain.project import ProjectId
from kili.domain.project_model import ProjectModelFilters
from kili.llm.services.export import export
from kili.services.export.exceptions import NoCompatibleJobError
from kili.use_cases.asset.utils import AssetUseCasesUtils
from kili.utils.logcontext import for_all_methods, log_call

DEFAULT_PROJECT_MODEL_FIELDS = [
"configuration",
"id",
"model.credentials",
"model.name",
"model.type",
"name",
]


@for_all_methods(log_call, exclude=["__init__"])
class LlmClientMethods:
Expand Down Expand Up @@ -73,3 +83,19 @@ def export(
except NoCompatibleJobError as excp:
warnings.warn(str(excp), stacklevel=2)
return None

def list_project_models(
self, project_id: str, filters: Optional[Dict] = None, fields: Optional[List[str]] = None
) -> List[Dict]:
"""List project models."""
converted_filters = ProjectModelFilters(
project_id=project_id,
model_id=filters["model_id"] if filters and "model_id" in filters else None,
)

return list(
self.kili_api_gateway.list_project_models(
filters=converted_filters,
fields=fields if fields else DEFAULT_PROJECT_MODEL_FIELDS,
)
)
5 changes: 5 additions & 0 deletions src/kili/llm/services/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from kili.domain.asset.asset import AssetFilters
from kili.domain.project import ProjectId

from .dynamic import LLMDynamicExporter
from .static import LLMStaticExporter


Expand All @@ -23,4 +24,8 @@ def export( # pylint: disable=too-many-arguments, too-many-locals
return LLMStaticExporter(kili_api_gateway, disable_tqdm).export(
project_id, asset_filter, project["jsonInterface"]
)
if input_type == "LLM_INSTR_FOLLOWING":
return LLMDynamicExporter(kili_api_gateway, disable_tqdm).export(
asset_filter, project["jsonInterface"]
)
raise ValueError(f'Project Input type "{input_type}" cannot be used for llm exports.')
Loading

0 comments on commit e38c76f

Please sign in to comment.