diff --git a/examples/client/python/simple_client_v2.py b/examples/client/python/simple_client_v2.py new file mode 100644 index 00000000..e3ccfc80 --- /dev/null +++ b/examples/client/python/simple_client_v2.py @@ -0,0 +1,108 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# Licensed under the Apache License, Version 2.0 https://aws.amazon.com/apache-2-0/ + +from api.connector import HTCGridConnector +from api.session import GridSession + +import time +import os +import json +import logging + +try: + client_config_file = os.environ['AGENT_CONFIG_FILE'] +except: + client_config_file = "/etc/agent/Agent_config.tfvars.json" + +with open(client_config_file, 'r') as file: + client_config_file = json.loads(file.read()) + + +TOTAL_COUNT = 0 +# Sample function callback +def sample_callback(worker_lambda_response): + global TOTAL_COUNT + TOTAL_COUNT += 1 + print(f"{TOTAL_COUNT}\tOK: {worker_lambda_response}") + + # do some computation + + pass + +if __name__ == "__main__": + + logging.info("Simple Client V2") + try: + username = os.environ['USERNAME'] + except KeyError: + username = "" + try: + password = os.environ['PASSWORD'] + except KeyError: + password = "" + + + # <1.> Establishes connection to one of many available HTC-Grids + grid_connector = HTCGridConnector(client_config_file, username=username, password=password) + + # <2.> Authentication based on the configuration file above + grid_connector.authenticate() + + + # <3.> Create session object with corresponding context & callback + context = { + "session_priority" : 1 + } + + grid_session = grid_connector.create_session( + service_name="MyService1", + context=context, + callback=sample_callback) + + + grid_session_2 = grid_connector.create_session( + service_name="MyService1", + context=context, + callback=sample_callback) + + + # <4.> Submit tasks for the session + task_1_definition = { + "worker_arguments": ["1000", "1", "1"] + } + + task_2_definition = { + "worker_arguments": ["2000", "1", "1"] + } + + grid_session.send([task_1_definition, task_2_definition]) + + grid_session_2.send([task_1_definition, task_2_definition]) + + # <5.> Submit additional tasks within the same session + time.sleep(1) + grid_session.send([task_1_definition, task_2_definition]) + + grid_session_2.send([task_1_definition, task_2_definition]) + + + # Blocking wait for completion + grid_session.wait_for_completion(timeout_ms=3000) + + print(grid_session.submitted_task_ids) + + print(grid_session.received_task_ids) + + grid_session.wait_for_completion() + grid_session_2.wait_for_completion() + + # grid_session.cancel() + + # Close session + grid_session.close() + grid_session_2.close() + + # Close connector and stop thread + grid_connector.close(wait_for_sessions_completion=True) + diff --git a/examples/submissions/k8s_jobs/Dockerfile.Submitter b/examples/submissions/k8s_jobs/Dockerfile.Submitter index bf2ceef6..dad4afd2 100644 --- a/examples/submissions/k8s_jobs/Dockerfile.Submitter +++ b/examples/submissions/k8s_jobs/Dockerfile.Submitter @@ -12,6 +12,7 @@ RUN pip install -r requirements.txt COPY ./examples/client/python/client.py . COPY ./examples/client/python/simple_client.py . +COPY ./examples/client/python/simple_client_v2.py . COPY ./examples/client/python/portfolio_pricing_client.py . COPY ./examples/client/python/sample_portfolio.json . diff --git a/examples/submissions/k8s_jobs/Makefile b/examples/submissions/k8s_jobs/Makefile index f68d4456..19d9dfa2 100644 --- a/examples/submissions/k8s_jobs/Makefile +++ b/examples/submissions/k8s_jobs/Makefile @@ -25,5 +25,7 @@ generated: mkdir -p $(GENERATED) && cat portfolio-pricing-book.yaml.tpl | sed "s/{{account_id}}/$(ACCOUNT_ID)/;s/{{region}}/$(REGION)/;s/{{image_name}}/$(SUBMITTER_IMAGE_NAME)/;s/{{image_tag}}/$(TAG)/" > $(GENERATED)/portfolio-pricing-book.yaml + mkdir -p $(GENERATED) && cat simple-task-v2-test.yaml.tpl | sed "s/{{account_id}}/$(ACCOUNT_ID)/;s/{{region}}/$(REGION)/;s/{{image_name}}/$(SUBMITTER_IMAGE_NAME)/;s/{{image_tag}}/$(TAG)/" > $(GENERATED)/simple-task-v2-test.yaml + clean: rm -rf $(GENERATED) \ No newline at end of file diff --git a/examples/submissions/k8s_jobs/simple-task-v2-test.yaml.tpl b/examples/submissions/k8s_jobs/simple-task-v2-test.yaml.tpl new file mode 100644 index 00000000..d4e0f75b --- /dev/null +++ b/examples/submissions/k8s_jobs/simple-task-v2-test.yaml.tpl @@ -0,0 +1,40 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: simple-client-v2-test +spec: + template: + spec: + containers: + - name: generator + securityContext: + {} + image: {{account_id}}.dkr.ecr.{{region}}.amazonaws.com/{{image_name}}:{{image_tag}} + imagePullPolicy: Always + resources: + limits: + cpu: 100m + memory: 128Mi + requests: + cpu: 100m + memory: 128Mi + command: ["python3","./simple_client_v2.py"] + volumeMounts: + - name: agent-config-volume + mountPath: /etc/agent + env: + - name: INTRA_VPC + value: "1" + restartPolicy: Never + nodeSelector: + grid/type: Operator + tolerations: + - effect: NoSchedule + key: grid/type + operator: Equal + value: Operator + volumes: + - name: agent-config-volume + configMap: + name: agent-configmap + backoffLimit: 0 diff --git a/source/client/python/api-v0.1/api/connector.py b/source/client/python/api-v0.1/api/connector.py index ee778e83..26ac0a66 100644 --- a/source/client/python/api-v0.1/api/connector.py +++ b/source/client/python/api-v0.1/api/connector.py @@ -10,7 +10,10 @@ import botocore import requests import logging +import threading +import traceback +from api.session import GridSession from api.in_out_manager import in_out_manager from utils.state_table_common import TASK_STATE_FINISHED from warrant_lite import WarrantLite @@ -23,8 +26,6 @@ from publicapi.api import default_api URLOPEN_LAMBDA_INVOKE_TIMEOUT_SEC = 120 # TODO Catch timout exception -TASK_TIMEOUT_SEC = 3600 -RETRY_COUNT = 5 TOKEK_REFRESH_INTERVAL_SEC = 200 working_path = os.path.dirname(os.path.realpath(__file__)) @@ -47,11 +48,25 @@ def get_safe_session_id(): return str(uuid.uuid1()) -class AWSConnector: +class HTCGridConnector: """This class implements the API for managing jobs""" in_out_manager = None - def __init__(self): + def __init__(self, agent_config_data, username="", password="", cognitoidp_client=None, s3_custom_resource=None, + redis_custom_connection=None): + """ + Args: + redis_custom_connection(object): override default redis connection + s3_custom_resource(object): override default s3 resource + cognitoidp_client(object): override default s3 cognito client + agent_config_data (dict): the HTC grid runtime configuration for the connector + username (string): the username used for authentication when the client run outside of a VPC + password (string): the password used for authentication when the client run outside of a VPC + + Returns: + Nothing + """ + # <1.> Setting defaults self.in_out_manager = None self.__api_gateway_endpoint = "" self.__public_api_gateway_endpoint = "" @@ -61,7 +76,7 @@ def __init__(self): self.__user_pool_client_id = "" self.__username = "" self.__password = "" - self.__dynamodb_results_pull_intervall = "" + self.__dynamodb_results_pull_interval = "" self.__task_input_passed_via_external_storage = "" self.__user_token_id = None self.__user_refresh_token = None @@ -74,36 +89,11 @@ def __init__(self): self.__api_client = None self.__default_api_client = None - def refresh(self): - """This method refreshes an expired JWT. The new JWT overrides the existing one""" - logging.info("starting cognito refresh") - try: - tokens = self.__cognito_client.initiate_auth( - ClientId=self.__user_pool_client_id, - AuthFlow='REFRESH_TOKEN_AUTH', - AuthParameters={ - 'REFRESH_TOKEN': self.__user_refresh_token, - } - ) - self.__user_token_id = tokens["AuthenticationResult"]["IdToken"] - logging.info("successfully cognito token refreshed") - except botocore.exceptions.ClientError: - logging.exception("Failed while refreshing cognito token") - - def init(self, agent_config_data, username="", password="", cognitoidp_client=None, s3_custom_resource=None, - redis_custom_connection=None): - """ - Args: - redis_custom_connection(object): override default redis connection - s3_custom_resource(object): override default s3 resource - cognitoidp_client(object): override default s3 cognito client - agent_config_data (dict): the HTC grid runtime configuration for the connector - username (string): the username used for authentication when the client run outside of a VPC - password (string): the password used for authentication when the client run outside of a VPC + self.active_sessions = {} + self.is_closed = False + self.wait_for_sessions_completion = True - Returns: - Nothing - """ + # <2.> Initialization logging.info("AGENT:", agent_config_data) self.in_out_manager = in_out_manager( agent_config_data['grid_storage_service'], @@ -120,7 +110,7 @@ def init(self, agent_config_data, username="", password="", cognitoidp_client=No self.__user_pool_client_id = agent_config_data['cognito_userpool_client_id'] self.__username = username self.__password = password - self.__dynamodb_results_pull_intervall = agent_config_data['dynamodb_results_pull_interval_sec'] + self.__dynamodb_results_pull_interval = agent_config_data['dynamodb_results_pull_interval_sec'] self.__task_input_passed_via_external_storage = agent_config_data['task_input_passed_via_external_storage'] self.__user_token_id = None if cognitoidp_client is None: @@ -143,13 +133,79 @@ def init(self, agent_config_data, username="", password="", cognitoidp_client=No self.__scheduler = BackgroundScheduler() logging.info("LAMBDA_ENDPOINT_URL:{}".format(self.__api_gateway_endpoint)) - logging.info("dynamodb_results_pull_interval_sec:{}".format(self.__dynamodb_results_pull_intervall)) + logging.info("dynamodb_results_pull_interval_sec:{}".format(self.__dynamodb_results_pull_interval)) logging.info("task_input_passed_via_external_storage:{}".format(self.__task_input_passed_via_external_storage)) logging.info("grid_storage_service:{}".format(agent_config_data['grid_storage_service'])) - logging.info("AWSConnector Initialized") + logging.info("HTCGridConnector Initialized") logging.info("init with {}".format(self.__user_pool_client_id)) logging.info("init with {}".format(self.__cognito_client)) + # <3.> Starting session management thread + self.t = threading.Thread(target=self.session_management_thread, args=(1,)) + self.t.start() + + + def create_session(self, service_name, context, callback): + + new_session = GridSession( + htc_grid_connector=self, + session_id=self.__get_safe_session_id(), + callback=callback + ) + + self.register_session(new_session) + + return new_session + + def close(self, wait_for_sessions_completion=True): + print("Connector: Closing HTC-Grid Connector") + self.is_closed = True + self.wait_for_sessions_completion = wait_for_sessions_completion + + def register_session(self, new_session): + assert(new_session.session_id not in self.active_sessions) + + self.active_sessions[new_session.session_id] = new_session + pass + + def diregister_session(self, session): + + # Check that all tasks of the session completed? + print(f"Connector: Diregistering session {session.session_id}") + del self.active_sessions[session.session_id] + pass + + def __get_safe_session_id(self): + session_id = uuid.uuid1() + return str(session_id) + + def session_management_thread(self, args): + + print("Connector: Thread started ", args) + + try: + while True: + print(f"Connector: Number of active sessions: {len(self.active_sessions)}") + + for session_id, session in self.active_sessions.items(): + session.check_tasks_states() + + time.sleep(1) + + if self.wait_for_sessions_completion: + # break if all sessions are completed + if len(self.active_sessions) == 0: + break + elif self.is_closed: + # break, ignore incompleted sessions if any + break + except Exception as e: + print("Unexpected error in session_management_thread {} [{}]".format( + e, traceback.format_exc())) + + def terminate(self): + self.t.join() + def authenticate(self): """This method authenticates against a Cognito User Pool. The JWT is stored as attribute of the class """ @@ -162,7 +218,7 @@ def authenticate(self): self.__user_token_id = tokens["AuthenticationResult"]["IdToken"] self.__user_refresh_token = tokens["AuthenticationResult"]["RefreshToken"] logging.info("authentication successful for user {}".format(self.__user_token_id)) - self.__scheduler.add_job(AWSConnector.refresh, 'interval', seconds=TOKEK_REFRESH_INTERVAL_SEC, + self.__scheduler.add_job(HTCGridConnector.refresh, 'interval', seconds=TOKEK_REFRESH_INTERVAL_SEC, args=[self]) self.__scheduler.start() except Exception as e: @@ -170,6 +226,30 @@ def authenticate(self): raise e self.__configuration.api_key['htc_cognito_authorizer'] = self.__user_token_id + def refresh(self): + """This method refreshes an expired JWT. The new JWT overrides the existing one""" + logging.info("starting cognito refresh") + try: + tokens = self.__cognito_client.initiate_auth( + ClientId=self.__user_pool_client_id, + AuthFlow='REFRESH_TOKEN_AUTH', + AuthParameters={ + 'REFRESH_TOKEN': self.__user_refresh_token, + } + ) + self.__user_token_id = tokens["AuthenticationResult"]["IdToken"] + logging.info("successfully cognito token refreshed") + except botocore.exceptions.ClientError: + logging.exception("Failed while refreshing cognito token") + + def is_task_input_passed_via_external_storage(self): + return self.__task_input_passed_via_external_storage + +####################################################################### +####################################################################### + + + def generate_user_task_json(self, tasks_list=None): """this methods returns from a list of tasks, a tasks object that can be submitted to the grid @@ -235,78 +315,22 @@ def generate_user_task_json(self, tasks_list=None): return user_task_json - # TODO implements this method - def cancel(self, session_id): - """ - Args: - session_id: - - Returns: - - """ - pass - - # TODO raise exception when the task list is above a given threshold - # TODO create a response object instead of dictionary - def send(self, tasks_list): # returns TaskID[] - """This method submits tasks to the HTC grid + def get_results(self, session_id): + """This methods get the result associated to a specific session_id Args: - tasks_list (list): the list of tasks to execute on the grid + session_id (list): session_id for which to retrieve results Returns: - dict: the response from the endpoint of the HTC grid + dict: the result of the submission """ - logging.info("Init send {} tasks".format(len(tasks_list))) - user_task_json_request = self.generate_user_task_json(tasks_list) - logging.info("user_task_json_request: {}".format(user_task_json_request)) - # print(user_task_json_request) - - json_response = self.submit(user_task_json_request) - logging.info("json_response = {}".format(json_response)) - return json_response - - def get_results(self, submission_response: dict, timeout_sec=0): - """This methods get the result associated to a specific submission - - Args: - submission_response (list): arrays storing the ids of the submission to check - timeout_sec (int): time after the connection between the client of the HTC grid is killed (Default value = 0) - - Returns: - str: the result of the submission - - """ - logging.info("Init get_results") - start_time = time.time() - - session_tasks_count: int = len(submission_response['task_ids']) - logging.info("session_tasks_count: {}".format(session_tasks_count)) - while True: - session_results = self.invoke_get_results_lambda({'session_id': submission_response['session_id']}) - logging.info("session_results: {}".format(session_results)) - # print("session_results: {}".format(session_results)) - - if 'metadata' in session_results \ - and session_results['metadata']['tasks_in_response'] == session_tasks_count: - break - elif 0 < timeout_sec < time.time() - start_time: - # We have timed out! - logging.error("Get Results Timed Out") - break - time.sleep(self.__dynamodb_results_pull_intervall) - - for i, completed_task in enumerate(session_results[TASK_STATE_FINISHED]): - stdout_bytes = self.in_out_manager.get_output_to_bytes(completed_task) - # print("stdout_bytes: {}".format(stdout_bytes)) - - output = base64.b64decode(stdout_bytes).decode('utf-8') + logging.info(f"Init get_results {session_id}") - session_results[TASK_STATE_FINISHED + '_OUTPUT'][i] = output + session_results = self.invoke_get_results_lambda({'session_id': session_id}) + logging.info("session_results: {}".format(session_results)) - logging.info("Finish get_results") return session_results # TODO this should be private diff --git a/source/client/python/api-v0.1/api/session.py b/source/client/python/api-v0.1/api/session.py new file mode 100644 index 00000000..a254cf85 --- /dev/null +++ b/source/client/python/api-v0.1/api/session.py @@ -0,0 +1,186 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# Licensed under the Apache License, Version 2.0 https://aws.amazon.com/apache-2-0/ + +import time +import logging +import json +import base64 +import traceback + +logging.basicConfig(format="%(asctime)s - %(levelname)s - %(filename)s - %(funcName)s - %(lineno)d - %(message)s", + datefmt='%H:%M:%S', level=logging.INFO) + +def get_time_now_ms(): + return int(round(time.time() * 1000)) + +class GridSession: + + def __init__(self, htc_grid_connector, session_id, callback): + + self.htc_grid_connector = htc_grid_connector + self.session_id = session_id + + self.submitted_task_ids = [] + self.submitted_tasks_count = 0 + + self.received_task_ids = {} + + self.callback = callback + self.time_send_was_invoked_ms = 0 + + self.in_out_manager = htc_grid_connector.in_out_manager + + # TODO: update constants + self.TASK_TIMEOUT_SEC = 3600 + self.RETRY_COUNT = 5 + + def send(self, tasks): # returns TaskID[] + """This method submits tasks to the HTC grid + + Args: + tasks_list (list): the list of tasks to execute on the grid + + Returns: + dict: the response from the endpoint of the HTC grid + + """ + try: + self.time_send_was_invoked_ms = get_time_now_ms() + + # <1.> Generate Task IDs based on the Session ID and the index of each task. + logging.info(f"Sending {len(tasks)} tasks for session {self.session_id}") + new_task_ids = [] + + for i, t in enumerate(tasks): + + task_index = i + self.submitted_tasks_count + task_id = self.__make_task_id_from_session_id(self.session_id, task_index) + new_task_ids.append(task_id) + + # <2.> Upload tasks into the Data Plane + serialized_tasks = [] + + for i, t in enumerate(tasks): + + data = json.dumps(t).encode('utf-8') + + b64data = base64.b64encode(data) + + self.in_out_manager.put_input_from_bytes(new_task_ids[i], b64data) + + serialized_tasks.append(b64data) + + # <3.> Construct submit_tasks Lambda invocation payload that will be passed via Data Plane + lambda_data_plane_payload = self.__construct_submit_tasks_lambda_invocation_payload(new_task_ids, serialized_tasks) + logging.info(f"lambda_data_plane_payload: {lambda_data_plane_payload}") + + # <4.> Invoke Submit Tasks Lambda in Control Plane + json_response = self.htc_grid_connector.submit(lambda_data_plane_payload) + logging.info(f"Submit Tasks Lambda json_response = {json_response}") + + # <5.> Bookkeeping + # TODO: check if submission is successful + self.submitted_tasks_count += len(tasks) + self.submitted_task_ids += new_task_ids + + return json_response + except Exception as e: + print("Unexpected error in sending {} [{}]".format( + e, traceback.format_exc())) + + + def __construct_submit_tasks_lambda_invocation_payload(self, new_task_ids, serialized_tasks): + lambda_payload = { + "session_id": self.session_id, + "scheduler_data": { + "task_timeout_sec": self.TASK_TIMEOUT_SEC, + "retry_count": self.RETRY_COUNT, + "tstamp_api_grid_connector_ms": 0, + "tstamp_agent_read_from_sqs_ms": 0 + }, + "stats": { + "stage1_grid_api_01_task_creation_tstmp": {"label": " ", "tstmp": self.time_send_was_invoked_ms}, + "stage1_grid_api_02_task_submission_tstmp": {"label": "upload_data_to_storage", + "tstmp": get_time_now_ms()}, + + "stage2_sbmtlmba_01_invocation_tstmp": {"label": "grid_api_2_lambda_ms", "tstmp": 0}, + "stage2_sbmtlmba_02_before_batch_write_tstmp": {"label": "task_construction_ms", "tstmp": 0}, + # "stage2_sbmtlmba_03_invocation_over_tstmp": {"label": "dynamo_db_submit_ms", "tstmp" : 0}, + + "stage3_agent_01_task_acquired_sqs_tstmp": {"label": "sqs_queuing_time_ms", "tstmp": 0}, + "stage3_agent_02_task_acquired_ddb_tstmp": {"label": "ddb_task_claiming_time_ms", "tstmp": 0}, + + "stage4_agent_01_user_code_finished_tstmp": {"label": "user_code_exec_time_ms", "tstmp": 0}, + "stage4_agent_02_S3_stdout_delivered_tstmp": {"label": "S3_stdout_upload_time_ms", "tstmp": 0} + }, + "tasks_list": { + "tasks": new_task_ids if self.htc_grid_connector.is_task_input_passed_via_external_storage() == 1 else serialized_tasks + } + } + + return lambda_payload + + + def check_tasks_states(self): + print(f"Session: Checking status {self.session_id}, submitted: {len(self.submitted_task_ids)}, received: {len(self.received_task_ids)}") + + try: + + res = self.htc_grid_connector.get_results(self.session_id) + + # Process only task_ids that we haven't seen before. + for t in res["finished"]: + if t not in self.received_task_ids: + + self.received_task_ids[t] = True + + stdout_bytes = self.in_out_manager.get_output_to_bytes(t) + logging.info("output_bytes: {}".format(stdout_bytes)) + + output = base64.b64decode(stdout_bytes).decode('utf-8') + logging.info("output_obj: {}".format(output)) + + # TODO: check if success or failure + + self.callback(output) + + except Exception as e: + print("Unexpected error in check_tasks_states {} [{}]".format( + e, traceback.format_exc())) + + + def wait_for_completion(self, timeout_ms=0, complete_and_close=False): + time_start_ms = get_time_now_ms() + + while len(self.received_task_ids) < len(self.submitted_task_ids): + print("Session: Main thread sleeps for results") + + time.sleep(1.0) + + if 0 < timeout_ms < get_time_now_ms() - time_start_ms: + logging.warning(f"{__name__} timeoud after {timeout_ms}") + break + + if len(self.received_task_ids) == len(self.submitted_task_ids): + # Unless more tasks will be submitted within this session + # Session can be considered completed. + if complete_and_close: + self.close() + + def cancel(self): + self.htc_grid_connector.cancel_sessions([self.session_id]) + + def close(self): + self.htc_grid_connector.diregister_session(session=self) + + + def __make_task_id_from_session_id(self, session_id, task_index): + return f"{session_id}_{task_index}" + + def __hash__(self): + return self.session_id + + def __eq__(self, other): + if not isinstance(other, type(self)): return NotImplemented + return self.session_id == other.session_id \ No newline at end of file