diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index c7c08f1d..41ef4275 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -6,6 +6,7 @@ # GitHub history for details. +import json import time from typing import Any, List, Union @@ -392,6 +393,66 @@ def _get_task_info(self, task_id: str): url=API_URL, ) + def search_task(self, input_json) -> object: + """ + This method searches a task from opensearch cluster (using ml commons api) + :param json: json input for the search request + :type json: string or dict + :return: returns a json object, with detailed information about the searched task + :rtype: object + """ + + API_URL = f"{ML_BASE_URI}/tasks/_search" + + if isinstance(input_json, str): + try: + json_obj = json.loads(input_json) + if not isinstance(json_obj, dict): + return "Invalid JSON object passed as argument." + API_BODY = json.dumps(json_obj) + except json.JSONDecodeError: + return "Invalid JSON string passed as argument." + elif isinstance(input_json, dict): + API_BODY = json.dumps(input_json) + else: + return "Invalid JSON object passed as argument." + + return self._client.transport.perform_request( + method="GET", + url=API_URL, + body=API_BODY, + ) + + def search_model(self, input_json) -> object: + """ + This method searches a task from opensearch cluster (using ml commons api) + :param json: json input for the search request + :type json: string or dict + :return: returns a json object, with detailed information about the searched task + :rtype: object + """ + + API_URL = f"{ML_BASE_URI}/models/_search" + + if isinstance(input_json, str): + try: + json_obj = json.loads(input_json) + if not isinstance(json_obj, dict): + return "Invalid JSON object passed as argument." + API_BODY = json.dumps(json_obj) + except json.JSONDecodeError: + return "Invalid JSON string passed as argument." + elif isinstance(input_json, dict): + API_BODY = json.dumps(input_json) + else: + return "Invalid JSON object passed as argument." + + return self._client.transport.perform_request( + method="POST", + url=API_URL, + body=API_BODY, + ) + def get_model_info(self, model_id: str) -> object: """ This method return information about a model registered in the opensearch cluster (using ml commons api) diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 73868ba4..ae32edd9 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -100,6 +100,102 @@ def test_execute(): ), "Raised Exception during execute API testing with JSON string" +def test_search(): + # Search task cases + raised = False + try: + search_task_obj = ml_client.search_task( + input_json='{"query": {"match_all": {}},"size": 1}' + ) + assert search_task_obj["hits"]["hits"] != [] + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in searching task" + + raised = False + try: + search_task_obj = ml_client.search_task( + input_json={"query": {"match_all": {}}, "size": 1} + ) + assert search_task_obj["hits"]["hits"] != [] + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in searching task" + + raised = False + try: + search_task_obj = ml_client.search_task(input_json=15) + assert search_task_obj == "Invalid JSON object passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in searching task" + + raised = False + try: + search_task_obj = ml_client.search_task(input_json="15") + assert search_task_obj == "Invalid JSON object passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in searching task" + + raised = False + try: + search_task_obj = ml_client.search_task( + input_json='{"query": {"match_all": {}},size: 1}' + ) + assert search_task_obj == "Invalid JSON string passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in searching task" + + # Search model cases + raised = False + try: + search_model_obj = ml_client.search_model( + input_json='{"query": {"match_all": {}},"size": 1}' + ) + assert search_model_obj["hits"]["hits"] != [] + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in searching model" + + raised = False + try: + search_model_obj = ml_client.search_model( + input_json={"query": {"match_all": {}}, "size": 1} + ) + assert search_model_obj["hits"]["hits"] != [] + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in searching model" + + raised = False + try: + search_model_obj = ml_client.search_model(input_json=15) + assert search_model_obj == "Invalid JSON object passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in searching model" + + raised = False + try: + search_model_obj = ml_client.search_model(input_json="15") + assert search_model_obj == "Invalid JSON object passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in searching model" + + raised = False + try: + search_model_obj = ml_client.search_model( + input_json='{"query": {"match_all": {}},size: 1}' + ) + assert search_model_obj == "Invalid JSON string passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in searching model" + + def test_DEPRECATED_integration_pretrained_model_upload_unload_delete(): raised = False try: @@ -379,6 +475,7 @@ def test_integration_model_train_register_full_cycle(): print("Model Task Status:", ml_task_status) raised = True assert raised == False, "Raised Exception in pulling task info" + # This is test is being flaky. Sometimes the test is passing and sometimes showing 500 error # due to memory circuit breaker. # Todo: We need to revisit this test.