Skip to content

Commit

Permalink
chore(lab-3088): fix pylint & pyright errors
Browse files Browse the repository at this point in the history
chore(lab-3088): fix pylint & pyright errors

chore(lab-3088): fix pylint & pyright errors
  • Loading branch information
paulruelle committed Sep 23, 2024
1 parent 4faca36 commit 1ae9a34
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,20 @@ def project_model_where_mapper(filter: ProjectModelFilters) -> Dict:

def map_create_model_input(data: ModelToCreateInput) -> Dict:
"""Build the GraphQL ModelInput variable to be sent in an operation."""
if data.type == ModelType.AZURE_OPEN_AI:
if data.type == ModelType.AZURE_OPEN_AI and isinstance(
data.credentials, AzureOpenAICredentials
):
credentials = {
"apiKey": data.credentials.api_key,
"deploymentId": data.credentials.deployment_id,
"endpoint": data.credentials.endpoint,
}
elif data.type == ModelType.OPEN_AI_SDK:
elif data.type == ModelType.OPEN_AI_SDK and isinstance(data.credentials, OpenAISDKCredentials):
credentials = {"apiKey": data.credentials.api_key, "endpoint": data.credentials.endpoint}
else:
raise ValueError(f"Unsupported model type: {data.type}")
raise ValueError(
f"Unsupported model type or credentials: {data.type}, {type(data.credentials)}"
)

return {
"credentials": credentials,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,46 +61,48 @@ def list_models(
None,
)

def get_model(self, model_id: str, fields: ListOrTuple[str]) -> Optional[Dict]:
def get_model(self, model_id: str, fields: ListOrTuple[str]) -> Dict:
"""Get a model by ID."""
fragment = fragment_builder(fields)
query = get_model_query(fragment)
variables = {"modelId": model_id}
result = self.graphql_client.execute(query, variables)
return result.get("model")
return result["model"]

def create_model(self, model: ModelToCreateInput):
def create_model(self, model: ModelToCreateInput) -> Dict:
"""Send a GraphQL request calling createModel resolver."""
payload = {"input": map_create_model_input(model)}
fragment = fragment_builder(["id"])
mutation = get_create_model_mutation(fragment)
result = self.graphql_client.execute(mutation, payload)
return result["createModel"]

def update_model(self, model_id: str, model: ModelToUpdateInput):
def update_model(self, model_id: str, model: ModelToUpdateInput) -> Dict:
"""Send a GraphQL request calling updateModel resolver."""
payload = {"id": model_id, "input": map_update_model_input(model)}
fragment = fragment_builder(["id"])
mutation = get_update_model_mutation(fragment)
result = self.graphql_client.execute(mutation, payload)
return result["updateModel"]

def delete_model(self, model_id: str):
def delete_model(self, model_id: str) -> Dict:
"""Send a GraphQL request to delete an organization model."""
payload = map_delete_model_input(model_id)
mutation = get_delete_model_mutation()
result = self.graphql_client.execute(mutation, payload)
return result["deleteModel"]

def create_project_model(self, project_model: ProjectModelToCreateInput):
def create_project_model(self, project_model: ProjectModelToCreateInput) -> Dict:
"""Send a GraphQL request calling createModel resolver."""
payload = {"input": map_create_project_model_input(project_model)}
fragment = fragment_builder(["id"])
mutation = get_create_project_model_mutation(fragment)
result = self.graphql_client.execute(mutation, payload)
return result["createProjectModel"]

def update_project_model(self, project_model_id: str, project_model: ProjectModelToUpdateInput):
def update_project_model(
self, project_model_id: str, project_model: ProjectModelToUpdateInput
) -> Dict:
"""Send a GraphQL request calling updateProjectModel resolver."""
payload = {
"updateProjectModelId": project_model_id,
Expand All @@ -111,7 +113,7 @@ def update_project_model(self, project_model_id: str, project_model: ProjectMode
result = self.graphql_client.execute(mutation, payload)
return result["updateProjectModel"]

def delete_project_model(self, project_model_id: str):
def delete_project_model(self, project_model_id: str) -> Dict:
"""Send a GraphQL request to delete a project model."""
payload = map_delete_project_model_input(project_model_id)
mutation = get_delete_project_model_mutation()
Expand Down
21 changes: 21 additions & 0 deletions src/kili/domain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,40 @@ class OrganizationModelFilters:


class ModelType(str, Enum):
"""Enumeration of the supported model types.
- `AZURE_OPEN_AI`: Models hosted on Microsoft Azure's OpenAI service.
- `OPEN_AI_SDK`: Models provided via OpenAI's SDK.
"""

AZURE_OPEN_AI = "AZURE_OPEN_AI"
OPEN_AI_SDK = "OPEN_AI_SDK"


@dataclass
class AzureOpenAICredentials:
"""Credentials for accessing Azure OpenAI models.
Attributes:
- `api_key`: The API key required for authentication to Azure OpenAI.
- `deployment_id`: The specific deployment of the model within Azure.
- `endpoint`: The endpoint URL where the Azure OpenAI service is hosted.
"""

api_key: str
deployment_id: str
endpoint: str


@dataclass
class OpenAISDKCredentials:
"""Credentials for accessing OpenAI SDK models.
Attributes:
- `api_key`: The API key required for authentication to OpenAI's SDK.
- `endpoint`: The endpoint URL where the OpenAI SDK service is hosted.
"""

api_key: str
endpoint: str

Expand Down
8 changes: 4 additions & 4 deletions src/kili/llm/presentation/client/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def export(
warnings.warn(str(excp), stacklevel=2)
return None

def list_models(self, organization_id: str, fields: Optional[List[str]] = None) -> List[Dict]:
def list_models(self, organization_id: str, fields: Optional[List[str]] = None):
"""List models of given organization."""
converted_filters = OrganizationModelFilters(
organization_id=organization_id,
Expand All @@ -113,14 +113,14 @@ def list_models(self, organization_id: str, fields: Optional[List[str]] = None)
)
)

def get_model(self, model_id: str, fields: Optional[List[str]] = None) -> Dict:
def get_model(self, model_id: str, fields: Optional[List[str]] = None):
return self.kili_api_gateway.get_model(
model_id=model_id,
fields=fields if fields else DEFAULT_ORGANIZATION_MODEL_FIELDS,
)

def create_model(self, organization_id: str, model: dict):
credentials_data = model.get("credentials")
credentials_data = model["credentials"]
model_type = ModelType(model["type"])

if model_type == ModelType.AZURE_OPEN_AI:
Expand Down Expand Up @@ -168,7 +168,7 @@ def delete_model(self, model_id: str):

def list_project_models(
self, project_id: str, filters: Optional[Dict] = None, fields: Optional[List[str]] = None
) -> List[Dict]:
):
"""List project models of given project."""
converted_filters = ProjectModelFilters(
project_id=project_id,
Expand Down
2 changes: 0 additions & 2 deletions src/kili/services/plugins/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def on_submit(self, label: Dict, asset_id: str):
"""
# pylint: disable=unused-argument
self.logger.warning("Method not implemented. Define a custom on_submit on your plugin")
# pylint: disable=unnecessary-ellipsis

def on_review(
self,
Expand Down Expand Up @@ -99,7 +98,6 @@ def on_review(self, label: Dict, asset_id: str):
"""
# pylint: disable=unused-argument
self.logger.warning("Method not implemented. Define a custom on_review on your plugin")
# pylint: disable=unnecessary-ellipsis

def on_custom_interface_click(
self,
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/llm/services/export/test_dynamic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from kili.llm.presentation.client.llm import LlmClientMethods

mock_json_interface = {
Expand Down Expand Up @@ -39,6 +41,8 @@
}
}

mock_empty_json_interface = {"jobs": {}}

mock_fetch_assets = [
{
"labels": [
Expand Down Expand Up @@ -435,3 +439,19 @@ def test_export_dynamic(mocker):
project_id="project_id",
)
assert result == expected_export


def test_export_dynamic_empty_json_interface(mocker):
get_project_return_val = {
"jsonInterface": mock_empty_json_interface,
"inputType": "LLM_INSTR_FOLLOWING",
"title": "Test project",
"id": "project_id",
"dataConnections": None,
}
kili_api_gateway = mocker.MagicMock()
kili_llm = LlmClientMethods(kili_api_gateway)
with pytest.raises(ValueError):
kili_llm.export(
project_id="project_id",
)
112 changes: 106 additions & 6 deletions tests/unit/llm/test_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from kili.llm.presentation.client.llm import LlmClientMethods

mock_list_models = [
Expand All @@ -20,7 +22,7 @@
"type": "OPEN_AI_SDK",
},
]
mock_get_model = {
mock_get_model_open_ai_sdk = {
"id": "model_id",
"credentials": {
"apiKey": "***",
Expand All @@ -29,6 +31,16 @@
"name": "Jamba (created by SDK)",
"type": "OPEN_AI_SDK",
}
mock_get_model_azure_open_ai = {
"id": "model_id",
"credentials": {
"apiKey": "***",
"endpoint": "https://ai21-jamba-1-5-large-ykxca.eastus.models.ai.azure.com",
"deploymentId": "deployment_id",
},
"name": "Jamba (created by SDK)",
"type": "AZURE_OPEN_AI",
}
mock_create_model = {"id": "new_model_id"}
mock_update_model = {
"id": "model_id",
Expand All @@ -49,15 +61,15 @@ def test_list_models(mocker):

def test_get_model(mocker):
kili_api_gateway = mocker.MagicMock()
kili_api_gateway.get_model.return_value = mock_get_model
kili_api_gateway.get_model.return_value = mock_get_model_open_ai_sdk

kili_llm = LlmClientMethods(kili_api_gateway)
result = kili_llm.get_model(model_id="model_id")

assert result == mock_get_model
assert result == mock_get_model_open_ai_sdk


def test_create_model(mocker):
def test_create_model_open_ai_sdk(mocker):
kili_api_gateway = mocker.MagicMock()
kili_api_gateway.create_model.return_value = mock_create_model

Expand All @@ -77,9 +89,65 @@ def test_create_model(mocker):
assert result == mock_create_model


def test_update_model(mocker):
def test_create_model_azure_openai(mocker):
kili_api_gateway = mocker.MagicMock()
kili_api_gateway.get_model.return_value = mock_get_model
kili_api_gateway.create_model.return_value = mock_create_model

kili_llm = LlmClientMethods(kili_api_gateway)
result = kili_llm.create_model(
organization_id="organization_id",
model={
"name": "New Model",
"type": "AZURE_OPEN_AI",
"credentials": {
"api_key": "***",
"endpoint": "https://api.openai.com",
"deployment_id": "deployment_id",
},
},
)

assert result == mock_create_model


def test_create_invalid_model(mocker):
kili_api_gateway = mocker.MagicMock()
kili_llm = LlmClientMethods(kili_api_gateway)

with pytest.raises(ValueError):
kili_llm.create_model(
organization_id="organization_id",
model={
"name": "New Model",
"type": "Wrong type",
"credentials": {"api_key": "***", "endpoint": "https://api.openai.com"},
},
)


def test_update_model_open_ai_sdk(mocker):
kili_api_gateway = mocker.MagicMock()
kili_api_gateway.get_model.return_value = mock_get_model_open_ai_sdk
kili_api_gateway.update_model.return_value = mock_update_model

kili_llm = LlmClientMethods(kili_api_gateway)
result = kili_llm.update_model(
model_id="model_id",
model={
"name": "Updated Model",
"credentials": {
"api_key": "***",
"endpoint": "https://api.openai.com",
},
},
)

assert result == mock_update_model


def test_update_model_azure_open_ai(mocker):
kili_api_gateway = mocker.MagicMock()
kili_api_gateway.get_model.return_value = mock_get_model_azure_open_ai
kili_api_gateway.update_model.return_value = mock_update_model

kili_llm = LlmClientMethods(kili_api_gateway)
Expand All @@ -90,13 +158,45 @@ def test_update_model(mocker):
"credentials": {
"api_key": "***",
"endpoint": "https://api.openai.com",
"deployment_id": "deployment_id",
},
},
)

assert result == mock_update_model


def test_update_invalid_model(mocker):
kili_api_gateway = mocker.MagicMock()
kili_llm = LlmClientMethods(kili_api_gateway)

with pytest.raises(ValueError):
kili_llm.update_model(
model_id="model_id",
model={
"name": "New Model",
"type": "Wrong type",
"credentials": {"api_key": "***", "endpoint": "https://api.openai.com"},
},
)


def test_update_non_existing_model(mocker):
kili_api_gateway = mocker.MagicMock()
kili_llm = LlmClientMethods(kili_api_gateway)
kili_api_gateway.get_model.return_value = None

with pytest.raises(ValueError):
kili_llm.update_model(
model_id="model_id",
model={
"name": "New Model",
"type": "Wrong type",
"credentials": {"api_key": "***", "endpoint": "https://api.openai.com"},
},
)


def test_delete_model(mocker):
kili_api_gateway = mocker.MagicMock()
kili_api_gateway.delete_model.return_value = mock_delete_model
Expand Down

0 comments on commit 1ae9a34

Please sign in to comment.