Skip to content

Commit

Permalink
Search task added (#177)
Browse files Browse the repository at this point in the history
* Search task added
Function accepts json string or in dict format. Test changes are added to check whether function will accept string and dict as input request body parameters.

Signed-off-by: Nurlan <[email protected]>

* Another test case added
Added "else" block test case.

Signed-off-by: Nurlan <[email protected]>

* Test corrected

Signed-off-by: Nurlan <[email protected]>

* Search model added and code was corrected little bit

Signed-off-by: Nurlan <[email protected]>

* Linting

Signed-off-by: Nurlan <[email protected]>

* Additional test case added for each search api: task and model

Signed-off-by: Nurlan <[email protected]>

---------

Signed-off-by: Nurlan <[email protected]>
(cherry picked from commit 57f0cf8)
  • Loading branch information
Nurlanprog authored and github-actions[bot] committed Jul 20, 2023
1 parent 4e0ba52 commit 49ba275
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 0 deletions.
61 changes: 61 additions & 0 deletions opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# GitHub history for details.


import json
import time
from typing import Any, List, Union

Expand Down Expand Up @@ -392,6 +393,66 @@ def _get_task_info(self, task_id: str):
url=API_URL,
)

def search_task(self, input_json) -> object:
"""
This method searches a task from opensearch cluster (using ml commons api)
:param json: json input for the search request
:type json: string or dict
:return: returns a json object, with detailed information about the searched task
:rtype: object
"""

API_URL = f"{ML_BASE_URI}/tasks/_search"

if isinstance(input_json, str):
try:
json_obj = json.loads(input_json)
if not isinstance(json_obj, dict):
return "Invalid JSON object passed as argument."
API_BODY = json.dumps(json_obj)
except json.JSONDecodeError:
return "Invalid JSON string passed as argument."
elif isinstance(input_json, dict):
API_BODY = json.dumps(input_json)
else:
return "Invalid JSON object passed as argument."

return self._client.transport.perform_request(
method="GET",
url=API_URL,
body=API_BODY,
)

def search_model(self, input_json) -> object:
"""
This method searches a task from opensearch cluster (using ml commons api)
:param json: json input for the search request
:type json: string or dict
:return: returns a json object, with detailed information about the searched task
:rtype: object
"""

API_URL = f"{ML_BASE_URI}/models/_search"

if isinstance(input_json, str):
try:
json_obj = json.loads(input_json)
if not isinstance(json_obj, dict):
return "Invalid JSON object passed as argument."
API_BODY = json.dumps(json_obj)
except json.JSONDecodeError:
return "Invalid JSON string passed as argument."
elif isinstance(input_json, dict):
API_BODY = json.dumps(input_json)
else:
return "Invalid JSON object passed as argument."

return self._client.transport.perform_request(
method="POST",
url=API_URL,
body=API_BODY,
)

def get_model_info(self, model_id: str) -> object:
"""
This method return information about a model registered in the opensearch cluster (using ml commons api)
Expand Down
97 changes: 97 additions & 0 deletions tests/ml_commons/test_ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,102 @@ def test_execute():
), "Raised Exception during execute API testing with JSON string"


def test_search():
# Search task cases
raised = False
try:
search_task_obj = ml_client.search_task(
input_json='{"query": {"match_all": {}},"size": 1}'
)
assert search_task_obj["hits"]["hits"] != []
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching task"

raised = False
try:
search_task_obj = ml_client.search_task(
input_json={"query": {"match_all": {}}, "size": 1}
)
assert search_task_obj["hits"]["hits"] != []
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching task"

raised = False
try:
search_task_obj = ml_client.search_task(input_json=15)
assert search_task_obj == "Invalid JSON object passed as argument."
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching task"

raised = False
try:
search_task_obj = ml_client.search_task(input_json="15")
assert search_task_obj == "Invalid JSON object passed as argument."
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching task"

raised = False
try:
search_task_obj = ml_client.search_task(
input_json='{"query": {"match_all": {}},size: 1}'
)
assert search_task_obj == "Invalid JSON string passed as argument."
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching task"

# Search model cases
raised = False
try:
search_model_obj = ml_client.search_model(
input_json='{"query": {"match_all": {}},"size": 1}'
)
assert search_model_obj["hits"]["hits"] != []
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching model"

raised = False
try:
search_model_obj = ml_client.search_model(
input_json={"query": {"match_all": {}}, "size": 1}
)
assert search_model_obj["hits"]["hits"] != []
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching model"

raised = False
try:
search_model_obj = ml_client.search_model(input_json=15)
assert search_model_obj == "Invalid JSON object passed as argument."
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching model"

raised = False
try:
search_model_obj = ml_client.search_model(input_json="15")
assert search_model_obj == "Invalid JSON object passed as argument."
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching model"

raised = False
try:
search_model_obj = ml_client.search_model(
input_json='{"query": {"match_all": {}},size: 1}'
)
assert search_model_obj == "Invalid JSON string passed as argument."
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching model"


def test_DEPRECATED_integration_pretrained_model_upload_unload_delete():
raised = False
try:
Expand Down Expand Up @@ -379,6 +475,7 @@ def test_integration_model_train_register_full_cycle():
print("Model Task Status:", ml_task_status)
raised = True
assert raised == False, "Raised Exception in pulling task info"

# This is test is being flaky. Sometimes the test is passing and sometimes showing 500 error
# due to memory circuit breaker.
# Todo: We need to revisit this test.
Expand Down

0 comments on commit 49ba275

Please sign in to comment.