Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(LAB-3088): aau i configure dogfooding projects models using the python #1774

Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 102 additions & 2 deletions src/kili/adapters/kili_api_gateway/model_configuration/mappers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,111 @@
"""GraphQL payload data mappers for api keys operations."""

from kili.domain.project_model import ProjectModelFilters
from typing import Dict
RuellePaul marked this conversation as resolved.
Show resolved Hide resolved

from kili.domain.llm import (
AzureOpenAICredentials,
ModelToCreateInput,
ModelToUpdateInput,
ModelType,
OpenAISDKCredentials,
OrganizationModelFilters,
ProjectModelFilters,
ProjectModelToCreateInput,
ProjectModelToUpdateInput,
)

def project_model_where_mapper(filter: ProjectModelFilters):

def model_where_wrapper(filter: OrganizationModelFilters) -> Dict:
"""Build the GraphQL ProjectMapperWhere variable to be sent in an operation."""
return {
"organizationId": filter.organization_id,
}


def project_model_where_mapper(filter: ProjectModelFilters) -> Dict:
"""Build the GraphQL ProjectMapperWhere variable to be sent in an operation."""
return {
"projectId": filter.project_id,
"modelId": filter.model_id,
}


def map_create_model_input(data: ModelToCreateInput) -> Dict:
"""Build the GraphQL ModelInput variable to be sent in an operation."""
if data.type == ModelType.AZURE_OPEN_AI and isinstance(
data.credentials, AzureOpenAICredentials
):
credentials = {
"apiKey": data.credentials.api_key,
"deploymentId": data.credentials.deployment_id,
"endpoint": data.credentials.endpoint,
}
elif data.type == ModelType.OPEN_AI_SDK and isinstance(data.credentials, OpenAISDKCredentials):
credentials = {"apiKey": data.credentials.api_key, "endpoint": data.credentials.endpoint}
else:
raise ValueError(
f"Unsupported model type or credentials: {data.type}, {type(data.credentials)}"
)

return {
"credentials": credentials,
"name": data.name,
"type": data.type.value,
"organizationId": data.organization_id,
}


def map_update_model_input(data: ModelToUpdateInput) -> Dict:
"""Build the GraphQL UpdateModelInput variable to be sent in an operation."""
input_dict = {}
if data.name is not None:
input_dict["name"] = data.name

if data.credentials is not None:
if isinstance(data.credentials, AzureOpenAICredentials):
credentials = {
"apiKey": data.credentials.api_key,
"deploymentId": data.credentials.deployment_id,
"endpoint": data.credentials.endpoint,
}
elif isinstance(data.credentials, OpenAISDKCredentials):
credentials = {
"apiKey": data.credentials.api_key,
"endpoint": data.credentials.endpoint,
}
else:
raise ValueError(f"Unsupported credentials type: {type(data.credentials)}")
input_dict["credentials"] = credentials

return input_dict


def map_create_project_model_input(data: ProjectModelToCreateInput) -> Dict:
"""Build the GraphQL ModelInput variable to be sent in an operation."""
return {
"projectId": data.project_id,
"modelId": data.model_id,
"configuration": data.configuration,
}


def map_update_project_model_input(data: ProjectModelToUpdateInput) -> Dict:
"""Build the GraphQL UpdateProjectModelInput variable to be sent in an operation."""
input_dict = {}
if data.configuration is not None:
input_dict["configuration"] = data.configuration
return input_dict


def map_delete_model_input(model_id: str) -> Dict:
"""Map the input for the GraphQL deleteModel mutation."""
return {
"deleteModelId": model_id,
}


def map_delete_project_model_input(project_model_id: str) -> Dict:
"""Map the input for the GraphQL deleteProjectModel mutation."""
return {
"deleteProjectModelId": project_model_id,
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,90 @@
"""GraphQL Asset operations."""


def get_models_query(fragment: str) -> str:
"""Return the GraphQL projectModels query."""
return f"""
query Models($where: ModelWhere!, $first: PageSize!, $skip: Int!) {{
data: models(where: $where, first: $first, skip: $skip) {{
{fragment}
}}
}}
"""


def get_model_query(fragment: str) -> str:
"""Return the GraphQL model query by ID."""
return f"""
query Model($modelId: ID!) {{
model(id: $modelId) {{
{fragment}
}}
}}
"""


def get_create_model_mutation(fragment: str) -> str:
"""Return the GraphQL createProjectModel mutation."""
return f"""
mutation CreateModel($input: CreateModelInput!) {{
createModel(input: $input) {{
{fragment}
}}
}}
"""


def get_update_model_mutation(fragment: str) -> str:
"""Return the GraphQL updateModel mutation."""
return f"""
mutation UpdateModel($id: ID!, $input: UpdateModelInput!) {{
updateModel(id: $id, input: $input) {{
{fragment}
}}
}}
"""


def get_delete_model_mutation() -> str:
"""Return the GraphQL deleteOrganizationModel mutation."""
return """
mutation DeleteModel($deleteModelId: ID!) {
deleteModel(id: $deleteModelId)
}
"""


def get_create_project_model_mutation(fragment: str) -> str:
"""Return the GraphQL createProjectModel mutation."""
return f"""
mutation CreateProjectModel($input: CreateProjectModelInput!) {{
createProjectModel(input: $input) {{
{fragment}
}}
}}
"""


def get_update_project_model_mutation(fragment: str) -> str:
"""Return the GraphQL updateProjectModel mutation."""
return f"""
mutation UpdateProjectModel($updateProjectModelId: ID!, $input: UpdateProjectModelInput!) {{
updateProjectModel(id: $updateProjectModelId, input: $input) {{
{fragment}
}}
}}
"""


def get_delete_project_model_mutation() -> str:
"""Return the GraphQL deleteProjectModel mutation."""
return """
mutation DeleteProjectModel($deleteProjectModelId: ID!) {
deleteProjectModel(id: $deleteProjectModelId)
}
"""


def get_project_models_query(fragment: str) -> str:
"""Return the GraphQL projectModels query."""
return f"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,125 @@
QueryOptions,
fragment_builder,
)
from kili.adapters.kili_api_gateway.model_configuration.mappers import project_model_where_mapper
from kili.adapters.kili_api_gateway.model_configuration.operations import get_project_models_query
from kili.domain.project_model import ProjectModelFilters
from kili.adapters.kili_api_gateway.model_configuration.mappers import (
map_create_model_input,
map_create_project_model_input,
map_delete_model_input,
map_delete_project_model_input,
map_update_model_input,
map_update_project_model_input,
model_where_wrapper,
project_model_where_mapper,
)
from kili.adapters.kili_api_gateway.model_configuration.operations import (
get_create_model_mutation,
get_create_project_model_mutation,
get_delete_model_mutation,
get_delete_project_model_mutation,
get_model_query,
get_models_query,
get_project_models_query,
get_update_model_mutation,
get_update_project_model_mutation,
)
from kili.domain.llm import (
ModelToCreateInput,
ModelToUpdateInput,
OrganizationModelFilters,
ProjectModelFilters,
ProjectModelToCreateInput,
ProjectModelToUpdateInput,
)
from kili.domain.types import ListOrTuple


class ModelConfigurationOperationMixin(BaseOperationMixin):
"""Mixin extending Kili API Gateway class with model configuration related operations."""

def list_models(
self,
filters: OrganizationModelFilters,
fields: ListOrTuple[str],
options: Optional[QueryOptions] = None,
) -> Generator[Dict, None, None]:
"""List models with given options."""
fragment = fragment_builder(fields)
query = get_models_query(fragment)
where = model_where_wrapper(filters)
return PaginatedGraphQLQuery(self.graphql_client).execute_query_from_paginated_call(
query,
where,
options if options else QueryOptions(disable_tqdm=False),
"Retrieving organization models",
None,
)

def get_model(self, model_id: str, fields: ListOrTuple[str]) -> Dict:
"""Get a model by ID."""
fragment = fragment_builder(fields)
query = get_model_query(fragment)
variables = {"modelId": model_id}
result = self.graphql_client.execute(query, variables)
return result["model"]

def create_model(self, model: ModelToCreateInput) -> Dict:
"""Send a GraphQL request calling createModel resolver."""
payload = {"input": map_create_model_input(model)}
fragment = fragment_builder(["id"])
mutation = get_create_model_mutation(fragment)
result = self.graphql_client.execute(mutation, payload)
return result["createModel"]

def update_properties_in_model(self, model_id: str, model: ModelToUpdateInput) -> Dict:
"""Send a GraphQL request calling updateModel resolver."""
payload = {"id": model_id, "input": map_update_model_input(model)}
fragment = fragment_builder(["id"])
mutation = get_update_model_mutation(fragment)
result = self.graphql_client.execute(mutation, payload)
return result["updateModel"]

def delete_model(self, model_id: str) -> Dict:
"""Send a GraphQL request to delete an organization model."""
payload = map_delete_model_input(model_id)
mutation = get_delete_model_mutation()
result = self.graphql_client.execute(mutation, payload)
return result["deleteModel"]

def create_project_model(self, project_model: ProjectModelToCreateInput) -> Dict:
"""Send a GraphQL request calling createModel resolver."""
payload = {"input": map_create_project_model_input(project_model)}
fragment = fragment_builder(["id"])
mutation = get_create_project_model_mutation(fragment)
result = self.graphql_client.execute(mutation, payload)
return result["createProjectModel"]

def update_project_model(
self, project_model_id: str, project_model: ProjectModelToUpdateInput
) -> Dict:
"""Send a GraphQL request calling updateProjectModel resolver."""
payload = {
"updateProjectModelId": project_model_id,
"input": map_update_project_model_input(project_model),
}
fragment = fragment_builder(["id", "configuration"])
mutation = get_update_project_model_mutation(fragment)
result = self.graphql_client.execute(mutation, payload)
return result["updateProjectModel"]

def delete_project_model(self, project_model_id: str) -> Dict:
"""Send a GraphQL request to delete a project model."""
payload = map_delete_project_model_input(project_model_id)
mutation = get_delete_project_model_mutation()
result = self.graphql_client.execute(mutation, payload)
return result["deleteProjectModel"]

def list_project_models(
self,
filters: ProjectModelFilters,
fields: ListOrTuple[str],
options: Optional[QueryOptions] = None,
) -> Generator[Dict, None, None]:
"""List assets with given options."""
"""List project models with given options."""
fragment = fragment_builder(fields)
query = get_project_models_query(fragment)
where = project_model_where_mapper(filters)
Expand Down
Loading