diff --git a/src/biokbase/narrative/jobs/appmanager.py b/src/biokbase/narrative/jobs/appmanager.py index 8480c57d9b..12a5aee0b0 100644 --- a/src/biokbase/narrative/jobs/appmanager.py +++ b/src/biokbase/narrative/jobs/appmanager.py @@ -6,7 +6,8 @@ import random import re import traceback -from typing import Callable, Optional +from collections.abc import Callable +from typing import Any from biokbase import auth from biokbase.narrative import clients @@ -19,14 +20,12 @@ ) from biokbase.narrative.common import kblogging from biokbase.narrative.exception_util import transform_job_exception -from biokbase.narrative.system import strict_system_variable, system_variable - -from biokbase.narrative.widgetmanager import WidgetManager - from biokbase.narrative.jobs import specmanager from biokbase.narrative.jobs.job import Job from biokbase.narrative.jobs.jobcomm import MESSAGE_TYPE, JobComm from biokbase.narrative.jobs.jobmanager import JobManager +from biokbase.narrative.system import strict_system_variable, system_variable +from biokbase.narrative.widgetmanager import WidgetManager """ A module for managing apps, specs, requirements, and for starting jobs. @@ -45,7 +44,7 @@ def timestamp() -> str: return datetime.datetime.utcnow().isoformat() + "Z" -def _app_error_wrapper(app_func: Callable) -> any: +def _app_error_wrapper(app_func: Callable) -> Callable: """ This is a decorator meant to wrap any of the `run_app*` methods here. It captures any raised exception, formats it into a message that can be sent @@ -122,7 +121,7 @@ def __new__(cls): AppManager.__instance._comm = None return AppManager.__instance - def reload(self): + def reload(self: "AppManager"): """ Reloads all app specs into memory from the App Catalog. Any outputs of app_usage, app_description, or available_apps @@ -130,7 +129,7 @@ def reload(self): """ self.spec_manager.reload() - def app_usage(self, app_id, tag="release"): + def app_usage(self: "AppManager", app_id: str, tag: str = "release"): """ This shows the list of inputs and outputs for a given app with a given tag. By default, this is done in a pretty HTML way, but this app can be @@ -149,7 +148,7 @@ def app_usage(self, app_id, tag="release"): """ return self.spec_manager.app_usage(app_id, tag) - def app_description(self, app_id, tag="release"): + def app_description(self: "AppManager", app_id: str, tag: str = "release"): """ Returns the app description in a printable HTML format. @@ -166,7 +165,7 @@ def app_description(self, app_id, tag="release"): """ return self.spec_manager.app_description(app_id, tag) - def available_apps(self, tag="release"): + def available_apps(self: "AppManager", tag: str = "release"): """ Lists the set of available apps for a given tag in a simple table. If the tag is not found, a ValueError will be raised. @@ -181,13 +180,13 @@ def available_apps(self, tag="release"): @_app_error_wrapper def run_legacy_batch_app( - self, - app_id, + self: "AppManager", + app_id: str, params, - tag="release", - version=None, - cell_id=None, - run_id=None, + tag: str = "release", + version: str | None = None, + cell_id: str | None = None, + run_id: str | None = None, dry_run=False, ): if params is None: @@ -320,15 +319,15 @@ def run_legacy_batch_app( @_app_error_wrapper def run_app( - self, - app_id, + self: "AppManager", + app_id: str, params, - tag="release", - version=None, - cell_id=None, - run_id=None, + tag: str = "release", + version: str | None = None, + cell_id: str | None = None, + run_id: str | None = None, dry_run=False, - ): + ) -> dict[str, Any]: """ Attempts to run the app, returns a Job with the running app info. If this is given a cell_id, then returns None. If not, it returns the @@ -409,12 +408,12 @@ def run_app( @_app_error_wrapper def run_app_batch( - self, + self: "AppManager", app_info: list, - cell_id: str = None, - run_id: str = None, + cell_id: str | None = None, + run_id: str | None = None, dry_run: bool = False, - ) -> Optional[dict]: + ) -> None | dict[str, Job | list[Job]] | dict[str, dict[str, str | int] | list]: """ Attempts to run a batch of apps in bulk using the Execution Engine's run_app_batch endpoint. If a cell_id is provided, this sends various job messages over the comm channel, and returns None. @@ -554,7 +553,7 @@ def run_app_batch( }, ) - child_jobs = Job.from_job_ids(child_ids, return_list=True) + child_jobs = Job.from_job_ids(child_ids) parent_job = Job.from_job_id( batch_id, children=child_jobs, @@ -568,7 +567,7 @@ def run_app_batch( if cell_id is None: return {"parent_job": parent_job, "child_jobs": child_jobs} - def _validate_bulk_app_info(self, app_info: dict): + def _validate_bulk_app_info(self: "AppManager", app_info: dict): """ Validation consists of: 1. must have "app_id" with format xyz/abc @@ -610,7 +609,9 @@ def _validate_bulk_app_info(self, app_info: dict): f"an app version must be a string, not {app_info['version']}" ) - def _reconstitute_shared_params(self, app_info_el: dict) -> None: + def _reconstitute_shared_params( + self: "AppManager", app_info_el: dict[str, Any] + ) -> None: """ Mutate each params dict to include any shared_params app_info_el is structured like: @@ -639,14 +640,14 @@ def _reconstitute_shared_params(self, app_info_el: dict) -> None: param_set.setdefault(k, v) def _build_run_job_params( - self, - spec: dict, + self: "AppManager", + spec: dict[str, Any], tag: str, - param_set: dict, - version: Optional[str] = None, - cell_id: Optional[str] = None, - run_id: Optional[str] = None, - ws_id: Optional[int] = None, + param_set: dict[str, Any], + version: str | None = None, + cell_id: str | None = None, + run_id: str | None = None, + ws_id: int | None = None, ) -> dict: """ Builds the set of inputs for EE2.run_job and EE2.run_job_batch (RunJobParams) given a spec @@ -726,13 +727,13 @@ def _build_run_job_params( @_app_error_wrapper def run_local_app( - self, - app_id, - params, - tag="release", - version=None, - cell_id=None, - run_id=None, + self: "AppManager", + app_id: str, + params: dict[str, Any], + tag: str = "release", + version: str | None = None, + cell_id: str | None = None, + run_id: str | None = None, widget_state=None, ): """ @@ -811,14 +812,14 @@ def run_local_app( ) def run_local_app_advanced( - self, - app_id, - params, + self: "AppManager", + app_id: str, + params: dict[str, Any], widget_state, - tag="release", - version=None, - cell_id=None, - run_id=None, + tag: str = "release", + version: str | None = None, + cell_id: str | None = None, + run_id: str | None = None, ): return self.run_local_app( app_id, @@ -830,7 +831,9 @@ def run_local_app_advanced( run_id=run_id, ) - def _get_validated_app_spec(self, app_id, tag, is_long, version=None): + def _get_validated_app_spec( + self: "AppManager", app_id: str, tag: str, is_long, version: str | None = None + ): if ( version is not None and tag != "release" @@ -860,7 +863,12 @@ def _get_validated_app_spec(self, app_id, tag, is_long, version=None): ) return spec - def _map_group_inputs(self, value, spec_param, spec_params): + def _map_group_inputs( + self: "AppManager", + value: list | dict | None, + spec_param: dict[str, Any], + spec_params: dict[str, Any], + ): if isinstance(value, list): return [self._map_group_inputs(v, spec_param, spec_params) for v in value] @@ -891,7 +899,12 @@ def _map_group_inputs(self, value, spec_param, spec_params): mapped_value[target_key] = target_val return mapped_value - def _map_inputs(self, input_mapping, params, spec_params): + def _map_inputs( + self: "AppManager", + input_mapping: list[dict[str, Any]], + params: dict[str, Any], + spec_params: dict[str, Any], + ): """ Maps the dictionary of parameters and inputs based on rules provided in the input_mapping. This iterates over the list of input_mappings, and @@ -977,7 +990,7 @@ def _map_inputs(self, input_mapping, params, spec_params): inputs_list.append(inputs_dict[k]) return inputs_list - def _generate_input(self, generator): + def _generate_input(self: "AppManager", generator: dict[str, Any] | None): """ Generates an input value using rules given by NarrativeMethodStore.AutoGeneratedValue. @@ -1007,10 +1020,10 @@ def _generate_input(self, generator): return ret + str(generator["suffix"]) return ret - def _send_comm_message(self, msg_type, content): + def _send_comm_message(self: "AppManager", msg_type: str, content: dict[str, Any]): JobComm().send_comm_message(msg_type, content) - def _get_agent_token(self, name: str) -> auth.TokenInfo: + def _get_agent_token(self: "AppManager", name: str) -> auth.TokenInfo: """ Retrieves an agent token from the Auth service with a formatted name. This prepends "KBApp_" to the name for filtering, and trims to make sure the name diff --git a/src/biokbase/narrative/jobs/job.py b/src/biokbase/narrative/jobs/job.py index 0aacc5e06e..70980fe67d 100644 --- a/src/biokbase/narrative/jobs/job.py +++ b/src/biokbase/narrative/jobs/job.py @@ -2,13 +2,12 @@ import json import uuid from pprint import pprint -from typing import List +from typing import Any -from jinja2 import Template - -import biokbase.narrative.clients as clients +from biokbase.narrative import clients from biokbase.narrative.app_util import map_inputs_from_job, map_outputs_from_state from biokbase.narrative.exception_util import transform_job_exception +from jinja2 import Template from .specmanager import SpecManager @@ -29,9 +28,9 @@ "updated", ] -EXCLUDED_JOB_STATE_FIELDS = JOB_INIT_EXCLUDED_JOB_STATE_FIELDS + ["job_input"] +EXCLUDED_JOB_STATE_FIELDS = [*JOB_INIT_EXCLUDED_JOB_STATE_FIELDS, "job_input"] -OUTPUT_STATE_EXCLUDED_JOB_STATE_FIELDS = EXCLUDED_JOB_STATE_FIELDS + ["user", "wsid"] +OUTPUT_STATE_EXCLUDED_JOB_STATE_FIELDS = [*EXCLUDED_JOB_STATE_FIELDS, "user"] EXTRA_JOB_STATE_FIELDS = ["batch_id", "child_jobs"] @@ -47,9 +46,11 @@ # cell_id (str): ID of the cell that initiated the job (if applicable) # job_id (str): the ID of the job # params (dict): input parameters -# user (str): the user who started the job # run_id (str): unique run ID for the job +# status (str): EE2 job status # tag (str): the application tag (dev/beta/release) +# user (str): the user who started the job +# wsid (str): the workspace ID for the job JOB_ATTR_DEFAULTS = { "app_id": None, "app_version": None, @@ -61,8 +62,10 @@ "retry_ids": [], "retry_parent": None, "run_id": None, + "status": "created", "tag": "release", "user": None, + "wsid": None, } JOB_ATTRS = list(JOB_ATTR_DEFAULTS.keys()) @@ -80,14 +83,18 @@ "params", ] ALL_ATTRS = list(set(JOB_ATTRS + JOB_INPUT_ATTRS + NARR_CELL_INFO_ATTRS)) -STATE_ATTRS = list(set(JOB_ATTRS) - set(JOB_INPUT_ATTRS) - set(NARR_CELL_INFO_ATTRS)) class Job: - _job_logs = [] - _acc_state = None # accumulates state - - def __init__(self, ee2_state, extra_data=None, children=None): + _job_logs: list[dict[str, Any]] | None = None + _acc_state: dict[str, Any] | None = None # accumulates state + + def __init__( + self: "Job", + ee2_state: dict[str, Any], + extra_data: dict[str, Any] | None = None, + children: list["Job"] | None = None, + ): """ Parameters: ----------- @@ -101,6 +108,8 @@ def __init__(self, ee2_state, extra_data=None, children=None): children (list): applies to batch parent jobs, Job instances of this Job's child jobs """ # verify job_id ... TODO validate ee2_state + self._job_logs = [] + self._acc_state = {} if ee2_state.get("job_id") is None: raise ValueError("Cannot create a job without a job ID!") @@ -113,49 +122,25 @@ def __init__(self, ee2_state, extra_data=None, children=None): self.children = children @classmethod - def from_job_id(cls, job_id, extra_data=None, children=None): - state = cls.query_ee2_state(job_id, init=True) - return cls(state, extra_data=extra_data, children=children) + def from_job_id( + cls, + job_id: str, + extra_data: dict[str, Any] | None = None, + children: list["Job"] | None = None, + ) -> "Job": + state = cls.query_ee2_states([job_id], init=True) + return cls(state[job_id], extra_data=extra_data, children=children) @classmethod - def from_job_ids(cls, job_ids, return_list=True): + def from_job_ids(cls, job_ids: list[str]) -> list["Job"]: states = cls.query_ee2_states(job_ids, init=True) - jobs = {} - for job_id, state in states.items(): - jobs[job_id] = cls(state) - - if return_list: - return list(jobs.values()) - return jobs - - @staticmethod - def _trim_ee2_state(state: dict, exclude_fields: list) -> None: - if exclude_fields: - for field in exclude_fields: - if field in state: - del state[field] - - @staticmethod - def query_ee2_state( - job_id: str, - init: bool = True, - ) -> dict: - return clients.get("execution_engine2").check_job( - { - "job_id": job_id, - "exclude_fields": ( - JOB_INIT_EXCLUDED_JOB_STATE_FIELDS - if init - else EXCLUDED_JOB_STATE_FIELDS - ), - } - ) + return [cls(state) for state in states.values()] @staticmethod def query_ee2_states( - job_ids: List[str], + job_ids: list[str], init: bool = True, - ) -> dict: + ) -> dict[str, dict]: if not job_ids: return {} @@ -171,10 +156,20 @@ def query_ee2_states( } ) - def __getattr__(self, name): + @staticmethod + def _trim_ee2_state( + state: dict[str, Any], exclude_fields: list[str] | None + ) -> None: + if exclude_fields: + for field in exclude_fields: + if field in state: + del state[field] + + def __getattr__(self: "Job", name: str) -> None | str | int | list | dict: """ Map expected job attributes to paths in stored ee2 state """ + attr = { "app_id": lambda: self._acc_state.get("job_input", {}).get( "app_id", JOB_ATTR_DEFAULTS["app_id"] @@ -193,44 +188,51 @@ def __getattr__(self, name): "cell_id": lambda: self._acc_state.get("job_input", {}) .get("narrative_cell_info", {}) .get("cell_id", JOB_ATTR_DEFAULTS["cell_id"]), - "child_jobs": lambda: copy.deepcopy( - # TODO - # Only batch container jobs have a child_jobs field - # and need the state refresh. - # But KBParallel/KB Batch App jobs may not have the - # batch_job field - self.refresh_state(force_refresh=True).get( - "child_jobs", JOB_ATTR_DEFAULTS["child_jobs"] - ) - if self.batch_job - else self._acc_state.get("child_jobs", JOB_ATTR_DEFAULTS["child_jobs"]) - ), "job_id": lambda: self._acc_state.get("job_id"), "params": lambda: copy.deepcopy( self._acc_state.get("job_input", {}).get( "params", JOB_ATTR_DEFAULTS["params"] ) ), - "retry_ids": lambda: copy.deepcopy( - # Batch container and retry jobs don't have a - # retry_ids field so skip the state refresh - self._acc_state.get("retry_ids", JOB_ATTR_DEFAULTS["retry_ids"]) - if self.batch_job or self.retry_parent - else self.refresh_state(force_refresh=True).get( - "retry_ids", JOB_ATTR_DEFAULTS["retry_ids"] - ) - ), "retry_parent": lambda: self._acc_state.get( "retry_parent", JOB_ATTR_DEFAULTS["retry_parent"] ), "run_id": lambda: self._acc_state.get("job_input", {}) .get("narrative_cell_info", {}) .get("run_id", JOB_ATTR_DEFAULTS["run_id"]), - # TODO: add the status attribute! "tag": lambda: self._acc_state.get("job_input", {}) .get("narrative_cell_info", {}) .get("tag", JOB_ATTR_DEFAULTS["tag"]), "user": lambda: self._acc_state.get("user", JOB_ATTR_DEFAULTS["user"]), + "wsid": lambda: self._acc_state.get("wsid", JOB_ATTR_DEFAULTS["wsid"]), + # the following properties can change whilst a job is in progress + "child_jobs": lambda force_refresh=True: copy.deepcopy( + # N.b. only batch container jobs have a child_jobs field + # and need the state refresh. + # But KBParallel/KB Batch App jobs do not have the + # batch_job field + self.refresh_state(force_refresh=force_refresh).get( + "child_jobs", JOB_ATTR_DEFAULTS["child_jobs"] + ) + if self.batch_job + else self._acc_state.get("child_jobs", JOB_ATTR_DEFAULTS["child_jobs"]) + ), + "retry_ids": lambda force_refresh=True: copy.deepcopy( + # Batch container and retry jobs don't have a + # retry_ids field so skip the state refresh + self._acc_state.get("retry_ids", JOB_ATTR_DEFAULTS["retry_ids"]) + if self.batch_job or self.retry_parent + else self.refresh_state(force_refresh=force_refresh).get( + "retry_ids", JOB_ATTR_DEFAULTS["retry_ids"] + ) + ), + "status": lambda force_refresh=True: copy.deepcopy( + self.refresh_state(force_refresh=force_refresh).get( + "status", JOB_ATTR_DEFAULTS["status"] + ) + if self.in_terminal_state() + else self._acc_state.get("status", JOB_ATTR_DEFAULTS["status"]) + ), } if name not in attr: @@ -238,7 +240,7 @@ def __getattr__(self, name): return attr[name]() - def __setattr__(self, name, value): + def __setattr__(self: "Job", name: str, value: None | str | int | list | dict): if name in ALL_ATTRS: raise AttributeError( "Job attributes must be updated using the `update_state` method" @@ -247,10 +249,10 @@ def __setattr__(self, name, value): object.__setattr__(self, name, value) @property - def app_name(self): + def app_name(self: "Job") -> str: return "batch" if self.batch_job else self.app_spec()["info"]["name"] - def was_terminal(self): + def in_terminal_state(self: "Job") -> bool: """ Checks if last queried ee2 state (or those of its children) was terminal. """ @@ -265,7 +267,7 @@ def was_terminal(self): return self._acc_state.get("status") in TERMINAL_STATUSES - def in_cells(self, cell_ids: List[str]) -> bool: + def in_cells(self: "Job", cell_ids: list[str]) -> bool: """ For job initialization. See if job is associated with present cells @@ -276,14 +278,14 @@ def in_cells(self, cell_ids: List[str]) -> bool: if cell_ids is None: raise ValueError("cell_ids cannot be None") - if self.batch_job: + if self.batch_job and self.children: return any(child_job.cell_id in cell_ids for child_job in self.children) return self.cell_id in cell_ids - def app_spec(self): + def app_spec(self: "Job"): return SpecManager().get_spec(self.app_id, self.tag) - def parameters(self): + def parameters(self: "Job") -> dict[str, Any]: """ Returns the parameters used to start the job. Job tries to use its params field, but if that's None, then it makes a call to EE2. @@ -291,17 +293,18 @@ def parameters(self): If no exception is raised, this only returns the list of parameters, NOT the whole object fetched from ee2.check_job """ - if self.params is not None: - return self.params + if self.params is None: + try: + state = self.query_ee2_states([self.job_id], init=True) + self.update_state(state[self.job_id]) + except Exception as e: + raise Exception( + f"Unable to fetch parameters for job {self.job_id} - {e}" + ) from e - try: - state = self.query_ee2_state(self.job_id, init=True) - self.update_state(state) - return self.params - except Exception as e: - raise Exception(f"Unable to fetch parameters for job {self.job_id} - {e}") + return self.params - def update_state(self, state: dict) -> None: + def update_state(self: "Job", state: dict[str, Any]) -> None: """ Given a state data structure (as emitted by ee2), update the stored state in the job object. All updates to the job state should go through this function. @@ -316,35 +319,35 @@ def update_state(self, state: dict) -> None: + f"job ID: {self.job_id}; state ID: {state['job_id']}" ) - if self._acc_state is None: - self._acc_state = {} - self._acc_state = {**self._acc_state, **state} def refresh_state( - self, force_refresh=False, exclude_fields=JOB_INIT_EXCLUDED_JOB_STATE_FIELDS + self: "Job", + force_refresh: bool = False, + exclude_fields: list[str] | None = JOB_INIT_EXCLUDED_JOB_STATE_FIELDS, ): """ Queries the job service to see the state of the current job. """ - - if force_refresh or not self.was_terminal(): - state = self.query_ee2_state(self.job_id, init=False) - self.update_state(state) + if force_refresh or not self.in_terminal_state(): + state = self.query_ee2_states([self.job_id], init=False) + self.update_state(state[self.job_id]) return self.cached_state(exclude_fields) - def cached_state(self, exclude_fields=None): + def cached_state( + self: "Job", exclude_fields: list[str] | None = None + ) -> dict[str, Any]: """Wrapper for self._acc_state""" state = copy.deepcopy(self._acc_state) self._trim_ee2_state(state, exclude_fields) return state - def output_state(self, state=None, no_refresh=False) -> dict: + def output_state(self: "Job") -> dict[str, str | dict[str, Any]]: """ - :param state: Supplied when the state is queried beforehand from EE2 in bulk, - or when it is retrieved from a cache. If not supplied, must be - queried with self.refresh_state() or self.cached_state() + Request the current job state in a format suitable for sending to the front end. + N.b. this method does not perform a data update. + :return: dict, with structure { @@ -396,24 +399,17 @@ def output_state(self, state=None, no_refresh=False) -> dict: } :rtype: dict """ - if not state: - state = self.cached_state() if no_refresh else self.refresh_state() - else: - self.update_state(state) - state = self.cached_state() - - if state is None: - return self._create_error_state( - "Unable to find current job state. Please try again later, or contact KBase.", - "Unable to return job state", - -1, - ) - + state = self.cached_state() self._trim_ee2_state(state, OUTPUT_STATE_EXCLUDED_JOB_STATE_FIELDS) + if "job_output" not in state: state["job_output"] = {} - for arg in EXTRA_JOB_STATE_FIELDS: - state[arg] = getattr(self, arg) + + if "batch_id" not in state: + state["batch_id"] = self.batch_id + + if "child_jobs" not in state: + state["child_jobs"] = JOB_ATTR_DEFAULTS["child_jobs"] widget_info = None if state.get("finished"): @@ -440,7 +436,9 @@ def output_state(self, state=None, no_refresh=False) -> dict: "outputWidgetInfo": widget_info, } - def show_output_widget(self, state=None): + def show_output_widget( + self: "Job", state: dict[str, Any] | None = None + ) -> str | None: """ For a complete job, returns the job results. An incomplete job throws an exception @@ -460,7 +458,7 @@ def show_output_widget(self, state=None): ) return f"Job is incomplete! It has status '{state['status']}'" - def get_viewer_params(self, state): + def get_viewer_params(self: "Job", state: dict[str, Any]) -> dict[str, Any] | None: """ Maps job state 'result' onto the inputs for a viewer. """ @@ -469,13 +467,15 @@ def get_viewer_params(self, state): (output_widget, widget_params) = self._get_output_info(state) return {"name": output_widget, "tag": self.tag, "params": widget_params} - def _get_output_info(self, state): + def _get_output_info(self: "Job", state: dict[str, Any]): spec = self.app_spec() return map_outputs_from_state( state, map_inputs_from_job(self.parameters(), spec), spec ) - def log(self, first_line=0, num_lines=None): + def log( + self: "Job", first_line: int = 0, num_lines: int | None = None + ) -> tuple[int, list[dict[str, Any]]]: """ Fetch a list of Job logs from the Job Service. This returns a 2-tuple (number of available log lines, list of log lines) @@ -516,16 +516,17 @@ def log(self, first_line=0, num_lines=None): self._job_logs[first_line : first_line + num_lines], ) - def _update_log(self): - log_update = clients.get("execution_engine2").get_job_logs( + def _update_log(self: "Job") -> None: + log_update: dict[str, Any] = clients.get("execution_engine2").get_job_logs( {"job_id": self.job_id, "skip_lines": len(self._job_logs)} ) if log_update["lines"]: self._job_logs = self._job_logs + log_update["lines"] - def _verify_children(self, children: List["Job"]) -> None: + def _verify_children(self: "Job", children: list["Job"] | None) -> None: if not self.batch_job: raise ValueError("Not a batch container job") + if children is None: raise ValueError( "Must supply children when setting children of batch job parent" @@ -535,16 +536,16 @@ def _verify_children(self, children: List["Job"]) -> None: if sorted(inst_child_ids) != sorted(self._acc_state.get("child_jobs")): raise ValueError("Child job id mismatch") - def update_children(self, children: List["Job"]) -> None: + def update_children(self: "Job", children: list["Job"]) -> None: self._verify_children(children) self.children = children def _create_error_state( - self, + self: "Job", error: str, error_msg: str, code: int, - ) -> dict: + ) -> dict[str, Any]: """ Creates an error state to return if 1. the state is missing or unretrievable @@ -571,10 +572,10 @@ def _create_error_state( "created": 0, } - def __repr__(self): + def __repr__(self: "Job") -> str: return "KBase Narrative Job - " + str(self.job_id) - def info(self): + def info(self: "Job") -> None: """ Printed job info """ @@ -590,7 +591,7 @@ def info(self): except BaseException: print("Unable to retrieve current running state!") - def _repr_javascript_(self): + def _repr_javascript_(self: "Job") -> str: """ Called by Jupyter when a Job object is entered into a code cell """ @@ -625,7 +626,7 @@ def _repr_javascript_(self): output_widget_info=json.dumps(output_widget_info), ) - def dump(self): + def dump(self: "Job") -> dict[str, Any]: """ Display job info without having to iterate through the attributes """ diff --git a/src/biokbase/narrative/jobs/jobcomm.py b/src/biokbase/narrative/jobs/jobcomm.py index e674dfece1..e3c3c4f60a 100644 --- a/src/biokbase/narrative/jobs/jobcomm.py +++ b/src/biokbase/narrative/jobs/jobcomm.py @@ -1,13 +1,12 @@ import copy import threading -from typing import List, Union - -from ipykernel.comm import Comm +from typing import Union from biokbase.narrative.common import kblogging from biokbase.narrative.exception_util import JobRequestException, NarrativeException from biokbase.narrative.jobs.jobmanager import JobManager from biokbase.narrative.jobs.util import load_job_constants +from ipykernel.comm import Comm (PARAM, MESSAGE_TYPE) = load_job_constants() @@ -21,6 +20,8 @@ LOOKUP_TIMER_INTERVAL = 5 +INPUT_TYPES = [PARAM["JOB_ID"], PARAM["JOB_ID_LIST"], PARAM["BATCH_ID"]] + class JobRequest: """ @@ -65,9 +66,7 @@ class JobRequest: batch_id str """ - INPUT_TYPES = [PARAM["JOB_ID"], PARAM["JOB_ID_LIST"], PARAM["BATCH_ID"]] - - def __init__(self, rq: dict): + def __init__(self: "JobRequest", rq: dict) -> None: rq = copy.deepcopy(rq) self.raw_request = rq self.msg_id = rq.get("msg_id") # might be useful later? @@ -79,20 +78,20 @@ def __init__(self, rq: dict): raise JobRequestException(MISSING_REQUEST_TYPE_ERR) input_type_count = 0 - for input_type in self.INPUT_TYPES: + for input_type in INPUT_TYPES: if input_type in self.rq_data: input_type_count += 1 if input_type_count > 1: raise JobRequestException(ONE_INPUT_TYPE_ONLY_ERR) @property - def job_id(self): + def job_id(self: "JobRequest"): if PARAM["JOB_ID"] in self.rq_data: return self.rq_data[PARAM["JOB_ID"]] raise JobRequestException(ONE_INPUT_TYPE_ONLY_ERR) @property - def job_id_list(self): + def job_id_list(self: "JobRequest"): if PARAM["JOB_ID_LIST"] in self.rq_data: return self.rq_data[PARAM["JOB_ID_LIST"]] if PARAM["JOB_ID"] in self.rq_data: @@ -100,22 +99,22 @@ def job_id_list(self): raise JobRequestException(ONE_INPUT_TYPE_ONLY_ERR) @property - def batch_id(self): + def batch_id(self: "JobRequest"): if PARAM["BATCH_ID"] in self.rq_data: return self.rq_data[PARAM["BATCH_ID"]] raise JobRequestException(ONE_INPUT_TYPE_ONLY_ERR) - def has_batch_id(self): + def has_batch_id(self: "JobRequest"): return PARAM["BATCH_ID"] in self.rq_data @property - def cell_id_list(self): + def cell_id_list(self: "JobRequest"): if PARAM["CELL_ID_LIST"] in self.rq_data: return self.rq_data[PARAM["CELL_ID_LIST"]] raise JobRequestException(CELLS_NOT_PROVIDED_ERR) @property - def ts(self): + def ts(self: "JobRequest"): """ Optional field sent with STATUS requests indicating to filter out job states in the STATUS response that have not been updated since @@ -171,7 +170,7 @@ def __new__(cls): JobComm.__instance = object.__new__(cls) return JobComm.__instance - def __init__(self): + def __init__(self: "JobComm") -> None: if self._comm is None: self._comm = Comm(target_name="KBaseJobs", data={}) self._comm.on_msg(self._handle_comm_message) @@ -190,7 +189,7 @@ def __init__(self): MESSAGE_TYPE["STOP_UPDATE"]: self._modify_job_updates, } - def _get_job_ids(self, req: JobRequest) -> List[str]: + def _get_job_ids(self: "JobComm", req: JobRequest) -> list[str]: """ Extract the job IDs from a job request object @@ -198,7 +197,7 @@ def _get_job_ids(self, req: JobRequest) -> List[str]: :type req: JobRequest :return: list of job IDs - :rtype: List[str] + :rtype: list[str] """ if req.has_batch_id(): return self._jm.update_batch_job(req.batch_id) @@ -206,9 +205,9 @@ def _get_job_ids(self, req: JobRequest) -> List[str]: return req.job_id_list def start_job_status_loop( - self, + self: "JobComm", init_jobs: bool = False, - cell_list: List[str] = None, + cell_list: list[str] | None = None, ) -> None: """ Starts the job status lookup loop, which runs every LOOKUP_TIMER_INTERVAL seconds. @@ -237,7 +236,7 @@ def start_job_status_loop( if self._lookup_timer is None: self._lookup_job_status_loop() - def stop_job_status_loop(self) -> None: + def stop_job_status_loop(self: "JobComm") -> None: """ Stops the job status lookup loop if it's running. Otherwise, this effectively does nothing. @@ -247,7 +246,7 @@ def stop_job_status_loop(self) -> None: self._lookup_timer = None self._running_lookup_loop = False - def _lookup_job_status_loop(self) -> None: + def _lookup_job_status_loop(self: "JobComm") -> None: """ Run a loop that will look up job info. After running, this spawns a Timer thread on a loop to run itself again. LOOKUP_TIMER_INTERVAL sets the frequency at which the loop runs. @@ -262,7 +261,9 @@ def _lookup_job_status_loop(self) -> None: self._lookup_timer.start() def get_all_job_states( - self, req: JobRequest = None, ignore_refresh_flag: bool = False + self: "JobComm", + req: JobRequest | None = None, + ignore_refresh_flag: bool = False, ) -> dict: """ Fetches status of all jobs in the current workspace and sends them to the front end. @@ -274,7 +275,7 @@ def get_all_job_states( self.send_comm_message(MESSAGE_TYPE["STATUS_ALL"], all_job_states) return all_job_states - def get_job_states_by_cell_id(self, req: JobRequest) -> dict: + def get_job_states_by_cell_id(self: "JobComm", req: JobRequest) -> dict: """ Fetches status of all jobs associated with the given cell ID(s). @@ -303,7 +304,7 @@ def get_job_states_by_cell_id(self, req: JobRequest) -> dict: self.send_comm_message(MESSAGE_TYPE["CELL_JOB_STATUS"], cell_job_states) return cell_job_states - def get_job_info(self, req: JobRequest) -> dict: + def get_job_info(self: "JobComm", req: JobRequest) -> dict: """ Gets job information for a list of job IDs. @@ -332,7 +333,7 @@ def get_job_info(self, req: JobRequest) -> dict: self.send_comm_message(MESSAGE_TYPE["INFO"], job_info) return job_info - def _get_job_states(self, job_id_list: list, ts: int = None) -> dict: + def _get_job_states(self: "JobComm", job_id_list: list, ts: int = None) -> dict: """ Retrieves the job states for the supplied job_ids. @@ -358,7 +359,7 @@ def _get_job_states(self, job_id_list: list, ts: int = None) -> dict: self.send_comm_message(MESSAGE_TYPE["STATUS"], output_states) return output_states - def get_job_state(self, job_id: str) -> dict: + def get_job_state(self: "JobComm", job_id: str) -> dict: """ Retrieve the job state for a single job. @@ -373,7 +374,7 @@ def get_job_state(self, job_id: str) -> dict: """ return self._get_job_states([job_id]) - def get_job_states(self, req: JobRequest) -> dict: + def get_job_states(self: "JobComm", req: JobRequest) -> dict: """ Retrieves the job states for the supplied job_ids. @@ -388,7 +389,7 @@ def get_job_states(self, req: JobRequest) -> dict: job_id_list = self._get_job_ids(req) return self._get_job_states(job_id_list, req.ts) - def _modify_job_updates(self, req: JobRequest) -> dict: + def _modify_job_updates(self: "JobComm", req: JobRequest) -> dict: """ Modifies how many things want to listen to a job update. If this is a request to start a job update, then this starts the update loop that @@ -419,7 +420,7 @@ def _modify_job_updates(self, req: JobRequest) -> dict: self.send_comm_message(MESSAGE_TYPE["STATUS"], output_states) return output_states - def cancel_jobs(self, req: JobRequest) -> dict: + def cancel_jobs(self: "JobComm", req: JobRequest) -> dict: """ Cancel a job or list of jobs. After sending the cancellation request, the job states are refreshed and their new output states returned. @@ -437,7 +438,7 @@ def cancel_jobs(self, req: JobRequest) -> dict: self.send_comm_message(MESSAGE_TYPE["STATUS"], cancel_results) return cancel_results - def retry_jobs(self, req: JobRequest) -> dict: + def retry_jobs(self: "JobComm", req: JobRequest) -> dict: """ Retry a job or list of jobs. @@ -454,7 +455,7 @@ def retry_jobs(self, req: JobRequest) -> dict: self.send_comm_message(MESSAGE_TYPE["RETRY"], retry_results) return retry_results - def get_job_logs(self, req: JobRequest) -> dict: + def get_job_logs(self: "JobComm", req: JobRequest) -> dict: """ Fetch the logs for a job or list of jobs. @@ -476,7 +477,7 @@ def get_job_logs(self, req: JobRequest) -> dict: self.send_comm_message(MESSAGE_TYPE["LOGS"], log_output) return log_output - def _handle_comm_message(self, msg: dict) -> dict: + def _handle_comm_message(self: "JobComm", msg: dict) -> dict: """ Handle incoming messages on the KBaseJobs channel. @@ -510,7 +511,7 @@ def _handle_comm_message(self, msg: dict) -> dict: return self._msg_map[request.request_type](request) - def send_comm_message(self, msg_type: str, content: dict) -> None: + def send_comm_message(self: "JobComm", msg_type: str, content: dict) -> None: """ Sends a ipykernel.Comm message to the KBaseJobs channel with the given msg_type and content. These just get encoded into the message itself. @@ -572,7 +573,7 @@ class exc_to_msg: jc = JobComm() - def __init__(self, req: Union[JobRequest, dict, str] = None): + def __init__(self: "exc_to_msg", req: Union[JobRequest, dict, str] = None): """ req can be several different things because this context manager supports being used in several different places. Generally it is @@ -582,10 +583,10 @@ def __init__(self, req: Union[JobRequest, dict, str] = None): """ self.req = req - def __enter__(self): + def __enter__(self: "exc_to_msg"): pass # nothing to do here - def __exit__(self, exc_type, exc_value, exc_tb): + def __exit__(self: "exc_to_msg", exc_type, exc_value, exc_tb): """ If an exception is caught during execution in the JobComm code, this will send back a comm error message like: diff --git a/src/biokbase/narrative/jobs/jobmanager.py b/src/biokbase/narrative/jobs/jobmanager.py index 30b6be950b..e28d876a40 100644 --- a/src/biokbase/narrative/jobs/jobmanager.py +++ b/src/biokbase/narrative/jobs/jobmanager.py @@ -1,17 +1,16 @@ import copy from datetime import datetime, timedelta, timezone -from typing import List, Tuple +from typing import Any -from IPython.display import HTML -from jinja2 import Template - -import biokbase.narrative.clients as clients -from biokbase.narrative.system import system_variable +from biokbase.narrative import clients from biokbase.narrative.common import kblogging from biokbase.narrative.exception_util import ( JobRequestException, transform_job_exception, ) +from biokbase.narrative.system import system_variable +from IPython.display import HTML +from jinja2 import Template from .job import JOB_INIT_EXCLUDED_JOB_STATE_FIELDS, Job @@ -54,7 +53,7 @@ class JobManager: _log = kblogging.get_logger(__name__) - def __new__(cls): + def __new__(cls) -> "JobManager": if JobManager.__instance is None: JobManager.__instance = object.__new__(cls) return JobManager.__instance @@ -70,14 +69,14 @@ def _reorder_parents_children(states: dict) -> dict: return {job_id: states[job_id] for job_id in ordering} def _check_job_list( - self, input_ids: List[str] = None - ) -> Tuple[List[str], List[str]]: + self: "JobManager", input_ids: list[str] | None = None + ) -> tuple[list[str], list[str]]: """ Deduplicates the input job list, maintaining insertion order. Any jobs not present in self._running_jobs are added to an error list :param input_ids: list of putative job IDs, defaults to [] - :type input_ids: List[str], optional + :type input_ids: list[str], optional :raises JobRequestException: if the input_ids parameter is not a list or or if there are no valid job IDs supplied @@ -85,7 +84,7 @@ def _check_job_list( :return: tuple with items job_ids - valid job IDs error_ids - jobs that the narrative backend does not know about - :rtype: Tuple[List[str], List[str]] + :rtype: tuple[list[str], list[str]] """ if not input_ids: raise JobRequestException(JOBS_MISSING_ERR, input_ids) @@ -107,7 +106,9 @@ def _check_job_list( return job_ids, error_ids - def register_new_job(self, job: Job, refresh: bool = None) -> None: + def register_new_job( + self: "JobManager", job: Job, refresh: bool | None = None + ) -> None: """ Registers a new Job with the manager and stores the job locally. This should only be invoked when a new Job gets started. @@ -120,19 +121,19 @@ def register_new_job(self, job: Job, refresh: bool = None) -> None: kblogging.log_event(self._log, "register_new_job", {"job_id": job.job_id}) if refresh is None: - refresh = not job.was_terminal() + refresh = not job.in_terminal_state() self._running_jobs[job.job_id] = {"job": job, "refresh": refresh} # add the new job to the _jobs_by_cell_id mapping if there is a cell_id present if job.cell_id: - if job.cell_id not in self._jobs_by_cell_id.keys(): + if job.cell_id not in self._jobs_by_cell_id: self._jobs_by_cell_id[job.cell_id] = set() self._jobs_by_cell_id[job.cell_id].add(job.job_id) if job.batch_id: self._jobs_by_cell_id[job.cell_id].add(job.batch_id) - def initialize_jobs(self, cell_ids: List[str] = None) -> None: + def initialize_jobs(self: "JobManager", cell_ids: list[str] | None = None) -> None: """ Initializes this JobManager. This is expected to be run by a running Narrative, and naturally linked to a workspace. @@ -143,7 +144,7 @@ def initialize_jobs(self, cell_ids: List[str] = None) -> None: 4. start the status lookup loop. :param cell_ids: list of cell IDs to filter the existing jobs for, defaults to None - :type cell_ids: List[str], optional + :type cell_ids: list[str], optional :raises NarrativeException: if the call to ee2 fails """ @@ -161,10 +162,10 @@ def initialize_jobs(self, cell_ids: List[str] = None) -> None: except Exception as e: kblogging.log_event(self._log, "init_error", {"err": str(e)}) new_e = transform_job_exception(e, "Unable to initialize jobs") - raise new_e + raise new_e from e self._running_jobs = {} - job_states = self._reorder_parents_children(job_states) + job_states: dict[str, Any] = self._reorder_parents_children(job_states) for job_state in job_states.values(): child_jobs = None if job_state.get("batch_job"): @@ -178,19 +179,19 @@ def initialize_jobs(self, cell_ids: List[str] = None) -> None: # Set to refresh when job is not in terminal state # and when job is present in cells (if given) # and when it is not part of a batch - refresh = not job.was_terminal() and not job.batch_id + refresh = not job.in_terminal_state() and not job.batch_id if cell_ids is not None: refresh = refresh and job.in_cells(cell_ids) self.register_new_job(job, refresh) - def _create_jobs(self, job_ids: List[str]) -> dict: + def _create_jobs(self: "JobManager", job_ids: list[str]) -> dict: """ Given a list of job IDs, creates job objects for them and populates the _running_jobs dictionary. TODO: error handling :param job_ids: job IDs to create job objects for - :type job_ids: List[str] + :type job_ids: list[str] :return: dictionary of job states indexed by job ID :rtype: dict @@ -207,7 +208,7 @@ def _create_jobs(self, job_ids: List[str]) -> dict: return job_states - def get_job(self, job_id: str) -> Job: + def get_job(self: "JobManager", job_id: str) -> Job: """ Retrieve a job from the Job Manager's _running_jobs index. @@ -224,8 +225,8 @@ def get_job(self, job_id: str) -> Job: return self._running_jobs[job_id]["job"] def _construct_job_output_state_set( - self, job_ids: List[str], states: dict = None - ) -> dict: + self: "JobManager", job_ids: list[str], states: dict[str, Any] | None = None + ) -> dict[str, dict[str, Any]]: """ Builds a set of job states for the list of job ids. @@ -263,7 +264,7 @@ def _construct_job_output_state_set( } :param job_ids: list of job IDs - :type job_ids: List[str] + :type job_ids: list[str] :param states: dict of job state data from EE2, indexed by job ID, defaults to None :type states: dict, optional @@ -278,6 +279,9 @@ def _construct_job_output_state_set( if not job_ids: return {} + # ensure states is initialised + if not states: + states = {} output_states = {} jobs_to_lookup = [] @@ -285,17 +289,19 @@ def _construct_job_output_state_set( # These are already post-processed and ready to return. for job_id in job_ids: job = self.get_job(job_id) - if job.was_terminal(): + if job.in_terminal_state(): + # job is already finished, will not change + output_states[job_id] = job.output_state() + elif job_id in states: + job.update_state(states[job_id]) output_states[job_id] = job.output_state() - elif states and job_id in states: - state = states[job_id] - output_states[job_id] = job.output_state(state) else: jobs_to_lookup.append(job_id) fetched_states = {} # Get the rest of states direct from EE2. if jobs_to_lookup: + error_message = "" try: fetched_states = Job.query_ee2_states(jobs_to_lookup, init=False) except Exception as e: @@ -312,16 +318,19 @@ def _construct_job_output_state_set( for job_id in jobs_to_lookup: job = self.get_job(job_id) if job_id in fetched_states: - output_states[job_id] = job.output_state(fetched_states[job_id]) + job.update_state(fetched_states[job_id]) + output_states[job_id] = job.output_state() else: # fetch the current state without updating it - output_states[job_id] = job.output_state({}) + output_states[job_id] = job.output_state() # add an error field with the error message from the failed look up output_states[job_id]["error"] = error_message return output_states - def get_job_states(self, job_ids: List[str], ts: int = None) -> dict: + def get_job_states( + self: "JobManager", job_ids: list[str], ts: int | None = None + ) -> dict: """ Retrieves the job states for the supplied job_ids. @@ -332,7 +341,7 @@ def get_job_states(self, job_ids: List[str], ts: int = None) -> dict: } :param job_ids: job IDs to retrieve job state data for - :type job_ids: List[str] + :type job_ids: list[str] :param ts: timestamp (as generated by time.time_ns()) to filter the jobs, defaults to None :type ts: int, optional @@ -343,7 +352,9 @@ def get_job_states(self, job_ids: List[str], ts: int = None) -> dict: output_states = self._construct_job_output_state_set(job_ids) return self.add_errors_to_results(output_states, error_ids) - def get_all_job_states(self, ignore_refresh_flag=False) -> dict: + def get_all_job_states( + self: "JobManager", ignore_refresh_flag: bool = False + ) -> dict: """ Fetches states for all running jobs. If ignore_refresh_flag is True, then returns states for all jobs this @@ -356,23 +367,26 @@ def get_all_job_states(self, ignore_refresh_flag=False) -> dict: :return: dictionary of job states, indexed by job ID :rtype: dict """ - jobs_to_lookup = [] - # grab the list of running job ids, so we don't run into update-while-iterating problems. - for job_id in self._running_jobs: - if self._running_jobs[job_id]["refresh"] or ignore_refresh_flag: - jobs_to_lookup.append(job_id) - if len(jobs_to_lookup) > 0: + jobs_to_lookup = [ + job_id + for job_id in self._running_jobs + if self._running_jobs[job_id]["refresh"] or ignore_refresh_flag + ] + + if jobs_to_lookup: return self._construct_job_output_state_set(jobs_to_lookup) return {} - def _get_job_ids_by_cell_id(self, cell_id_list: List[str] = None) -> tuple: + def _get_job_ids_by_cell_id( + self: "JobManager", cell_id_list: list[str] | None = None + ) -> tuple: """ Finds jobs with a cell_id in cell_id_list. Mappings of job ID to cell ID are added when new jobs are registered. :param cell_id_list: cell IDs to retrieve job state data for - :type cell_id_list: List[str] + :type cell_id_list: list[str] :return: tuple with two components: job_id_list: list of job IDs associated with the cell IDs supplied @@ -392,12 +406,14 @@ def _get_job_ids_by_cell_id(self, cell_id_list: List[str] = None) -> tuple: job_id_list = set().union(*cell_to_job_mapping.values()) return (job_id_list, cell_to_job_mapping) - def get_job_states_by_cell_id(self, cell_id_list: List[str] = None) -> dict: + def get_job_states_by_cell_id( + self: "JobManager", cell_id_list: list[str] | None = None + ) -> dict: """ Retrieves the job states for jobs associated with the cell_id_list supplied. :param cell_id_list: cell IDs to retrieve job state data for - :type cell_id_list: List[str] + :type cell_id_list: list[str] :return: dictionary with two keys: 'jobs': job states, indexed by job ID @@ -413,7 +429,7 @@ def get_job_states_by_cell_id(self, cell_id_list: List[str] = None) -> dict: return {"jobs": job_states, "mapping": cell_to_job_mapping} - def get_job_info(self, job_ids: List[str]) -> dict: + def get_job_info(self: "JobManager", job_ids: list[str]) -> dict: """ Gets job information for a list of job IDs. @@ -433,7 +449,7 @@ def get_job_info(self, job_ids: List[str]) -> dict: } :param job_ids: job IDs to retrieve job info for - :type job_ids: List[str] + :type job_ids: list[str] :return: job info for each job, indexed by job ID :rtype: dict """ @@ -452,10 +468,10 @@ def get_job_info(self, job_ids: List[str]) -> dict: return self.add_errors_to_results(infos, error_ids) def get_job_logs( - self, + self: "JobManager", job_id: str, first_line: int = 0, - num_lines: int = None, + num_lines: int | None = None, latest: bool = False, ) -> dict: """ @@ -518,27 +534,27 @@ def get_job_logs( logs = logs[first_line:] else: (max_lines, logs) = job.log(first_line=first_line, num_lines=num_lines) - + except Exception as e: return { "job_id": job.job_id, "batch_id": job.batch_id, - "first": first_line, - "latest": latest, - "max_lines": max_lines, - "lines": logs, + "error": e.message, } - except Exception as e: + else: return { "job_id": job.job_id, "batch_id": job.batch_id, - "error": e.message, + "first": first_line, + "latest": latest, + "max_lines": max_lines, + "lines": logs, } def get_job_logs_for_list( - self, - job_id_list: List[str], + self: "JobManager", + job_id_list: list[str], first_line: int = 0, - num_lines: int = None, + num_lines: int | None = None, latest: bool = False, ) -> dict: """ @@ -551,7 +567,7 @@ def get_job_logs_for_list( } :param job_id_list: list of jobs to fetch logs for - :type job_id_list: List[str] + :type job_id_list: list[str] :param first_line: the first line to be returned, defaults to 0 :type first_line: int, optional :param num_lines: number of lines to be returned, defaults to None @@ -570,7 +586,7 @@ def get_job_logs_for_list( return self.add_errors_to_results(output, error_ids) - def cancel_jobs(self, job_id_list: List[str]) -> dict: + def cancel_jobs(self: "JobManager", job_id_list: list[str]) -> dict: """ Cancel a list of jobs and return their new state. After sending the cancellation request, the job states are refreshed and their new output states returned. @@ -588,7 +604,7 @@ def cancel_jobs(self, job_id_list: List[str]) -> dict: } :param job_id_list: job IDs to cancel - :type job_id_list: List[str] + :type job_id_list: list[str] :return: job output states, indexed by job ID :rtype: dict @@ -596,7 +612,7 @@ def cancel_jobs(self, job_id_list: List[str]) -> dict: job_ids, error_ids = self._check_job_list(job_id_list) error_states = {} for job_id in job_ids: - if not self.get_job(job_id).was_terminal(): + if not self.get_job(job_id).in_terminal_state(): error = self._cancel_job(job_id) if error: error_states[job_id] = error.message @@ -608,7 +624,7 @@ def cancel_jobs(self, job_id_list: List[str]) -> dict: return self.add_errors_to_results(job_states, error_ids) - def _cancel_job(self, job_id: str) -> None: + def _cancel_job(self: "JobManager", job_id: str) -> Exception | None: """ Cancel a single job. If an error occurs during cancellation, that error is converted into a NarrativeException and returned to the caller. @@ -633,7 +649,7 @@ def _cancel_job(self, job_id: str) -> None: del self._running_jobs[job_id]["canceling"] return error - def retry_jobs(self, job_id_list: List[str]) -> dict: + def retry_jobs(self: "JobManager", job_id_list: list[str]) -> dict: """ Retry a list of job IDs, returning job output states for the jobs to be retried and the new jobs created by the retry command. @@ -667,7 +683,7 @@ def retry_jobs(self, job_id_list: List[str]) -> dict: } :param job_id_list: list of job IDs - :type job_id_list: List[str] + :type job_id_list: list[str] :raises NarrativeException: if EE2 returns an error from the retry request @@ -676,9 +692,9 @@ def retry_jobs(self, job_id_list: List[str]) -> dict: """ job_ids, error_ids = self._check_job_list(job_id_list) try: - retry_results = clients.get("execution_engine2").retry_jobs( - {"job_ids": job_ids} - ) + retry_results: list[dict[str, Any]] = clients.get( + "execution_engine2" + ).retry_jobs({"job_ids": job_ids}) except Exception as e: raise transform_job_exception(e, "Unable to retry job(s)") from e @@ -706,14 +722,16 @@ def retry_jobs(self, job_id_list: List[str]) -> dict: results_by_job_id[job_id]["error"] = result["error"] return self.add_errors_to_results(results_by_job_id, error_ids) - def add_errors_to_results(self, results: dict, error_ids: List[str]) -> dict: + def add_errors_to_results( + self: "JobManager", results: dict, error_ids: list[str] + ) -> dict: """ Add the generic "not found" error for each job_id in error_ids. :param results: dictionary of job data (output state, info, retry, etc.) indexed by job ID :type results: dict :param error_ids: list of IDs that could not be found - :type error_ids: List[str] + :type error_ids: list[str] :return: input results dictionary augmented by a dictionary containing job ID and a short not found error message for every ID in the error_ids list @@ -726,7 +744,9 @@ def add_errors_to_results(self, results: dict, error_ids: List[str]) -> dict: } return results - def modify_job_refresh(self, job_ids: List[str], update_refresh: bool) -> None: + def modify_job_refresh( + self: "JobManager", job_ids: list[str], update_refresh: bool + ) -> None: """ Modifies how many things want to get the job updated. If this sets the current "refresh" key to be less than 0, it gets reset to 0. @@ -737,7 +757,7 @@ def modify_job_refresh(self, job_ids: List[str], update_refresh: bool) -> None: for job_id in job_ids: self._running_jobs[job_id]["refresh"] = update_refresh - def update_batch_job(self, batch_id: str) -> List[str]: + def update_batch_job(self: "JobManager", batch_id: str) -> list[str]: """ Update a batch job and create child jobs if necessary """ @@ -755,13 +775,13 @@ def update_batch_job(self, batch_id: str) -> List[str]: else: unreg_child_ids.append(job_id) - unreg_child_jobs = [] + unreg_child_jobs: list[Job] = [] if unreg_child_ids: unreg_child_jobs = Job.from_job_ids(unreg_child_ids) for job in unreg_child_jobs: self.register_new_job( job=job, - refresh=not job.was_terminal(), + refresh=not job.in_terminal_state(), ) batch_job.update_children(reg_child_jobs + unreg_child_jobs) diff --git a/src/biokbase/narrative/tests/data/response_data.json b/src/biokbase/narrative/tests/data/response_data.json index ddd42028aa..cfde19e2a8 100644 --- a/src/biokbase/narrative/tests/data/response_data.json +++ b/src/biokbase/narrative/tests/data/response_data.json @@ -408,7 +408,8 @@ "wsName": "wjriehl:1475006266615" }, "tag": "release" - } + }, + "wsid": 10538 }, "job_id": "BATCH_COMPLETED", "outputWidgetInfo": { @@ -446,7 +447,8 @@ "BATCH_RETRY_ERROR" ], "running": 1625755952628, - "status": "error" + "status": "error", + "wsid": 10538 }, "job_id": "BATCH_ERROR_RETRIED", "outputWidgetInfo": null @@ -469,7 +471,8 @@ "job_output": {}, "retry_count": 0, "retry_ids": [], - "status": "created" + "status": "created", + "wsid": 10538 }, "job_id": "BATCH_PARENT", "outputWidgetInfo": null @@ -509,7 +512,8 @@ "wsName": "wjriehl:1475006266615" }, "tag": "release" - } + }, + "wsid": 10538 }, "job_id": "BATCH_RETRY_COMPLETED", "outputWidgetInfo": { @@ -546,7 +550,8 @@ "retry_ids": [], "retry_parent": "BATCH_ERROR_RETRIED", "running": 1625757662469, - "status": "error" + "status": "error", + "wsid": 10538 }, "job_id": "BATCH_RETRY_ERROR", "outputWidgetInfo": null @@ -564,7 +569,8 @@ "retry_ids": [], "retry_parent": "BATCH_TERMINATED_RETRIED", "running": 1625757285899, - "status": "running" + "status": "running", + "wsid": 10538 }, "job_id": "BATCH_RETRY_RUNNING", "outputWidgetInfo": null @@ -583,7 +589,8 @@ "retry_ids": [], "running": 1625755952599, "status": "terminated", - "terminated_code": 0 + "terminated_code": 0, + "wsid": 10538 }, "job_id": "BATCH_TERMINATED", "outputWidgetInfo": null @@ -605,7 +612,8 @@ ], "running": 1625755952563, "status": "terminated", - "terminated_code": 0 + "terminated_code": 0, + "wsid": 10538 }, "job_id": "BATCH_TERMINATED_RETRIED", "outputWidgetInfo": null @@ -641,7 +649,8 @@ "report_window_line_height": "16" }, "tag": "dev" - } + }, + "wsid": 54321 }, "job_id": "JOB_COMPLETED", "outputWidgetInfo": { @@ -664,7 +673,8 @@ "job_output": {}, "retry_count": 0, "retry_ids": [], - "status": "created" + "status": "created", + "wsid": 55555 }, "job_id": "JOB_CREATED", "outputWidgetInfo": null @@ -690,7 +700,8 @@ "retry_count": 0, "retry_ids": [], "running": 1642431847689, - "status": "error" + "status": "error", + "wsid": 65854 }, "job_id": "JOB_ERROR", "outputWidgetInfo": null @@ -707,7 +718,8 @@ "retry_count": 0, "retry_ids": [], "running": 1642521775867, - "status": "running" + "status": "running", + "wsid": 54321 }, "job_id": "JOB_RUNNING", "outputWidgetInfo": null @@ -725,7 +737,8 @@ "retry_count": 0, "retry_ids": [], "status": "terminated", - "terminated_code": 0 + "terminated_code": 0, + "wsid": 54321 }, "job_id": "JOB_TERMINATED", "outputWidgetInfo": null @@ -780,7 +793,8 @@ "wsName": "wjriehl:1475006266615" }, "tag": "release" - } + }, + "wsid": 10538 }, "job_id": "BATCH_COMPLETED", "outputWidgetInfo": { @@ -821,7 +835,8 @@ "BATCH_RETRY_ERROR" ], "running": 1625755952628, - "status": "error" + "status": "error", + "wsid": 10538 }, "job_id": "BATCH_ERROR_RETRIED", "outputWidgetInfo": null @@ -849,7 +864,8 @@ "retry_ids": [], "retry_parent": "BATCH_ERROR_RETRIED", "running": 1625757662469, - "status": "error" + "status": "error", + "wsid": 10538 }, "job_id": "BATCH_RETRY_ERROR", "outputWidgetInfo": null @@ -876,7 +892,8 @@ "job_output": {}, "retry_count": 0, "retry_ids": [], - "status": "created" + "status": "created", + "wsid": 10538 }, "job_id": "BATCH_PARENT", "outputWidgetInfo": null @@ -920,7 +937,8 @@ "wsName": "wjriehl:1475006266615" }, "tag": "release" - } + }, + "wsid": 10538 }, "job_id": "BATCH_RETRY_COMPLETED", "outputWidgetInfo": { @@ -961,7 +979,8 @@ "retry_ids": [], "retry_parent": "BATCH_ERROR_RETRIED", "running": 1625757662469, - "status": "error" + "status": "error", + "wsid": 10538 }, "job_id": "BATCH_RETRY_ERROR", "outputWidgetInfo": null @@ -983,7 +1002,8 @@ "retry_ids": [], "retry_parent": "BATCH_TERMINATED_RETRIED", "running": 1625757285899, - "status": "running" + "status": "running", + "wsid": 10538 }, "job_id": "BATCH_RETRY_RUNNING", "outputWidgetInfo": null @@ -1006,7 +1026,8 @@ "retry_ids": [], "running": 1625755952599, "status": "terminated", - "terminated_code": 0 + "terminated_code": 0, + "wsid": 10538 }, "job_id": "BATCH_TERMINATED", "outputWidgetInfo": null @@ -1031,7 +1052,8 @@ ], "running": 1625755952563, "status": "terminated", - "terminated_code": 0 + "terminated_code": 0, + "wsid": 10538 }, "job_id": "BATCH_TERMINATED_RETRIED", "outputWidgetInfo": null @@ -1072,7 +1094,8 @@ "wsName": "wjriehl:1475006266615" }, "tag": "release" - } + }, + "wsid": 10538 }, "job_id": "BATCH_RETRY_COMPLETED", "outputWidgetInfo": { @@ -1122,7 +1145,8 @@ "report_window_line_height": "16" }, "tag": "dev" - } + }, + "wsid": 54321 }, "job_id": "JOB_COMPLETED", "outputWidgetInfo": { @@ -1149,7 +1173,8 @@ "job_output": {}, "retry_count": 0, "retry_ids": [], - "status": "created" + "status": "created", + "wsid": 55555 }, "job_id": "JOB_CREATED", "outputWidgetInfo": null @@ -1179,7 +1204,8 @@ "retry_count": 0, "retry_ids": [], "running": 1642431847689, - "status": "error" + "status": "error", + "wsid": 65854 }, "job_id": "JOB_ERROR", "outputWidgetInfo": null @@ -1200,7 +1226,8 @@ "retry_count": 0, "retry_ids": [], "running": 1642521775867, - "status": "running" + "status": "running", + "wsid": 54321 }, "job_id": "JOB_RUNNING", "outputWidgetInfo": null @@ -1222,7 +1249,8 @@ "retry_count": 0, "retry_ids": [], "status": "terminated", - "terminated_code": 0 + "terminated_code": 0, + "wsid": 54321 }, "job_id": "JOB_TERMINATED", "outputWidgetInfo": null diff --git a/src/biokbase/narrative/tests/generate_test_results.py b/src/biokbase/narrative/tests/generate_test_results.py index 0c389bdb1c..ee1e1fd5b6 100644 --- a/src/biokbase/narrative/tests/generate_test_results.py +++ b/src/biokbase/narrative/tests/generate_test_results.py @@ -37,11 +37,13 @@ TEST_SPECS[tag] = spec_dict -def get_test_spec(tag, app_id): +def get_test_spec(tag: str, app_id: str) -> dict[str, dict]: return copy.deepcopy(TEST_SPECS[tag][app_id]) -def generate_mappings(all_jobs): +def generate_mappings( + all_jobs: dict[str, dict] +) -> tuple[dict[str, str], dict[str, dict], set[dict]]: # collect retried jobs and generate the cell-to-job mapping retried_jobs = {} jobs_by_cell_id = {} @@ -57,7 +59,7 @@ def generate_mappings(all_jobs): # add the new job to the jobs_by_cell_id mapping if there is a cell_id present if cell_id: - if cell_id not in jobs_by_cell_id.keys(): + if cell_id not in jobs_by_cell_id: jobs_by_cell_id[cell_id] = set() jobs_by_cell_id[cell_id].add(job["job_id"]) @@ -67,7 +69,7 @@ def generate_mappings(all_jobs): return (retried_jobs, jobs_by_cell_id, batch_jobs) -def _generate_job_output(job_id): +def _generate_job_output(job_id: str) -> dict[str, str | dict]: state = get_test_job(job_id) widget_info = state.get("widget_info") @@ -93,14 +95,14 @@ def _generate_job_output(job_id): return {"job_id": job_id, "jobState": state, "outputWidgetInfo": widget_info} -def generate_bad_jobs(): +def generate_bad_jobs() -> dict[str, dict]: return { job_id: {"job_id": job_id, "error": generate_error(job_id, "not_found")} for job_id in BAD_JOBS } -def generate_job_output_state(all_jobs): +def generate_job_output_state(all_jobs: dict[str, dict]) -> dict[str, dict]: """ Generate the expected output from a `job_status` request """ @@ -110,14 +112,13 @@ def generate_job_output_state(all_jobs): return job_status -def generate_job_info(all_jobs): +def generate_job_info(all_jobs: dict[str, dict]) -> dict[str, dict]: """ Expected output from a `job_info` request """ job_info = generate_bad_jobs() for job_id in all_jobs: test_job = get_test_job(job_id) - job_id = test_job.get("job_id") app_id = test_job.get("job_input", {}).get("app_id") tag = ( test_job.get("job_input", {}) @@ -145,7 +146,9 @@ def generate_job_info(all_jobs): return job_info -def generate_job_retries(all_jobs, retried_jobs): +def generate_job_retries( + all_jobs: dict[str, dict], retried_jobs: dict[str, str] +) -> dict[str, dict]: """ Expected output from a `retry_job` request """ @@ -173,14 +176,11 @@ def generate_job_retries(all_jobs, retried_jobs): return job_retries -def log_gen(n_lines): - lines = [] - for i in range(n_lines): - lines.append({"is_error": 0, "line": f"This is line {str(i+1)}"}) - return lines +def log_gen(n_lines: int) -> list[dict[str, int | str]]: + return [{"is_error": 0, "line": f"This is line {i+1}"} for i in range(n_lines)] -def generate_job_logs(all_jobs): +def generate_job_logs(all_jobs: dict[str, dict]) -> dict[str, dict]: """ Expected output from a `job_logs` request. Note that only completed jobs have logs in this case. """ @@ -211,11 +211,9 @@ def generate_job_logs(all_jobs): # mapping of cell IDs to jobs INVALID_CELL_ID = "invalid_cell_id" -TEST_CELL_ID_LIST = list(JOBS_BY_CELL_ID.keys()) + [INVALID_CELL_ID] +TEST_CELL_ID_LIST = [*list(JOBS_BY_CELL_ID.keys()), INVALID_CELL_ID] # mapping expected as output from get_job_states_by_cell_id -TEST_CELL_IDs = { - cell_id: list(JOBS_BY_CELL_ID[cell_id]) for cell_id in JOBS_BY_CELL_ID.keys() -} +TEST_CELL_IDs = {cell_id: list(JOBS_BY_CELL_ID[cell_id]) for cell_id in JOBS_BY_CELL_ID} TEST_CELL_IDs[INVALID_CELL_ID] = [] @@ -230,7 +228,7 @@ def generate_job_logs(all_jobs): config.write_json_file(RESPONSE_DATA_FILE, ALL_RESPONSE_DATA) -def main(args=None): +def main(args: list[str] | None = None) -> None: if args and args[0] == "--force" or not os.path.exists(RESPONSE_DATA_FILE): config.write_json_file(RESPONSE_DATA_FILE, ALL_RESPONSE_DATA) diff --git a/src/biokbase/narrative/tests/job_test_constants.py b/src/biokbase/narrative/tests/job_test_constants.py index 3fa741d867..721f73ca1a 100644 --- a/src/biokbase/narrative/tests/job_test_constants.py +++ b/src/biokbase/narrative/tests/job_test_constants.py @@ -95,7 +95,7 @@ def get_test_jobs(job_ids): BATCH_PARENT_CHILDREN = [BATCH_PARENT] + BATCH_CHILDREN -JOBS_TERMINALITY = { +JOB_TERMINAL_STATE = { job_id: TEST_JOBS[job_id]["status"] in TERMINAL_STATUSES for job_id in TEST_JOBS.keys() } @@ -103,7 +103,7 @@ def get_test_jobs(job_ids): TERMINAL_JOBS = [] ACTIVE_JOBS = [] REFRESH_STATE = {} -for key, value in JOBS_TERMINALITY.items(): +for key, value in JOB_TERMINAL_STATE.items(): if value: TERMINAL_JOBS.append(key) else: diff --git a/src/biokbase/narrative/tests/narrative_mock/mockclients.py b/src/biokbase/narrative/tests/narrative_mock/mockclients.py index e86e912af8..a9a97b2175 100644 --- a/src/biokbase/narrative/tests/narrative_mock/mockclients.py +++ b/src/biokbase/narrative/tests/narrative_mock/mockclients.py @@ -18,7 +18,7 @@ ) from biokbase.workspace.baseclient import ServerError -from ..util import ConfigTests +from src.biokbase.narrative.tests.util import ConfigTests RANDOM_DATE = "2018-08-10T16:47:36+0000" RANDOM_TYPE = "ModuleA.TypeA-1.0" @@ -29,11 +29,11 @@ NARR_HASH = "278abf8f0dbf8ab5ce349598a8674a6e" -def generate_ee2_error(fn): +def generate_ee2_error(fn: str) -> EEServerError: return EEServerError("JSONRPCError", -32000, fn + " failed") -def get_nar_obj(i): +def get_nar_obj(i: int) -> list[str | int | dict[str, str]]: return [ i, "My_Test_Narrative", @@ -291,9 +291,7 @@ def check_job(self, params): job_id = params.get("job_id") if not job_id: return {} - job_state = self.job_state_data.get( - job_id, {"job_id": job_id, "status": "unmocked"} - ) + job_state = self.job_state_data.get(job_id, {"job_id": job_id}) if "exclude_fields" in params: for f in params["exclude_fields"]: if f in job_state: @@ -546,6 +544,9 @@ def check_workspace_jobs(self, params): def check_job(self, params): raise generate_ee2_error("check_job") + def check_jobs(self, params): + raise generate_ee2_error("check_jobs") + def cancel_job(self, params): raise generate_ee2_error(MESSAGE_TYPE["CANCEL"]) @@ -618,7 +619,7 @@ def __exit__(self, exc_type, exc_value, traceback): assert ( getattr(self.target, self.method_name) == self.called - ), f"Method {self.target.__name__}.{self.method_name} was modified during context managment with {self.__class__.name}" + ), f"Method {self.target.__name__}.{self.method_name} was modified during context management with {self.__class__.name}" setattr(self.target, self.method_name, self.orig_method) self.assert_called(self.call_status) diff --git a/src/biokbase/narrative/tests/test_appmanager.py b/src/biokbase/narrative/tests/test_appmanager.py index f99341ef1c..ef86f10110 100644 --- a/src/biokbase/narrative/tests/test_appmanager.py +++ b/src/biokbase/narrative/tests/test_appmanager.py @@ -316,6 +316,8 @@ def test_run_app__from_gui_cell(self, auth, c): c.return_value.send_comm_message, False, cell_id=cell_id ) + # N.b. the following three tests contact the workspace + # src/config/config.json must be set to use the CI configuration @mock.patch(JOB_COMM_MOCK) def test_run_app__bad_id(self, c): c.return_value.send_comm_message = MagicMock() @@ -507,6 +509,8 @@ def test_run_legacy_batch_app__gui_cell(self, auth, c): c.return_value.send_comm_message, False, cell_id=cell_id ) + # N.b. the following three tests contact the workspace + # src/config/config.json must be set to use the CI configuration @mock.patch(JOB_COMM_MOCK) def test_run_legacy_batch_app__bad_id(self, c): c.return_value.send_comm_message = MagicMock() @@ -1094,6 +1098,8 @@ def test_generate_input_bad(self): with self.assertRaises(ValueError): self.am._generate_input({"symbols": -1}) + # N.b. the following test contacts the workspace + # src/config/config.json must be set to use the CI configuration def test_transform_input_good(self): ws_name = self.public_ws test_data = [ diff --git a/src/biokbase/narrative/tests/test_job.py b/src/biokbase/narrative/tests/test_job.py index 52a7a755b9..2f33cf0002 100644 --- a/src/biokbase/narrative/tests/test_job.py +++ b/src/biokbase/narrative/tests/test_job.py @@ -30,7 +30,7 @@ JOB_CREATED, JOB_RUNNING, JOB_TERMINATED, - JOBS_TERMINALITY, + JOB_TERMINAL_STATE, MAX_LOG_LINES, TERMINAL_JOBS, get_test_job, @@ -109,6 +109,7 @@ def create_attrs_from_ee2(job_id): "run_id": narr_cell_info.get("run_id", JOB_ATTR_DEFAULTS["run_id"]), "tag": narr_cell_info.get("tag", JOB_ATTR_DEFAULTS["tag"]), "user": state.get("user", JOB_ATTR_DEFAULTS["user"]), + "wsid": state.get("wsid", JOB_ATTR_DEFAULTS["wsid"]), } @@ -140,17 +141,14 @@ def get_batch_family_jobs(return_list=False): As invoked in appmanager's run_app_batch, i.e., with from_job_id(s) """ - child_jobs = Job.from_job_ids(BATCH_CHILDREN, return_list=True) + child_jobs = Job.from_job_ids(BATCH_CHILDREN) batch_job = Job.from_job_id(BATCH_PARENT, children=child_jobs) if return_list: - return [batch_job] + child_jobs + return [batch_job, *child_jobs] return { BATCH_PARENT: batch_job, - **{ - child_id: child_job - for child_id, child_job in zip(BATCH_CHILDREN, child_jobs) - }, + **{job.job_id: job for job in child_jobs}, } @@ -186,11 +184,11 @@ def check_jobs_equal(self, jobl, jobr): def check_job_attrs_custom(self, job, exp_attr=None): if not exp_attr: exp_attr = {} - attr = dict(JOB_ATTR_DEFAULTS) + attr = copy.deepcopy(JOB_ATTR_DEFAULTS) attr.update(exp_attr) with mock.patch(CLIENTS, get_mock_client): for name, value in attr.items(): - self.assertEqual(value, getattr(job, name)) + assert value == getattr(job, name) def check_job_attrs(self, job, job_id, exp_attrs=None, skip_state=False): # TODO check _acc_state full vs pruned, extra_data @@ -240,10 +238,10 @@ def test_job_init__from_job_ids(self): job_ids.remove(BATCH_PARENT) with mock.patch(CLIENTS, get_mock_client): - jobs = Job.from_job_ids(job_ids, return_list=False) + jobs = Job.from_job_ids(job_ids) - for job_id, job in jobs.items(): - self.check_job_attrs(job, job_id) + for job in jobs: + self.check_job_attrs(job, job.job_id) def test_job_init__extra_state(self): """ @@ -311,6 +309,7 @@ def test_job_from_state__custom(self): "job_id": "0123456789abcdef", "params": params, "run_id": JOB_ATTR_DEFAULTS["run_id"], + "status": JOB_ATTR_DEFAULTS["status"], "tag": JOB_ATTR_DEFAULTS["tag"], "user": "the_user", } @@ -344,9 +343,9 @@ def test_refresh_state__non_terminal(self): """ # ee2_state is fully populated (includes job_input, no job_output) job = create_job_from_ee2(JOB_CREATED) - self.assertFalse(job.was_terminal()) + self.assertFalse(job.in_terminal_state()) state = job.refresh_state() - self.assertFalse(job.was_terminal()) + self.assertFalse(job.in_terminal_state()) self.assertEqual(state["status"], "created") expected_state = create_state_from_ee2(JOB_CREATED) @@ -357,7 +356,7 @@ def test_refresh_state__terminal(self): test that a completed job emits its state without calling check_job """ job = create_job_from_ee2(JOB_COMPLETED) - self.assertTrue(job.was_terminal()) + self.assertTrue(job.in_terminal_state()) expected = create_state_from_ee2(JOB_COMPLETED) with assert_obj_method_called(MockClients, "check_job", call_status=False): @@ -368,53 +367,33 @@ def test_refresh_state__terminal(self): @mock.patch(CLIENTS, get_failing_mock_client) def test_refresh_state__raise_exception(self): """ - test that the correct exception is thrown if check_job cannot be called + test that the correct exception is thrown if check_jobs cannot be called """ job = create_job_from_ee2(JOB_CREATED) - self.assertFalse(job.was_terminal()) - with self.assertRaisesRegex(ServerError, "check_job failed"): + self.assertFalse(job.in_terminal_state()) + with self.assertRaisesRegex(ServerError, "check_jobs failed"): job.refresh_state() - def test_refresh_state__returns_none(self): - def mock_state(self, state=None): - return None - - job = create_job_from_ee2(JOB_CREATED) - expected = { - "status": "error", - "error": { - "code": -1, - "name": "Job Error", - "message": "Unable to return job state", - "error": "Unable to find current job state. Please try again later, or contact KBase.", - }, - "errormsg": "Unable to return job state", - "error_code": -1, - "job_id": job.job_id, - "cell_id": job.cell_id, - "run_id": job.run_id, - "created": 0, - } - - with mock.patch.object(Job, "refresh_state", mock_state): - state = job.output_state() - self.assertEqual(expected, state) - # TODO: improve this test def test_job_update__no_state(self): """ test that without a state object supplied, the job state is unchanged """ job = create_job_from_ee2(JOB_CREATED) - self.assertFalse(job.was_terminal()) + job_copy = copy.deepcopy(job) + assert job.cached_state() == job_copy.cached_state() # should fail with error 'state must be a dict' with self.assertRaisesRegex(TypeError, "state must be a dict"): job.update_state(None) - self.assertFalse(job.was_terminal()) + assert job.cached_state() == job_copy.cached_state() job.update_state({}) - self.assertFalse(job.was_terminal()) + assert job.cached_state() == job_copy.cached_state() + + job.update_state({"wsid": "something or other"}) + assert job.cached_state() != job_copy.cached_state() + assert job.wsid == "something or other" @mock.patch(CLIENTS, get_mock_client) def test_job_update__invalid_job_id(self): @@ -574,7 +553,7 @@ def test_parent_children__ok(self): children=child_jobs, ) - self.assertFalse(parent_job.was_terminal()) + self.assertFalse(parent_job.in_terminal_state()) # Make all child jobs completed with mock.patch.object( @@ -585,7 +564,7 @@ def test_parent_children__ok(self): for child_job in child_jobs: child_job.refresh_state(force_refresh=True) - self.assertTrue(parent_job.was_terminal()) + self.assertTrue(parent_job.in_terminal_state()) def test_parent_children__fail(self): parent_state = create_state_from_ee2(BATCH_PARENT) @@ -645,19 +624,19 @@ def test_get_viewer_params__batch_parent(self): self.assertIsNone(out) @mock.patch(CLIENTS, get_mock_client) - def test_query_job_state(self): + def test_query_job_states_single_job(self): for job_id in ALL_JOBS: exp = create_state_from_ee2( job_id, exclude_fields=JOB_INIT_EXCLUDED_JOB_STATE_FIELDS ) - got = Job.query_ee2_state(job_id, init=True) - self.assertEqual(exp, got) + got = Job.query_ee2_states([job_id], init=True) + assert exp == got[job_id] exp = create_state_from_ee2( job_id, exclude_fields=EXCLUDED_JOB_STATE_FIELDS ) - got = Job.query_ee2_state(job_id, init=False) - self.assertEqual(exp, got) + got = Job.query_ee2_states([job_id], init=False) + assert exp == got[job_id] @mock.patch(CLIENTS, get_mock_client) def test_query_job_states(self): @@ -731,18 +710,18 @@ def mock_check_job(self_, params): with mock.patch.object(MockClients, "check_job", mock_check_job): self.check_job_attrs(job, job_id, {"child_jobs": self.NEW_CHILD_JOBS}) - def test_was_terminal(self): + def test_in_terminal_state(self): all_jobs = get_all_jobs() for job_id, job in all_jobs.items(): - self.assertEqual(JOBS_TERMINALITY[job_id], job.was_terminal()) + self.assertEqual(JOB_TERMINAL_STATE[job_id], job.in_terminal_state()) @mock.patch(CLIENTS, get_mock_client) - def test_was_terminal__batch(self): + def test_in_terminal_state__batch(self): batch_fam = get_batch_family_jobs(return_list=True) batch_job, child_jobs = batch_fam[0], batch_fam[1:] - self.assertFalse(batch_job.was_terminal()) + self.assertFalse(batch_job.in_terminal_state()) def mock_check_job(self_, params): self.assertTrue(params["job_id"] in BATCH_CHILDREN) @@ -752,7 +731,7 @@ def mock_check_job(self_, params): for job in child_jobs: job.refresh_state(force_refresh=True) - self.assertTrue(batch_job.was_terminal()) + self.assertTrue(batch_job.in_terminal_state()) def test_in_cells(self): all_jobs = get_all_jobs() diff --git a/src/biokbase/narrative/tests/test_jobcomm.py b/src/biokbase/narrative/tests/test_jobcomm.py index ce0fc1fea9..ca49991b2e 100644 --- a/src/biokbase/narrative/tests/test_jobcomm.py +++ b/src/biokbase/narrative/tests/test_jobcomm.py @@ -618,6 +618,14 @@ def test_get_job_states__job_id_list__ee2_error(self): exc = Exception("Test exception") exc_message = str(exc) + expected = { + job_id: copy.deepcopy(ALL_RESPONSE_DATA[STATUS][job_id]) + for job_id in ALL_JOBS + } + for job_id in ACTIVE_JOBS: + # add in the ee2_error message + expected[job_id]["error"] = exc_message + def mock_check_jobs(params): raise exc @@ -628,14 +636,6 @@ def mock_check_jobs(params): self.jc._handle_comm_message(req_dict) msg = self.jc._comm.last_message - expected = { - job_id: copy.deepcopy(ALL_RESPONSE_DATA[STATUS][job_id]) - for job_id in ALL_JOBS - } - for job_id in ACTIVE_JOBS: - # add in the ee2_error message - expected[job_id]["error"] = exc_message - self.assertEqual( { "msg_type": STATUS, diff --git a/src/biokbase/narrative/tests/test_jobmanager.py b/src/biokbase/narrative/tests/test_jobmanager.py index d40a9e48c1..5f7be32d0f 100644 --- a/src/biokbase/narrative/tests/test_jobmanager.py +++ b/src/biokbase/narrative/tests/test_jobmanager.py @@ -107,7 +107,7 @@ def test_initialize_jobs(self): terminal_ids = [ job_id for job_id, d in self.jm._running_jobs.items() - if d["job"].was_terminal() + if d["job"].in_terminal_state() ] self.assertEqual( set(TERMINAL_JOBS), @@ -233,12 +233,6 @@ def test__construct_job_output_state_set__ee2_error(self): exc = Exception("Test exception") exc_msg = str(exc) - def mock_check_jobs(params): - raise exc - - with mock.patch.object(MockClients, "check_jobs", side_effect=mock_check_jobs): - job_states = self.jm._construct_job_output_state_set(ALL_JOBS) - expected = { job_id: copy.deepcopy(ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][job_id]) for job_id in ALL_JOBS @@ -248,6 +242,12 @@ def mock_check_jobs(params): # expect there to be an error message added expected[job_id]["error"] = exc_msg + def mock_check_jobs(params): + raise exc + + with mock.patch.object(MockClients, "check_jobs", side_effect=mock_check_jobs): + job_states = self.jm._construct_job_output_state_set(ALL_JOBS) + self.assertEqual( expected, job_states, @@ -403,8 +403,8 @@ def test_cancel_jobs__bad_inputs(self): def test_cancel_jobs__job_already_finished(self): self.assertEqual(get_test_job(JOB_COMPLETED)["status"], "completed") self.assertEqual(get_test_job(JOB_TERMINATED)["status"], "terminated") - self.assertTrue(self.jm.get_job(JOB_COMPLETED).was_terminal()) - self.assertTrue(self.jm.get_job(JOB_TERMINATED).was_terminal()) + self.assertTrue(self.jm.get_job(JOB_COMPLETED).in_terminal_state()) + self.assertTrue(self.jm.get_job(JOB_TERMINATED).in_terminal_state()) job_id_list = [JOB_COMPLETED, JOB_TERMINATED] with mock.patch( "biokbase.narrative.jobs.jobmanager.JobManager._cancel_job" diff --git a/src/biokbase/narrative/tests/test_narrativeio.py b/src/biokbase/narrative/tests/test_narrativeio.py index 428f3556af..cfb8e962a6 100644 --- a/src/biokbase/narrative/tests/test_narrativeio.py +++ b/src/biokbase/narrative/tests/test_narrativeio.py @@ -5,10 +5,8 @@ import unittest from unittest.mock import patch -from tornado.web import HTTPError - import biokbase.auth -import biokbase.narrative.clients as clients +from biokbase.narrative import clients from biokbase.narrative.common.exceptions import WorkspaceError from biokbase.narrative.common.narrative_ref import NarrativeRef from biokbase.narrative.common.url_config import URLS @@ -16,27 +14,26 @@ LIST_OBJECTS_FIELDS, KBaseWSManagerMixin, ) +from tornado.web import HTTPError from . import util from .narrative_mock.mockclients import MockClients, get_mock_client, get_nar_obj __author__ = "Bill Riehl " -metadata_fields = set( - [ - "objid", - "name", - "type", - "save_date", - "ver", - "saved_by", - "wsid", - "workspace", - "chsum", - "size", - "meta", - ] -) +metadata_fields = { + "objid", + "name", + "type", + "save_date", + "ver", + "saved_by", + "wsid", + "workspace", + "chsum", + "size", + "meta", +} HAS_TEST_TOKEN = False diff --git a/src/biokbase/narrative/tests/test_widgetmanager.py b/src/biokbase/narrative/tests/test_widgetmanager.py index 8fd73085fd..b075cfc457 100644 --- a/src/biokbase/narrative/tests/test_widgetmanager.py +++ b/src/biokbase/narrative/tests/test_widgetmanager.py @@ -74,6 +74,8 @@ def test_show_output_widget_bad(self): check_widget=True, ) + # N.b. the following test contacts the workspace + # src/config/config.json must be set to use the CI configuration def test_show_advanced_viewer_widget(self): title = "Widget Viewer" cell_id = "abcde"