Skip to content

Commit

Permalink
chore(lab-3088): fix pylint & pyright errors
Browse files Browse the repository at this point in the history
  • Loading branch information
paulruelle committed Sep 23, 2024
1 parent bf37c0a commit 13ae5a3
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_model(self, model_id: str, fields: ListOrTuple[str]) -> Dict:
query = get_model_query(fragment)
variables = {"modelId": model_id}
result = self.graphql_client.execute(query, variables)
return result.get("model")
return result["model"]

def create_model(self, model: ModelToCreateInput) -> Dict:
"""Send a GraphQL request calling createModel resolver."""
Expand Down
2 changes: 1 addition & 1 deletion src/kili/llm/presentation/client/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def get_model(self, model_id: str, fields: Optional[List[str]] = None):
)

def create_model(self, organization_id: str, model: dict):
credentials_data = model.get("credentials")
credentials_data = model["credentials"]
model_type = ModelType(model["type"])

if model_type == ModelType.AZURE_OPEN_AI:
Expand Down
64 changes: 58 additions & 6 deletions tests/unit/llm/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"type": "OPEN_AI_SDK",
},
]
mock_get_model = {
mock_get_model_open_ai_sdk = {
"id": "model_id",
"credentials": {
"apiKey": "***",
Expand All @@ -29,6 +29,16 @@
"name": "Jamba (created by SDK)",
"type": "OPEN_AI_SDK",
}
mock_get_model_azure_open_ai = {
"id": "model_id",
"credentials": {
"apiKey": "***",
"endpoint": "https://ai21-jamba-1-5-large-ykxca.eastus.models.ai.azure.com",
"deploymentId": "deployment_id",
},
"name": "Jamba (created by SDK)",
"type": "AZURE_OPEN_AI",
}
mock_create_model = {"id": "new_model_id"}
mock_update_model = {
"id": "model_id",
Expand All @@ -49,15 +59,15 @@ def test_list_models(mocker):

def test_get_model(mocker):
kili_api_gateway = mocker.MagicMock()
kili_api_gateway.get_model.return_value = mock_get_model
kili_api_gateway.get_model.return_value = mock_get_model_open_ai_sdk

kili_llm = LlmClientMethods(kili_api_gateway)
result = kili_llm.get_model(model_id="model_id")

assert result == mock_get_model
assert result == mock_get_model_open_ai_sdk


def test_create_model(mocker):
def test_create_model_open_ai_sdk(mocker):
kili_api_gateway = mocker.MagicMock()
kili_api_gateway.create_model.return_value = mock_create_model

Expand All @@ -77,9 +87,50 @@ def test_create_model(mocker):
assert result == mock_create_model


def test_update_model(mocker):
def test_create_model_azure_openai(mocker):
kili_api_gateway = mocker.MagicMock()
kili_api_gateway.create_model.return_value = mock_create_model

kili_llm = LlmClientMethods(kili_api_gateway)
result = kili_llm.create_model(
organization_id="organization_id",
model={
"name": "New Model",
"type": "AZURE_OPEN_AI",
"credentials": {
"api_key": "***",
"endpoint": "https://api.openai.com",
"deployment_id": "deployment_id",
},
},
)

assert result == mock_create_model


def test_update_model_open_ai_sdk(mocker):
kili_api_gateway = mocker.MagicMock()
kili_api_gateway.get_model.return_value = mock_get_model_open_ai_sdk
kili_api_gateway.update_model.return_value = mock_update_model

kili_llm = LlmClientMethods(kili_api_gateway)
result = kili_llm.update_model(
model_id="model_id",
model={
"name": "Updated Model",
"credentials": {
"api_key": "***",
"endpoint": "https://api.openai.com",
},
},
)

assert result == mock_update_model


def test_update_model_azure_open_ai(mocker):
kili_api_gateway = mocker.MagicMock()
kili_api_gateway.get_model.return_value = mock_get_model
kili_api_gateway.get_model.return_value = mock_get_model_azure_open_ai
kili_api_gateway.update_model.return_value = mock_update_model

kili_llm = LlmClientMethods(kili_api_gateway)
Expand All @@ -90,6 +141,7 @@ def test_update_model(mocker):
"credentials": {
"api_key": "***",
"endpoint": "https://api.openai.com",
"deployment_id": "deployment_id",
},
},
)
Expand Down

0 comments on commit 13ae5a3

Please sign in to comment.