From ecf0c31cf4be9ceb60b581c0e0c640b91856b3ee Mon Sep 17 00:00:00 2001 From: ialarmedalien Date: Thu, 9 Nov 2023 13:55:36 -0800 Subject: [PATCH 1/3] add user and wsid to job output Removing some comments and an extraneous method; general tidying --- src/biokbase/narrative/jobs/appmanager.py | 127 +++++---- src/biokbase/narrative/jobs/job.py | 259 +++++++++--------- src/biokbase/narrative/jobs/jobcomm.py | 71 ++--- src/biokbase/narrative/jobs/jobmanager.py | 168 +++++++----- .../narrative/tests/data/response_data.json | 84 ++++-- .../narrative/tests/generate_test_results.py | 38 ++- .../narrative/tests/job_test_constants.py | 4 +- .../tests/narrative_mock/mockclients.py | 15 +- .../narrative/tests/test_appmanager.py | 6 + src/biokbase/narrative/tests/test_job.py | 95 +++---- src/biokbase/narrative/tests/test_jobcomm.py | 16 +- .../narrative/tests/test_jobmanager.py | 18 +- .../narrative/tests/test_narrativeio.py | 33 +-- .../narrative/tests/test_widgetmanager.py | 2 + 14 files changed, 491 insertions(+), 445 deletions(-) diff --git a/src/biokbase/narrative/jobs/appmanager.py b/src/biokbase/narrative/jobs/appmanager.py index 8480c57d9b..12a5aee0b0 100644 --- a/src/biokbase/narrative/jobs/appmanager.py +++ b/src/biokbase/narrative/jobs/appmanager.py @@ -6,7 +6,8 @@ import random import re import traceback -from typing import Callable, Optional +from collections.abc import Callable +from typing import Any from biokbase import auth from biokbase.narrative import clients @@ -19,14 +20,12 @@ ) from biokbase.narrative.common import kblogging from biokbase.narrative.exception_util import transform_job_exception -from biokbase.narrative.system import strict_system_variable, system_variable - -from biokbase.narrative.widgetmanager import WidgetManager - from biokbase.narrative.jobs import specmanager from biokbase.narrative.jobs.job import Job from biokbase.narrative.jobs.jobcomm import MESSAGE_TYPE, JobComm from biokbase.narrative.jobs.jobmanager import JobManager +from biokbase.narrative.system import strict_system_variable, system_variable +from biokbase.narrative.widgetmanager import WidgetManager """ A module for managing apps, specs, requirements, and for starting jobs. @@ -45,7 +44,7 @@ def timestamp() -> str: return datetime.datetime.utcnow().isoformat() + "Z" -def _app_error_wrapper(app_func: Callable) -> any: +def _app_error_wrapper(app_func: Callable) -> Callable: """ This is a decorator meant to wrap any of the `run_app*` methods here. It captures any raised exception, formats it into a message that can be sent @@ -122,7 +121,7 @@ def __new__(cls): AppManager.__instance._comm = None return AppManager.__instance - def reload(self): + def reload(self: "AppManager"): """ Reloads all app specs into memory from the App Catalog. Any outputs of app_usage, app_description, or available_apps @@ -130,7 +129,7 @@ def reload(self): """ self.spec_manager.reload() - def app_usage(self, app_id, tag="release"): + def app_usage(self: "AppManager", app_id: str, tag: str = "release"): """ This shows the list of inputs and outputs for a given app with a given tag. By default, this is done in a pretty HTML way, but this app can be @@ -149,7 +148,7 @@ def app_usage(self, app_id, tag="release"): """ return self.spec_manager.app_usage(app_id, tag) - def app_description(self, app_id, tag="release"): + def app_description(self: "AppManager", app_id: str, tag: str = "release"): """ Returns the app description in a printable HTML format. @@ -166,7 +165,7 @@ def app_description(self, app_id, tag="release"): """ return self.spec_manager.app_description(app_id, tag) - def available_apps(self, tag="release"): + def available_apps(self: "AppManager", tag: str = "release"): """ Lists the set of available apps for a given tag in a simple table. If the tag is not found, a ValueError will be raised. @@ -181,13 +180,13 @@ def available_apps(self, tag="release"): @_app_error_wrapper def run_legacy_batch_app( - self, - app_id, + self: "AppManager", + app_id: str, params, - tag="release", - version=None, - cell_id=None, - run_id=None, + tag: str = "release", + version: str | None = None, + cell_id: str | None = None, + run_id: str | None = None, dry_run=False, ): if params is None: @@ -320,15 +319,15 @@ def run_legacy_batch_app( @_app_error_wrapper def run_app( - self, - app_id, + self: "AppManager", + app_id: str, params, - tag="release", - version=None, - cell_id=None, - run_id=None, + tag: str = "release", + version: str | None = None, + cell_id: str | None = None, + run_id: str | None = None, dry_run=False, - ): + ) -> dict[str, Any]: """ Attempts to run the app, returns a Job with the running app info. If this is given a cell_id, then returns None. If not, it returns the @@ -409,12 +408,12 @@ def run_app( @_app_error_wrapper def run_app_batch( - self, + self: "AppManager", app_info: list, - cell_id: str = None, - run_id: str = None, + cell_id: str | None = None, + run_id: str | None = None, dry_run: bool = False, - ) -> Optional[dict]: + ) -> None | dict[str, Job | list[Job]] | dict[str, dict[str, str | int] | list]: """ Attempts to run a batch of apps in bulk using the Execution Engine's run_app_batch endpoint. If a cell_id is provided, this sends various job messages over the comm channel, and returns None. @@ -554,7 +553,7 @@ def run_app_batch( }, ) - child_jobs = Job.from_job_ids(child_ids, return_list=True) + child_jobs = Job.from_job_ids(child_ids) parent_job = Job.from_job_id( batch_id, children=child_jobs, @@ -568,7 +567,7 @@ def run_app_batch( if cell_id is None: return {"parent_job": parent_job, "child_jobs": child_jobs} - def _validate_bulk_app_info(self, app_info: dict): + def _validate_bulk_app_info(self: "AppManager", app_info: dict): """ Validation consists of: 1. must have "app_id" with format xyz/abc @@ -610,7 +609,9 @@ def _validate_bulk_app_info(self, app_info: dict): f"an app version must be a string, not {app_info['version']}" ) - def _reconstitute_shared_params(self, app_info_el: dict) -> None: + def _reconstitute_shared_params( + self: "AppManager", app_info_el: dict[str, Any] + ) -> None: """ Mutate each params dict to include any shared_params app_info_el is structured like: @@ -639,14 +640,14 @@ def _reconstitute_shared_params(self, app_info_el: dict) -> None: param_set.setdefault(k, v) def _build_run_job_params( - self, - spec: dict, + self: "AppManager", + spec: dict[str, Any], tag: str, - param_set: dict, - version: Optional[str] = None, - cell_id: Optional[str] = None, - run_id: Optional[str] = None, - ws_id: Optional[int] = None, + param_set: dict[str, Any], + version: str | None = None, + cell_id: str | None = None, + run_id: str | None = None, + ws_id: int | None = None, ) -> dict: """ Builds the set of inputs for EE2.run_job and EE2.run_job_batch (RunJobParams) given a spec @@ -726,13 +727,13 @@ def _build_run_job_params( @_app_error_wrapper def run_local_app( - self, - app_id, - params, - tag="release", - version=None, - cell_id=None, - run_id=None, + self: "AppManager", + app_id: str, + params: dict[str, Any], + tag: str = "release", + version: str | None = None, + cell_id: str | None = None, + run_id: str | None = None, widget_state=None, ): """ @@ -811,14 +812,14 @@ def run_local_app( ) def run_local_app_advanced( - self, - app_id, - params, + self: "AppManager", + app_id: str, + params: dict[str, Any], widget_state, - tag="release", - version=None, - cell_id=None, - run_id=None, + tag: str = "release", + version: str | None = None, + cell_id: str | None = None, + run_id: str | None = None, ): return self.run_local_app( app_id, @@ -830,7 +831,9 @@ def run_local_app_advanced( run_id=run_id, ) - def _get_validated_app_spec(self, app_id, tag, is_long, version=None): + def _get_validated_app_spec( + self: "AppManager", app_id: str, tag: str, is_long, version: str | None = None + ): if ( version is not None and tag != "release" @@ -860,7 +863,12 @@ def _get_validated_app_spec(self, app_id, tag, is_long, version=None): ) return spec - def _map_group_inputs(self, value, spec_param, spec_params): + def _map_group_inputs( + self: "AppManager", + value: list | dict | None, + spec_param: dict[str, Any], + spec_params: dict[str, Any], + ): if isinstance(value, list): return [self._map_group_inputs(v, spec_param, spec_params) for v in value] @@ -891,7 +899,12 @@ def _map_group_inputs(self, value, spec_param, spec_params): mapped_value[target_key] = target_val return mapped_value - def _map_inputs(self, input_mapping, params, spec_params): + def _map_inputs( + self: "AppManager", + input_mapping: list[dict[str, Any]], + params: dict[str, Any], + spec_params: dict[str, Any], + ): """ Maps the dictionary of parameters and inputs based on rules provided in the input_mapping. This iterates over the list of input_mappings, and @@ -977,7 +990,7 @@ def _map_inputs(self, input_mapping, params, spec_params): inputs_list.append(inputs_dict[k]) return inputs_list - def _generate_input(self, generator): + def _generate_input(self: "AppManager", generator: dict[str, Any] | None): """ Generates an input value using rules given by NarrativeMethodStore.AutoGeneratedValue. @@ -1007,10 +1020,10 @@ def _generate_input(self, generator): return ret + str(generator["suffix"]) return ret - def _send_comm_message(self, msg_type, content): + def _send_comm_message(self: "AppManager", msg_type: str, content: dict[str, Any]): JobComm().send_comm_message(msg_type, content) - def _get_agent_token(self, name: str) -> auth.TokenInfo: + def _get_agent_token(self: "AppManager", name: str) -> auth.TokenInfo: """ Retrieves an agent token from the Auth service with a formatted name. This prepends "KBApp_" to the name for filtering, and trims to make sure the name diff --git a/src/biokbase/narrative/jobs/job.py b/src/biokbase/narrative/jobs/job.py index 0aacc5e06e..70980fe67d 100644 --- a/src/biokbase/narrative/jobs/job.py +++ b/src/biokbase/narrative/jobs/job.py @@ -2,13 +2,12 @@ import json import uuid from pprint import pprint -from typing import List +from typing import Any -from jinja2 import Template - -import biokbase.narrative.clients as clients +from biokbase.narrative import clients from biokbase.narrative.app_util import map_inputs_from_job, map_outputs_from_state from biokbase.narrative.exception_util import transform_job_exception +from jinja2 import Template from .specmanager import SpecManager @@ -29,9 +28,9 @@ "updated", ] -EXCLUDED_JOB_STATE_FIELDS = JOB_INIT_EXCLUDED_JOB_STATE_FIELDS + ["job_input"] +EXCLUDED_JOB_STATE_FIELDS = [*JOB_INIT_EXCLUDED_JOB_STATE_FIELDS, "job_input"] -OUTPUT_STATE_EXCLUDED_JOB_STATE_FIELDS = EXCLUDED_JOB_STATE_FIELDS + ["user", "wsid"] +OUTPUT_STATE_EXCLUDED_JOB_STATE_FIELDS = [*EXCLUDED_JOB_STATE_FIELDS, "user"] EXTRA_JOB_STATE_FIELDS = ["batch_id", "child_jobs"] @@ -47,9 +46,11 @@ # cell_id (str): ID of the cell that initiated the job (if applicable) # job_id (str): the ID of the job # params (dict): input parameters -# user (str): the user who started the job # run_id (str): unique run ID for the job +# status (str): EE2 job status # tag (str): the application tag (dev/beta/release) +# user (str): the user who started the job +# wsid (str): the workspace ID for the job JOB_ATTR_DEFAULTS = { "app_id": None, "app_version": None, @@ -61,8 +62,10 @@ "retry_ids": [], "retry_parent": None, "run_id": None, + "status": "created", "tag": "release", "user": None, + "wsid": None, } JOB_ATTRS = list(JOB_ATTR_DEFAULTS.keys()) @@ -80,14 +83,18 @@ "params", ] ALL_ATTRS = list(set(JOB_ATTRS + JOB_INPUT_ATTRS + NARR_CELL_INFO_ATTRS)) -STATE_ATTRS = list(set(JOB_ATTRS) - set(JOB_INPUT_ATTRS) - set(NARR_CELL_INFO_ATTRS)) class Job: - _job_logs = [] - _acc_state = None # accumulates state - - def __init__(self, ee2_state, extra_data=None, children=None): + _job_logs: list[dict[str, Any]] | None = None + _acc_state: dict[str, Any] | None = None # accumulates state + + def __init__( + self: "Job", + ee2_state: dict[str, Any], + extra_data: dict[str, Any] | None = None, + children: list["Job"] | None = None, + ): """ Parameters: ----------- @@ -101,6 +108,8 @@ def __init__(self, ee2_state, extra_data=None, children=None): children (list): applies to batch parent jobs, Job instances of this Job's child jobs """ # verify job_id ... TODO validate ee2_state + self._job_logs = [] + self._acc_state = {} if ee2_state.get("job_id") is None: raise ValueError("Cannot create a job without a job ID!") @@ -113,49 +122,25 @@ def __init__(self, ee2_state, extra_data=None, children=None): self.children = children @classmethod - def from_job_id(cls, job_id, extra_data=None, children=None): - state = cls.query_ee2_state(job_id, init=True) - return cls(state, extra_data=extra_data, children=children) + def from_job_id( + cls, + job_id: str, + extra_data: dict[str, Any] | None = None, + children: list["Job"] | None = None, + ) -> "Job": + state = cls.query_ee2_states([job_id], init=True) + return cls(state[job_id], extra_data=extra_data, children=children) @classmethod - def from_job_ids(cls, job_ids, return_list=True): + def from_job_ids(cls, job_ids: list[str]) -> list["Job"]: states = cls.query_ee2_states(job_ids, init=True) - jobs = {} - for job_id, state in states.items(): - jobs[job_id] = cls(state) - - if return_list: - return list(jobs.values()) - return jobs - - @staticmethod - def _trim_ee2_state(state: dict, exclude_fields: list) -> None: - if exclude_fields: - for field in exclude_fields: - if field in state: - del state[field] - - @staticmethod - def query_ee2_state( - job_id: str, - init: bool = True, - ) -> dict: - return clients.get("execution_engine2").check_job( - { - "job_id": job_id, - "exclude_fields": ( - JOB_INIT_EXCLUDED_JOB_STATE_FIELDS - if init - else EXCLUDED_JOB_STATE_FIELDS - ), - } - ) + return [cls(state) for state in states.values()] @staticmethod def query_ee2_states( - job_ids: List[str], + job_ids: list[str], init: bool = True, - ) -> dict: + ) -> dict[str, dict]: if not job_ids: return {} @@ -171,10 +156,20 @@ def query_ee2_states( } ) - def __getattr__(self, name): + @staticmethod + def _trim_ee2_state( + state: dict[str, Any], exclude_fields: list[str] | None + ) -> None: + if exclude_fields: + for field in exclude_fields: + if field in state: + del state[field] + + def __getattr__(self: "Job", name: str) -> None | str | int | list | dict: """ Map expected job attributes to paths in stored ee2 state """ + attr = { "app_id": lambda: self._acc_state.get("job_input", {}).get( "app_id", JOB_ATTR_DEFAULTS["app_id"] @@ -193,44 +188,51 @@ def __getattr__(self, name): "cell_id": lambda: self._acc_state.get("job_input", {}) .get("narrative_cell_info", {}) .get("cell_id", JOB_ATTR_DEFAULTS["cell_id"]), - "child_jobs": lambda: copy.deepcopy( - # TODO - # Only batch container jobs have a child_jobs field - # and need the state refresh. - # But KBParallel/KB Batch App jobs may not have the - # batch_job field - self.refresh_state(force_refresh=True).get( - "child_jobs", JOB_ATTR_DEFAULTS["child_jobs"] - ) - if self.batch_job - else self._acc_state.get("child_jobs", JOB_ATTR_DEFAULTS["child_jobs"]) - ), "job_id": lambda: self._acc_state.get("job_id"), "params": lambda: copy.deepcopy( self._acc_state.get("job_input", {}).get( "params", JOB_ATTR_DEFAULTS["params"] ) ), - "retry_ids": lambda: copy.deepcopy( - # Batch container and retry jobs don't have a - # retry_ids field so skip the state refresh - self._acc_state.get("retry_ids", JOB_ATTR_DEFAULTS["retry_ids"]) - if self.batch_job or self.retry_parent - else self.refresh_state(force_refresh=True).get( - "retry_ids", JOB_ATTR_DEFAULTS["retry_ids"] - ) - ), "retry_parent": lambda: self._acc_state.get( "retry_parent", JOB_ATTR_DEFAULTS["retry_parent"] ), "run_id": lambda: self._acc_state.get("job_input", {}) .get("narrative_cell_info", {}) .get("run_id", JOB_ATTR_DEFAULTS["run_id"]), - # TODO: add the status attribute! "tag": lambda: self._acc_state.get("job_input", {}) .get("narrative_cell_info", {}) .get("tag", JOB_ATTR_DEFAULTS["tag"]), "user": lambda: self._acc_state.get("user", JOB_ATTR_DEFAULTS["user"]), + "wsid": lambda: self._acc_state.get("wsid", JOB_ATTR_DEFAULTS["wsid"]), + # the following properties can change whilst a job is in progress + "child_jobs": lambda force_refresh=True: copy.deepcopy( + # N.b. only batch container jobs have a child_jobs field + # and need the state refresh. + # But KBParallel/KB Batch App jobs do not have the + # batch_job field + self.refresh_state(force_refresh=force_refresh).get( + "child_jobs", JOB_ATTR_DEFAULTS["child_jobs"] + ) + if self.batch_job + else self._acc_state.get("child_jobs", JOB_ATTR_DEFAULTS["child_jobs"]) + ), + "retry_ids": lambda force_refresh=True: copy.deepcopy( + # Batch container and retry jobs don't have a + # retry_ids field so skip the state refresh + self._acc_state.get("retry_ids", JOB_ATTR_DEFAULTS["retry_ids"]) + if self.batch_job or self.retry_parent + else self.refresh_state(force_refresh=force_refresh).get( + "retry_ids", JOB_ATTR_DEFAULTS["retry_ids"] + ) + ), + "status": lambda force_refresh=True: copy.deepcopy( + self.refresh_state(force_refresh=force_refresh).get( + "status", JOB_ATTR_DEFAULTS["status"] + ) + if self.in_terminal_state() + else self._acc_state.get("status", JOB_ATTR_DEFAULTS["status"]) + ), } if name not in attr: @@ -238,7 +240,7 @@ def __getattr__(self, name): return attr[name]() - def __setattr__(self, name, value): + def __setattr__(self: "Job", name: str, value: None | str | int | list | dict): if name in ALL_ATTRS: raise AttributeError( "Job attributes must be updated using the `update_state` method" @@ -247,10 +249,10 @@ def __setattr__(self, name, value): object.__setattr__(self, name, value) @property - def app_name(self): + def app_name(self: "Job") -> str: return "batch" if self.batch_job else self.app_spec()["info"]["name"] - def was_terminal(self): + def in_terminal_state(self: "Job") -> bool: """ Checks if last queried ee2 state (or those of its children) was terminal. """ @@ -265,7 +267,7 @@ def was_terminal(self): return self._acc_state.get("status") in TERMINAL_STATUSES - def in_cells(self, cell_ids: List[str]) -> bool: + def in_cells(self: "Job", cell_ids: list[str]) -> bool: """ For job initialization. See if job is associated with present cells @@ -276,14 +278,14 @@ def in_cells(self, cell_ids: List[str]) -> bool: if cell_ids is None: raise ValueError("cell_ids cannot be None") - if self.batch_job: + if self.batch_job and self.children: return any(child_job.cell_id in cell_ids for child_job in self.children) return self.cell_id in cell_ids - def app_spec(self): + def app_spec(self: "Job"): return SpecManager().get_spec(self.app_id, self.tag) - def parameters(self): + def parameters(self: "Job") -> dict[str, Any]: """ Returns the parameters used to start the job. Job tries to use its params field, but if that's None, then it makes a call to EE2. @@ -291,17 +293,18 @@ def parameters(self): If no exception is raised, this only returns the list of parameters, NOT the whole object fetched from ee2.check_job """ - if self.params is not None: - return self.params + if self.params is None: + try: + state = self.query_ee2_states([self.job_id], init=True) + self.update_state(state[self.job_id]) + except Exception as e: + raise Exception( + f"Unable to fetch parameters for job {self.job_id} - {e}" + ) from e - try: - state = self.query_ee2_state(self.job_id, init=True) - self.update_state(state) - return self.params - except Exception as e: - raise Exception(f"Unable to fetch parameters for job {self.job_id} - {e}") + return self.params - def update_state(self, state: dict) -> None: + def update_state(self: "Job", state: dict[str, Any]) -> None: """ Given a state data structure (as emitted by ee2), update the stored state in the job object. All updates to the job state should go through this function. @@ -316,35 +319,35 @@ def update_state(self, state: dict) -> None: + f"job ID: {self.job_id}; state ID: {state['job_id']}" ) - if self._acc_state is None: - self._acc_state = {} - self._acc_state = {**self._acc_state, **state} def refresh_state( - self, force_refresh=False, exclude_fields=JOB_INIT_EXCLUDED_JOB_STATE_FIELDS + self: "Job", + force_refresh: bool = False, + exclude_fields: list[str] | None = JOB_INIT_EXCLUDED_JOB_STATE_FIELDS, ): """ Queries the job service to see the state of the current job. """ - - if force_refresh or not self.was_terminal(): - state = self.query_ee2_state(self.job_id, init=False) - self.update_state(state) + if force_refresh or not self.in_terminal_state(): + state = self.query_ee2_states([self.job_id], init=False) + self.update_state(state[self.job_id]) return self.cached_state(exclude_fields) - def cached_state(self, exclude_fields=None): + def cached_state( + self: "Job", exclude_fields: list[str] | None = None + ) -> dict[str, Any]: """Wrapper for self._acc_state""" state = copy.deepcopy(self._acc_state) self._trim_ee2_state(state, exclude_fields) return state - def output_state(self, state=None, no_refresh=False) -> dict: + def output_state(self: "Job") -> dict[str, str | dict[str, Any]]: """ - :param state: Supplied when the state is queried beforehand from EE2 in bulk, - or when it is retrieved from a cache. If not supplied, must be - queried with self.refresh_state() or self.cached_state() + Request the current job state in a format suitable for sending to the front end. + N.b. this method does not perform a data update. + :return: dict, with structure { @@ -396,24 +399,17 @@ def output_state(self, state=None, no_refresh=False) -> dict: } :rtype: dict """ - if not state: - state = self.cached_state() if no_refresh else self.refresh_state() - else: - self.update_state(state) - state = self.cached_state() - - if state is None: - return self._create_error_state( - "Unable to find current job state. Please try again later, or contact KBase.", - "Unable to return job state", - -1, - ) - + state = self.cached_state() self._trim_ee2_state(state, OUTPUT_STATE_EXCLUDED_JOB_STATE_FIELDS) + if "job_output" not in state: state["job_output"] = {} - for arg in EXTRA_JOB_STATE_FIELDS: - state[arg] = getattr(self, arg) + + if "batch_id" not in state: + state["batch_id"] = self.batch_id + + if "child_jobs" not in state: + state["child_jobs"] = JOB_ATTR_DEFAULTS["child_jobs"] widget_info = None if state.get("finished"): @@ -440,7 +436,9 @@ def output_state(self, state=None, no_refresh=False) -> dict: "outputWidgetInfo": widget_info, } - def show_output_widget(self, state=None): + def show_output_widget( + self: "Job", state: dict[str, Any] | None = None + ) -> str | None: """ For a complete job, returns the job results. An incomplete job throws an exception @@ -460,7 +458,7 @@ def show_output_widget(self, state=None): ) return f"Job is incomplete! It has status '{state['status']}'" - def get_viewer_params(self, state): + def get_viewer_params(self: "Job", state: dict[str, Any]) -> dict[str, Any] | None: """ Maps job state 'result' onto the inputs for a viewer. """ @@ -469,13 +467,15 @@ def get_viewer_params(self, state): (output_widget, widget_params) = self._get_output_info(state) return {"name": output_widget, "tag": self.tag, "params": widget_params} - def _get_output_info(self, state): + def _get_output_info(self: "Job", state: dict[str, Any]): spec = self.app_spec() return map_outputs_from_state( state, map_inputs_from_job(self.parameters(), spec), spec ) - def log(self, first_line=0, num_lines=None): + def log( + self: "Job", first_line: int = 0, num_lines: int | None = None + ) -> tuple[int, list[dict[str, Any]]]: """ Fetch a list of Job logs from the Job Service. This returns a 2-tuple (number of available log lines, list of log lines) @@ -516,16 +516,17 @@ def log(self, first_line=0, num_lines=None): self._job_logs[first_line : first_line + num_lines], ) - def _update_log(self): - log_update = clients.get("execution_engine2").get_job_logs( + def _update_log(self: "Job") -> None: + log_update: dict[str, Any] = clients.get("execution_engine2").get_job_logs( {"job_id": self.job_id, "skip_lines": len(self._job_logs)} ) if log_update["lines"]: self._job_logs = self._job_logs + log_update["lines"] - def _verify_children(self, children: List["Job"]) -> None: + def _verify_children(self: "Job", children: list["Job"] | None) -> None: if not self.batch_job: raise ValueError("Not a batch container job") + if children is None: raise ValueError( "Must supply children when setting children of batch job parent" @@ -535,16 +536,16 @@ def _verify_children(self, children: List["Job"]) -> None: if sorted(inst_child_ids) != sorted(self._acc_state.get("child_jobs")): raise ValueError("Child job id mismatch") - def update_children(self, children: List["Job"]) -> None: + def update_children(self: "Job", children: list["Job"]) -> None: self._verify_children(children) self.children = children def _create_error_state( - self, + self: "Job", error: str, error_msg: str, code: int, - ) -> dict: + ) -> dict[str, Any]: """ Creates an error state to return if 1. the state is missing or unretrievable @@ -571,10 +572,10 @@ def _create_error_state( "created": 0, } - def __repr__(self): + def __repr__(self: "Job") -> str: return "KBase Narrative Job - " + str(self.job_id) - def info(self): + def info(self: "Job") -> None: """ Printed job info """ @@ -590,7 +591,7 @@ def info(self): except BaseException: print("Unable to retrieve current running state!") - def _repr_javascript_(self): + def _repr_javascript_(self: "Job") -> str: """ Called by Jupyter when a Job object is entered into a code cell """ @@ -625,7 +626,7 @@ def _repr_javascript_(self): output_widget_info=json.dumps(output_widget_info), ) - def dump(self): + def dump(self: "Job") -> dict[str, Any]: """ Display job info without having to iterate through the attributes """ diff --git a/src/biokbase/narrative/jobs/jobcomm.py b/src/biokbase/narrative/jobs/jobcomm.py index e674dfece1..e3c3c4f60a 100644 --- a/src/biokbase/narrative/jobs/jobcomm.py +++ b/src/biokbase/narrative/jobs/jobcomm.py @@ -1,13 +1,12 @@ import copy import threading -from typing import List, Union - -from ipykernel.comm import Comm +from typing import Union from biokbase.narrative.common import kblogging from biokbase.narrative.exception_util import JobRequestException, NarrativeException from biokbase.narrative.jobs.jobmanager import JobManager from biokbase.narrative.jobs.util import load_job_constants +from ipykernel.comm import Comm (PARAM, MESSAGE_TYPE) = load_job_constants() @@ -21,6 +20,8 @@ LOOKUP_TIMER_INTERVAL = 5 +INPUT_TYPES = [PARAM["JOB_ID"], PARAM["JOB_ID_LIST"], PARAM["BATCH_ID"]] + class JobRequest: """ @@ -65,9 +66,7 @@ class JobRequest: batch_id str """ - INPUT_TYPES = [PARAM["JOB_ID"], PARAM["JOB_ID_LIST"], PARAM["BATCH_ID"]] - - def __init__(self, rq: dict): + def __init__(self: "JobRequest", rq: dict) -> None: rq = copy.deepcopy(rq) self.raw_request = rq self.msg_id = rq.get("msg_id") # might be useful later? @@ -79,20 +78,20 @@ def __init__(self, rq: dict): raise JobRequestException(MISSING_REQUEST_TYPE_ERR) input_type_count = 0 - for input_type in self.INPUT_TYPES: + for input_type in INPUT_TYPES: if input_type in self.rq_data: input_type_count += 1 if input_type_count > 1: raise JobRequestException(ONE_INPUT_TYPE_ONLY_ERR) @property - def job_id(self): + def job_id(self: "JobRequest"): if PARAM["JOB_ID"] in self.rq_data: return self.rq_data[PARAM["JOB_ID"]] raise JobRequestException(ONE_INPUT_TYPE_ONLY_ERR) @property - def job_id_list(self): + def job_id_list(self: "JobRequest"): if PARAM["JOB_ID_LIST"] in self.rq_data: return self.rq_data[PARAM["JOB_ID_LIST"]] if PARAM["JOB_ID"] in self.rq_data: @@ -100,22 +99,22 @@ def job_id_list(self): raise JobRequestException(ONE_INPUT_TYPE_ONLY_ERR) @property - def batch_id(self): + def batch_id(self: "JobRequest"): if PARAM["BATCH_ID"] in self.rq_data: return self.rq_data[PARAM["BATCH_ID"]] raise JobRequestException(ONE_INPUT_TYPE_ONLY_ERR) - def has_batch_id(self): + def has_batch_id(self: "JobRequest"): return PARAM["BATCH_ID"] in self.rq_data @property - def cell_id_list(self): + def cell_id_list(self: "JobRequest"): if PARAM["CELL_ID_LIST"] in self.rq_data: return self.rq_data[PARAM["CELL_ID_LIST"]] raise JobRequestException(CELLS_NOT_PROVIDED_ERR) @property - def ts(self): + def ts(self: "JobRequest"): """ Optional field sent with STATUS requests indicating to filter out job states in the STATUS response that have not been updated since @@ -171,7 +170,7 @@ def __new__(cls): JobComm.__instance = object.__new__(cls) return JobComm.__instance - def __init__(self): + def __init__(self: "JobComm") -> None: if self._comm is None: self._comm = Comm(target_name="KBaseJobs", data={}) self._comm.on_msg(self._handle_comm_message) @@ -190,7 +189,7 @@ def __init__(self): MESSAGE_TYPE["STOP_UPDATE"]: self._modify_job_updates, } - def _get_job_ids(self, req: JobRequest) -> List[str]: + def _get_job_ids(self: "JobComm", req: JobRequest) -> list[str]: """ Extract the job IDs from a job request object @@ -198,7 +197,7 @@ def _get_job_ids(self, req: JobRequest) -> List[str]: :type req: JobRequest :return: list of job IDs - :rtype: List[str] + :rtype: list[str] """ if req.has_batch_id(): return self._jm.update_batch_job(req.batch_id) @@ -206,9 +205,9 @@ def _get_job_ids(self, req: JobRequest) -> List[str]: return req.job_id_list def start_job_status_loop( - self, + self: "JobComm", init_jobs: bool = False, - cell_list: List[str] = None, + cell_list: list[str] | None = None, ) -> None: """ Starts the job status lookup loop, which runs every LOOKUP_TIMER_INTERVAL seconds. @@ -237,7 +236,7 @@ def start_job_status_loop( if self._lookup_timer is None: self._lookup_job_status_loop() - def stop_job_status_loop(self) -> None: + def stop_job_status_loop(self: "JobComm") -> None: """ Stops the job status lookup loop if it's running. Otherwise, this effectively does nothing. @@ -247,7 +246,7 @@ def stop_job_status_loop(self) -> None: self._lookup_timer = None self._running_lookup_loop = False - def _lookup_job_status_loop(self) -> None: + def _lookup_job_status_loop(self: "JobComm") -> None: """ Run a loop that will look up job info. After running, this spawns a Timer thread on a loop to run itself again. LOOKUP_TIMER_INTERVAL sets the frequency at which the loop runs. @@ -262,7 +261,9 @@ def _lookup_job_status_loop(self) -> None: self._lookup_timer.start() def get_all_job_states( - self, req: JobRequest = None, ignore_refresh_flag: bool = False + self: "JobComm", + req: JobRequest | None = None, + ignore_refresh_flag: bool = False, ) -> dict: """ Fetches status of all jobs in the current workspace and sends them to the front end. @@ -274,7 +275,7 @@ def get_all_job_states( self.send_comm_message(MESSAGE_TYPE["STATUS_ALL"], all_job_states) return all_job_states - def get_job_states_by_cell_id(self, req: JobRequest) -> dict: + def get_job_states_by_cell_id(self: "JobComm", req: JobRequest) -> dict: """ Fetches status of all jobs associated with the given cell ID(s). @@ -303,7 +304,7 @@ def get_job_states_by_cell_id(self, req: JobRequest) -> dict: self.send_comm_message(MESSAGE_TYPE["CELL_JOB_STATUS"], cell_job_states) return cell_job_states - def get_job_info(self, req: JobRequest) -> dict: + def get_job_info(self: "JobComm", req: JobRequest) -> dict: """ Gets job information for a list of job IDs. @@ -332,7 +333,7 @@ def get_job_info(self, req: JobRequest) -> dict: self.send_comm_message(MESSAGE_TYPE["INFO"], job_info) return job_info - def _get_job_states(self, job_id_list: list, ts: int = None) -> dict: + def _get_job_states(self: "JobComm", job_id_list: list, ts: int = None) -> dict: """ Retrieves the job states for the supplied job_ids. @@ -358,7 +359,7 @@ def _get_job_states(self, job_id_list: list, ts: int = None) -> dict: self.send_comm_message(MESSAGE_TYPE["STATUS"], output_states) return output_states - def get_job_state(self, job_id: str) -> dict: + def get_job_state(self: "JobComm", job_id: str) -> dict: """ Retrieve the job state for a single job. @@ -373,7 +374,7 @@ def get_job_state(self, job_id: str) -> dict: """ return self._get_job_states([job_id]) - def get_job_states(self, req: JobRequest) -> dict: + def get_job_states(self: "JobComm", req: JobRequest) -> dict: """ Retrieves the job states for the supplied job_ids. @@ -388,7 +389,7 @@ def get_job_states(self, req: JobRequest) -> dict: job_id_list = self._get_job_ids(req) return self._get_job_states(job_id_list, req.ts) - def _modify_job_updates(self, req: JobRequest) -> dict: + def _modify_job_updates(self: "JobComm", req: JobRequest) -> dict: """ Modifies how many things want to listen to a job update. If this is a request to start a job update, then this starts the update loop that @@ -419,7 +420,7 @@ def _modify_job_updates(self, req: JobRequest) -> dict: self.send_comm_message(MESSAGE_TYPE["STATUS"], output_states) return output_states - def cancel_jobs(self, req: JobRequest) -> dict: + def cancel_jobs(self: "JobComm", req: JobRequest) -> dict: """ Cancel a job or list of jobs. After sending the cancellation request, the job states are refreshed and their new output states returned. @@ -437,7 +438,7 @@ def cancel_jobs(self, req: JobRequest) -> dict: self.send_comm_message(MESSAGE_TYPE["STATUS"], cancel_results) return cancel_results - def retry_jobs(self, req: JobRequest) -> dict: + def retry_jobs(self: "JobComm", req: JobRequest) -> dict: """ Retry a job or list of jobs. @@ -454,7 +455,7 @@ def retry_jobs(self, req: JobRequest) -> dict: self.send_comm_message(MESSAGE_TYPE["RETRY"], retry_results) return retry_results - def get_job_logs(self, req: JobRequest) -> dict: + def get_job_logs(self: "JobComm", req: JobRequest) -> dict: """ Fetch the logs for a job or list of jobs. @@ -476,7 +477,7 @@ def get_job_logs(self, req: JobRequest) -> dict: self.send_comm_message(MESSAGE_TYPE["LOGS"], log_output) return log_output - def _handle_comm_message(self, msg: dict) -> dict: + def _handle_comm_message(self: "JobComm", msg: dict) -> dict: """ Handle incoming messages on the KBaseJobs channel. @@ -510,7 +511,7 @@ def _handle_comm_message(self, msg: dict) -> dict: return self._msg_map[request.request_type](request) - def send_comm_message(self, msg_type: str, content: dict) -> None: + def send_comm_message(self: "JobComm", msg_type: str, content: dict) -> None: """ Sends a ipykernel.Comm message to the KBaseJobs channel with the given msg_type and content. These just get encoded into the message itself. @@ -572,7 +573,7 @@ class exc_to_msg: jc = JobComm() - def __init__(self, req: Union[JobRequest, dict, str] = None): + def __init__(self: "exc_to_msg", req: Union[JobRequest, dict, str] = None): """ req can be several different things because this context manager supports being used in several different places. Generally it is @@ -582,10 +583,10 @@ def __init__(self, req: Union[JobRequest, dict, str] = None): """ self.req = req - def __enter__(self): + def __enter__(self: "exc_to_msg"): pass # nothing to do here - def __exit__(self, exc_type, exc_value, exc_tb): + def __exit__(self: "exc_to_msg", exc_type, exc_value, exc_tb): """ If an exception is caught during execution in the JobComm code, this will send back a comm error message like: diff --git a/src/biokbase/narrative/jobs/jobmanager.py b/src/biokbase/narrative/jobs/jobmanager.py index 30b6be950b..e28d876a40 100644 --- a/src/biokbase/narrative/jobs/jobmanager.py +++ b/src/biokbase/narrative/jobs/jobmanager.py @@ -1,17 +1,16 @@ import copy from datetime import datetime, timedelta, timezone -from typing import List, Tuple +from typing import Any -from IPython.display import HTML -from jinja2 import Template - -import biokbase.narrative.clients as clients -from biokbase.narrative.system import system_variable +from biokbase.narrative import clients from biokbase.narrative.common import kblogging from biokbase.narrative.exception_util import ( JobRequestException, transform_job_exception, ) +from biokbase.narrative.system import system_variable +from IPython.display import HTML +from jinja2 import Template from .job import JOB_INIT_EXCLUDED_JOB_STATE_FIELDS, Job @@ -54,7 +53,7 @@ class JobManager: _log = kblogging.get_logger(__name__) - def __new__(cls): + def __new__(cls) -> "JobManager": if JobManager.__instance is None: JobManager.__instance = object.__new__(cls) return JobManager.__instance @@ -70,14 +69,14 @@ def _reorder_parents_children(states: dict) -> dict: return {job_id: states[job_id] for job_id in ordering} def _check_job_list( - self, input_ids: List[str] = None - ) -> Tuple[List[str], List[str]]: + self: "JobManager", input_ids: list[str] | None = None + ) -> tuple[list[str], list[str]]: """ Deduplicates the input job list, maintaining insertion order. Any jobs not present in self._running_jobs are added to an error list :param input_ids: list of putative job IDs, defaults to [] - :type input_ids: List[str], optional + :type input_ids: list[str], optional :raises JobRequestException: if the input_ids parameter is not a list or or if there are no valid job IDs supplied @@ -85,7 +84,7 @@ def _check_job_list( :return: tuple with items job_ids - valid job IDs error_ids - jobs that the narrative backend does not know about - :rtype: Tuple[List[str], List[str]] + :rtype: tuple[list[str], list[str]] """ if not input_ids: raise JobRequestException(JOBS_MISSING_ERR, input_ids) @@ -107,7 +106,9 @@ def _check_job_list( return job_ids, error_ids - def register_new_job(self, job: Job, refresh: bool = None) -> None: + def register_new_job( + self: "JobManager", job: Job, refresh: bool | None = None + ) -> None: """ Registers a new Job with the manager and stores the job locally. This should only be invoked when a new Job gets started. @@ -120,19 +121,19 @@ def register_new_job(self, job: Job, refresh: bool = None) -> None: kblogging.log_event(self._log, "register_new_job", {"job_id": job.job_id}) if refresh is None: - refresh = not job.was_terminal() + refresh = not job.in_terminal_state() self._running_jobs[job.job_id] = {"job": job, "refresh": refresh} # add the new job to the _jobs_by_cell_id mapping if there is a cell_id present if job.cell_id: - if job.cell_id not in self._jobs_by_cell_id.keys(): + if job.cell_id not in self._jobs_by_cell_id: self._jobs_by_cell_id[job.cell_id] = set() self._jobs_by_cell_id[job.cell_id].add(job.job_id) if job.batch_id: self._jobs_by_cell_id[job.cell_id].add(job.batch_id) - def initialize_jobs(self, cell_ids: List[str] = None) -> None: + def initialize_jobs(self: "JobManager", cell_ids: list[str] | None = None) -> None: """ Initializes this JobManager. This is expected to be run by a running Narrative, and naturally linked to a workspace. @@ -143,7 +144,7 @@ def initialize_jobs(self, cell_ids: List[str] = None) -> None: 4. start the status lookup loop. :param cell_ids: list of cell IDs to filter the existing jobs for, defaults to None - :type cell_ids: List[str], optional + :type cell_ids: list[str], optional :raises NarrativeException: if the call to ee2 fails """ @@ -161,10 +162,10 @@ def initialize_jobs(self, cell_ids: List[str] = None) -> None: except Exception as e: kblogging.log_event(self._log, "init_error", {"err": str(e)}) new_e = transform_job_exception(e, "Unable to initialize jobs") - raise new_e + raise new_e from e self._running_jobs = {} - job_states = self._reorder_parents_children(job_states) + job_states: dict[str, Any] = self._reorder_parents_children(job_states) for job_state in job_states.values(): child_jobs = None if job_state.get("batch_job"): @@ -178,19 +179,19 @@ def initialize_jobs(self, cell_ids: List[str] = None) -> None: # Set to refresh when job is not in terminal state # and when job is present in cells (if given) # and when it is not part of a batch - refresh = not job.was_terminal() and not job.batch_id + refresh = not job.in_terminal_state() and not job.batch_id if cell_ids is not None: refresh = refresh and job.in_cells(cell_ids) self.register_new_job(job, refresh) - def _create_jobs(self, job_ids: List[str]) -> dict: + def _create_jobs(self: "JobManager", job_ids: list[str]) -> dict: """ Given a list of job IDs, creates job objects for them and populates the _running_jobs dictionary. TODO: error handling :param job_ids: job IDs to create job objects for - :type job_ids: List[str] + :type job_ids: list[str] :return: dictionary of job states indexed by job ID :rtype: dict @@ -207,7 +208,7 @@ def _create_jobs(self, job_ids: List[str]) -> dict: return job_states - def get_job(self, job_id: str) -> Job: + def get_job(self: "JobManager", job_id: str) -> Job: """ Retrieve a job from the Job Manager's _running_jobs index. @@ -224,8 +225,8 @@ def get_job(self, job_id: str) -> Job: return self._running_jobs[job_id]["job"] def _construct_job_output_state_set( - self, job_ids: List[str], states: dict = None - ) -> dict: + self: "JobManager", job_ids: list[str], states: dict[str, Any] | None = None + ) -> dict[str, dict[str, Any]]: """ Builds a set of job states for the list of job ids. @@ -263,7 +264,7 @@ def _construct_job_output_state_set( } :param job_ids: list of job IDs - :type job_ids: List[str] + :type job_ids: list[str] :param states: dict of job state data from EE2, indexed by job ID, defaults to None :type states: dict, optional @@ -278,6 +279,9 @@ def _construct_job_output_state_set( if not job_ids: return {} + # ensure states is initialised + if not states: + states = {} output_states = {} jobs_to_lookup = [] @@ -285,17 +289,19 @@ def _construct_job_output_state_set( # These are already post-processed and ready to return. for job_id in job_ids: job = self.get_job(job_id) - if job.was_terminal(): + if job.in_terminal_state(): + # job is already finished, will not change + output_states[job_id] = job.output_state() + elif job_id in states: + job.update_state(states[job_id]) output_states[job_id] = job.output_state() - elif states and job_id in states: - state = states[job_id] - output_states[job_id] = job.output_state(state) else: jobs_to_lookup.append(job_id) fetched_states = {} # Get the rest of states direct from EE2. if jobs_to_lookup: + error_message = "" try: fetched_states = Job.query_ee2_states(jobs_to_lookup, init=False) except Exception as e: @@ -312,16 +318,19 @@ def _construct_job_output_state_set( for job_id in jobs_to_lookup: job = self.get_job(job_id) if job_id in fetched_states: - output_states[job_id] = job.output_state(fetched_states[job_id]) + job.update_state(fetched_states[job_id]) + output_states[job_id] = job.output_state() else: # fetch the current state without updating it - output_states[job_id] = job.output_state({}) + output_states[job_id] = job.output_state() # add an error field with the error message from the failed look up output_states[job_id]["error"] = error_message return output_states - def get_job_states(self, job_ids: List[str], ts: int = None) -> dict: + def get_job_states( + self: "JobManager", job_ids: list[str], ts: int | None = None + ) -> dict: """ Retrieves the job states for the supplied job_ids. @@ -332,7 +341,7 @@ def get_job_states(self, job_ids: List[str], ts: int = None) -> dict: } :param job_ids: job IDs to retrieve job state data for - :type job_ids: List[str] + :type job_ids: list[str] :param ts: timestamp (as generated by time.time_ns()) to filter the jobs, defaults to None :type ts: int, optional @@ -343,7 +352,9 @@ def get_job_states(self, job_ids: List[str], ts: int = None) -> dict: output_states = self._construct_job_output_state_set(job_ids) return self.add_errors_to_results(output_states, error_ids) - def get_all_job_states(self, ignore_refresh_flag=False) -> dict: + def get_all_job_states( + self: "JobManager", ignore_refresh_flag: bool = False + ) -> dict: """ Fetches states for all running jobs. If ignore_refresh_flag is True, then returns states for all jobs this @@ -356,23 +367,26 @@ def get_all_job_states(self, ignore_refresh_flag=False) -> dict: :return: dictionary of job states, indexed by job ID :rtype: dict """ - jobs_to_lookup = [] - # grab the list of running job ids, so we don't run into update-while-iterating problems. - for job_id in self._running_jobs: - if self._running_jobs[job_id]["refresh"] or ignore_refresh_flag: - jobs_to_lookup.append(job_id) - if len(jobs_to_lookup) > 0: + jobs_to_lookup = [ + job_id + for job_id in self._running_jobs + if self._running_jobs[job_id]["refresh"] or ignore_refresh_flag + ] + + if jobs_to_lookup: return self._construct_job_output_state_set(jobs_to_lookup) return {} - def _get_job_ids_by_cell_id(self, cell_id_list: List[str] = None) -> tuple: + def _get_job_ids_by_cell_id( + self: "JobManager", cell_id_list: list[str] | None = None + ) -> tuple: """ Finds jobs with a cell_id in cell_id_list. Mappings of job ID to cell ID are added when new jobs are registered. :param cell_id_list: cell IDs to retrieve job state data for - :type cell_id_list: List[str] + :type cell_id_list: list[str] :return: tuple with two components: job_id_list: list of job IDs associated with the cell IDs supplied @@ -392,12 +406,14 @@ def _get_job_ids_by_cell_id(self, cell_id_list: List[str] = None) -> tuple: job_id_list = set().union(*cell_to_job_mapping.values()) return (job_id_list, cell_to_job_mapping) - def get_job_states_by_cell_id(self, cell_id_list: List[str] = None) -> dict: + def get_job_states_by_cell_id( + self: "JobManager", cell_id_list: list[str] | None = None + ) -> dict: """ Retrieves the job states for jobs associated with the cell_id_list supplied. :param cell_id_list: cell IDs to retrieve job state data for - :type cell_id_list: List[str] + :type cell_id_list: list[str] :return: dictionary with two keys: 'jobs': job states, indexed by job ID @@ -413,7 +429,7 @@ def get_job_states_by_cell_id(self, cell_id_list: List[str] = None) -> dict: return {"jobs": job_states, "mapping": cell_to_job_mapping} - def get_job_info(self, job_ids: List[str]) -> dict: + def get_job_info(self: "JobManager", job_ids: list[str]) -> dict: """ Gets job information for a list of job IDs. @@ -433,7 +449,7 @@ def get_job_info(self, job_ids: List[str]) -> dict: } :param job_ids: job IDs to retrieve job info for - :type job_ids: List[str] + :type job_ids: list[str] :return: job info for each job, indexed by job ID :rtype: dict """ @@ -452,10 +468,10 @@ def get_job_info(self, job_ids: List[str]) -> dict: return self.add_errors_to_results(infos, error_ids) def get_job_logs( - self, + self: "JobManager", job_id: str, first_line: int = 0, - num_lines: int = None, + num_lines: int | None = None, latest: bool = False, ) -> dict: """ @@ -518,27 +534,27 @@ def get_job_logs( logs = logs[first_line:] else: (max_lines, logs) = job.log(first_line=first_line, num_lines=num_lines) - + except Exception as e: return { "job_id": job.job_id, "batch_id": job.batch_id, - "first": first_line, - "latest": latest, - "max_lines": max_lines, - "lines": logs, + "error": e.message, } - except Exception as e: + else: return { "job_id": job.job_id, "batch_id": job.batch_id, - "error": e.message, + "first": first_line, + "latest": latest, + "max_lines": max_lines, + "lines": logs, } def get_job_logs_for_list( - self, - job_id_list: List[str], + self: "JobManager", + job_id_list: list[str], first_line: int = 0, - num_lines: int = None, + num_lines: int | None = None, latest: bool = False, ) -> dict: """ @@ -551,7 +567,7 @@ def get_job_logs_for_list( } :param job_id_list: list of jobs to fetch logs for - :type job_id_list: List[str] + :type job_id_list: list[str] :param first_line: the first line to be returned, defaults to 0 :type first_line: int, optional :param num_lines: number of lines to be returned, defaults to None @@ -570,7 +586,7 @@ def get_job_logs_for_list( return self.add_errors_to_results(output, error_ids) - def cancel_jobs(self, job_id_list: List[str]) -> dict: + def cancel_jobs(self: "JobManager", job_id_list: list[str]) -> dict: """ Cancel a list of jobs and return their new state. After sending the cancellation request, the job states are refreshed and their new output states returned. @@ -588,7 +604,7 @@ def cancel_jobs(self, job_id_list: List[str]) -> dict: } :param job_id_list: job IDs to cancel - :type job_id_list: List[str] + :type job_id_list: list[str] :return: job output states, indexed by job ID :rtype: dict @@ -596,7 +612,7 @@ def cancel_jobs(self, job_id_list: List[str]) -> dict: job_ids, error_ids = self._check_job_list(job_id_list) error_states = {} for job_id in job_ids: - if not self.get_job(job_id).was_terminal(): + if not self.get_job(job_id).in_terminal_state(): error = self._cancel_job(job_id) if error: error_states[job_id] = error.message @@ -608,7 +624,7 @@ def cancel_jobs(self, job_id_list: List[str]) -> dict: return self.add_errors_to_results(job_states, error_ids) - def _cancel_job(self, job_id: str) -> None: + def _cancel_job(self: "JobManager", job_id: str) -> Exception | None: """ Cancel a single job. If an error occurs during cancellation, that error is converted into a NarrativeException and returned to the caller. @@ -633,7 +649,7 @@ def _cancel_job(self, job_id: str) -> None: del self._running_jobs[job_id]["canceling"] return error - def retry_jobs(self, job_id_list: List[str]) -> dict: + def retry_jobs(self: "JobManager", job_id_list: list[str]) -> dict: """ Retry a list of job IDs, returning job output states for the jobs to be retried and the new jobs created by the retry command. @@ -667,7 +683,7 @@ def retry_jobs(self, job_id_list: List[str]) -> dict: } :param job_id_list: list of job IDs - :type job_id_list: List[str] + :type job_id_list: list[str] :raises NarrativeException: if EE2 returns an error from the retry request @@ -676,9 +692,9 @@ def retry_jobs(self, job_id_list: List[str]) -> dict: """ job_ids, error_ids = self._check_job_list(job_id_list) try: - retry_results = clients.get("execution_engine2").retry_jobs( - {"job_ids": job_ids} - ) + retry_results: list[dict[str, Any]] = clients.get( + "execution_engine2" + ).retry_jobs({"job_ids": job_ids}) except Exception as e: raise transform_job_exception(e, "Unable to retry job(s)") from e @@ -706,14 +722,16 @@ def retry_jobs(self, job_id_list: List[str]) -> dict: results_by_job_id[job_id]["error"] = result["error"] return self.add_errors_to_results(results_by_job_id, error_ids) - def add_errors_to_results(self, results: dict, error_ids: List[str]) -> dict: + def add_errors_to_results( + self: "JobManager", results: dict, error_ids: list[str] + ) -> dict: """ Add the generic "not found" error for each job_id in error_ids. :param results: dictionary of job data (output state, info, retry, etc.) indexed by job ID :type results: dict :param error_ids: list of IDs that could not be found - :type error_ids: List[str] + :type error_ids: list[str] :return: input results dictionary augmented by a dictionary containing job ID and a short not found error message for every ID in the error_ids list @@ -726,7 +744,9 @@ def add_errors_to_results(self, results: dict, error_ids: List[str]) -> dict: } return results - def modify_job_refresh(self, job_ids: List[str], update_refresh: bool) -> None: + def modify_job_refresh( + self: "JobManager", job_ids: list[str], update_refresh: bool + ) -> None: """ Modifies how many things want to get the job updated. If this sets the current "refresh" key to be less than 0, it gets reset to 0. @@ -737,7 +757,7 @@ def modify_job_refresh(self, job_ids: List[str], update_refresh: bool) -> None: for job_id in job_ids: self._running_jobs[job_id]["refresh"] = update_refresh - def update_batch_job(self, batch_id: str) -> List[str]: + def update_batch_job(self: "JobManager", batch_id: str) -> list[str]: """ Update a batch job and create child jobs if necessary """ @@ -755,13 +775,13 @@ def update_batch_job(self, batch_id: str) -> List[str]: else: unreg_child_ids.append(job_id) - unreg_child_jobs = [] + unreg_child_jobs: list[Job] = [] if unreg_child_ids: unreg_child_jobs = Job.from_job_ids(unreg_child_ids) for job in unreg_child_jobs: self.register_new_job( job=job, - refresh=not job.was_terminal(), + refresh=not job.in_terminal_state(), ) batch_job.update_children(reg_child_jobs + unreg_child_jobs) diff --git a/src/biokbase/narrative/tests/data/response_data.json b/src/biokbase/narrative/tests/data/response_data.json index ddd42028aa..cfde19e2a8 100644 --- a/src/biokbase/narrative/tests/data/response_data.json +++ b/src/biokbase/narrative/tests/data/response_data.json @@ -408,7 +408,8 @@ "wsName": "wjriehl:1475006266615" }, "tag": "release" - } + }, + "wsid": 10538 }, "job_id": "BATCH_COMPLETED", "outputWidgetInfo": { @@ -446,7 +447,8 @@ "BATCH_RETRY_ERROR" ], "running": 1625755952628, - "status": "error" + "status": "error", + "wsid": 10538 }, "job_id": "BATCH_ERROR_RETRIED", "outputWidgetInfo": null @@ -469,7 +471,8 @@ "job_output": {}, "retry_count": 0, "retry_ids": [], - "status": "created" + "status": "created", + "wsid": 10538 }, "job_id": "BATCH_PARENT", "outputWidgetInfo": null @@ -509,7 +512,8 @@ "wsName": "wjriehl:1475006266615" }, "tag": "release" - } + }, + "wsid": 10538 }, "job_id": "BATCH_RETRY_COMPLETED", "outputWidgetInfo": { @@ -546,7 +550,8 @@ "retry_ids": [], "retry_parent": "BATCH_ERROR_RETRIED", "running": 1625757662469, - "status": "error" + "status": "error", + "wsid": 10538 }, "job_id": "BATCH_RETRY_ERROR", "outputWidgetInfo": null @@ -564,7 +569,8 @@ "retry_ids": [], "retry_parent": "BATCH_TERMINATED_RETRIED", "running": 1625757285899, - "status": "running" + "status": "running", + "wsid": 10538 }, "job_id": "BATCH_RETRY_RUNNING", "outputWidgetInfo": null @@ -583,7 +589,8 @@ "retry_ids": [], "running": 1625755952599, "status": "terminated", - "terminated_code": 0 + "terminated_code": 0, + "wsid": 10538 }, "job_id": "BATCH_TERMINATED", "outputWidgetInfo": null @@ -605,7 +612,8 @@ ], "running": 1625755952563, "status": "terminated", - "terminated_code": 0 + "terminated_code": 0, + "wsid": 10538 }, "job_id": "BATCH_TERMINATED_RETRIED", "outputWidgetInfo": null @@ -641,7 +649,8 @@ "report_window_line_height": "16" }, "tag": "dev" - } + }, + "wsid": 54321 }, "job_id": "JOB_COMPLETED", "outputWidgetInfo": { @@ -664,7 +673,8 @@ "job_output": {}, "retry_count": 0, "retry_ids": [], - "status": "created" + "status": "created", + "wsid": 55555 }, "job_id": "JOB_CREATED", "outputWidgetInfo": null @@ -690,7 +700,8 @@ "retry_count": 0, "retry_ids": [], "running": 1642431847689, - "status": "error" + "status": "error", + "wsid": 65854 }, "job_id": "JOB_ERROR", "outputWidgetInfo": null @@ -707,7 +718,8 @@ "retry_count": 0, "retry_ids": [], "running": 1642521775867, - "status": "running" + "status": "running", + "wsid": 54321 }, "job_id": "JOB_RUNNING", "outputWidgetInfo": null @@ -725,7 +737,8 @@ "retry_count": 0, "retry_ids": [], "status": "terminated", - "terminated_code": 0 + "terminated_code": 0, + "wsid": 54321 }, "job_id": "JOB_TERMINATED", "outputWidgetInfo": null @@ -780,7 +793,8 @@ "wsName": "wjriehl:1475006266615" }, "tag": "release" - } + }, + "wsid": 10538 }, "job_id": "BATCH_COMPLETED", "outputWidgetInfo": { @@ -821,7 +835,8 @@ "BATCH_RETRY_ERROR" ], "running": 1625755952628, - "status": "error" + "status": "error", + "wsid": 10538 }, "job_id": "BATCH_ERROR_RETRIED", "outputWidgetInfo": null @@ -849,7 +864,8 @@ "retry_ids": [], "retry_parent": "BATCH_ERROR_RETRIED", "running": 1625757662469, - "status": "error" + "status": "error", + "wsid": 10538 }, "job_id": "BATCH_RETRY_ERROR", "outputWidgetInfo": null @@ -876,7 +892,8 @@ "job_output": {}, "retry_count": 0, "retry_ids": [], - "status": "created" + "status": "created", + "wsid": 10538 }, "job_id": "BATCH_PARENT", "outputWidgetInfo": null @@ -920,7 +937,8 @@ "wsName": "wjriehl:1475006266615" }, "tag": "release" - } + }, + "wsid": 10538 }, "job_id": "BATCH_RETRY_COMPLETED", "outputWidgetInfo": { @@ -961,7 +979,8 @@ "retry_ids": [], "retry_parent": "BATCH_ERROR_RETRIED", "running": 1625757662469, - "status": "error" + "status": "error", + "wsid": 10538 }, "job_id": "BATCH_RETRY_ERROR", "outputWidgetInfo": null @@ -983,7 +1002,8 @@ "retry_ids": [], "retry_parent": "BATCH_TERMINATED_RETRIED", "running": 1625757285899, - "status": "running" + "status": "running", + "wsid": 10538 }, "job_id": "BATCH_RETRY_RUNNING", "outputWidgetInfo": null @@ -1006,7 +1026,8 @@ "retry_ids": [], "running": 1625755952599, "status": "terminated", - "terminated_code": 0 + "terminated_code": 0, + "wsid": 10538 }, "job_id": "BATCH_TERMINATED", "outputWidgetInfo": null @@ -1031,7 +1052,8 @@ ], "running": 1625755952563, "status": "terminated", - "terminated_code": 0 + "terminated_code": 0, + "wsid": 10538 }, "job_id": "BATCH_TERMINATED_RETRIED", "outputWidgetInfo": null @@ -1072,7 +1094,8 @@ "wsName": "wjriehl:1475006266615" }, "tag": "release" - } + }, + "wsid": 10538 }, "job_id": "BATCH_RETRY_COMPLETED", "outputWidgetInfo": { @@ -1122,7 +1145,8 @@ "report_window_line_height": "16" }, "tag": "dev" - } + }, + "wsid": 54321 }, "job_id": "JOB_COMPLETED", "outputWidgetInfo": { @@ -1149,7 +1173,8 @@ "job_output": {}, "retry_count": 0, "retry_ids": [], - "status": "created" + "status": "created", + "wsid": 55555 }, "job_id": "JOB_CREATED", "outputWidgetInfo": null @@ -1179,7 +1204,8 @@ "retry_count": 0, "retry_ids": [], "running": 1642431847689, - "status": "error" + "status": "error", + "wsid": 65854 }, "job_id": "JOB_ERROR", "outputWidgetInfo": null @@ -1200,7 +1226,8 @@ "retry_count": 0, "retry_ids": [], "running": 1642521775867, - "status": "running" + "status": "running", + "wsid": 54321 }, "job_id": "JOB_RUNNING", "outputWidgetInfo": null @@ -1222,7 +1249,8 @@ "retry_count": 0, "retry_ids": [], "status": "terminated", - "terminated_code": 0 + "terminated_code": 0, + "wsid": 54321 }, "job_id": "JOB_TERMINATED", "outputWidgetInfo": null diff --git a/src/biokbase/narrative/tests/generate_test_results.py b/src/biokbase/narrative/tests/generate_test_results.py index 0c389bdb1c..ee1e1fd5b6 100644 --- a/src/biokbase/narrative/tests/generate_test_results.py +++ b/src/biokbase/narrative/tests/generate_test_results.py @@ -37,11 +37,13 @@ TEST_SPECS[tag] = spec_dict -def get_test_spec(tag, app_id): +def get_test_spec(tag: str, app_id: str) -> dict[str, dict]: return copy.deepcopy(TEST_SPECS[tag][app_id]) -def generate_mappings(all_jobs): +def generate_mappings( + all_jobs: dict[str, dict] +) -> tuple[dict[str, str], dict[str, dict], set[dict]]: # collect retried jobs and generate the cell-to-job mapping retried_jobs = {} jobs_by_cell_id = {} @@ -57,7 +59,7 @@ def generate_mappings(all_jobs): # add the new job to the jobs_by_cell_id mapping if there is a cell_id present if cell_id: - if cell_id not in jobs_by_cell_id.keys(): + if cell_id not in jobs_by_cell_id: jobs_by_cell_id[cell_id] = set() jobs_by_cell_id[cell_id].add(job["job_id"]) @@ -67,7 +69,7 @@ def generate_mappings(all_jobs): return (retried_jobs, jobs_by_cell_id, batch_jobs) -def _generate_job_output(job_id): +def _generate_job_output(job_id: str) -> dict[str, str | dict]: state = get_test_job(job_id) widget_info = state.get("widget_info") @@ -93,14 +95,14 @@ def _generate_job_output(job_id): return {"job_id": job_id, "jobState": state, "outputWidgetInfo": widget_info} -def generate_bad_jobs(): +def generate_bad_jobs() -> dict[str, dict]: return { job_id: {"job_id": job_id, "error": generate_error(job_id, "not_found")} for job_id in BAD_JOBS } -def generate_job_output_state(all_jobs): +def generate_job_output_state(all_jobs: dict[str, dict]) -> dict[str, dict]: """ Generate the expected output from a `job_status` request """ @@ -110,14 +112,13 @@ def generate_job_output_state(all_jobs): return job_status -def generate_job_info(all_jobs): +def generate_job_info(all_jobs: dict[str, dict]) -> dict[str, dict]: """ Expected output from a `job_info` request """ job_info = generate_bad_jobs() for job_id in all_jobs: test_job = get_test_job(job_id) - job_id = test_job.get("job_id") app_id = test_job.get("job_input", {}).get("app_id") tag = ( test_job.get("job_input", {}) @@ -145,7 +146,9 @@ def generate_job_info(all_jobs): return job_info -def generate_job_retries(all_jobs, retried_jobs): +def generate_job_retries( + all_jobs: dict[str, dict], retried_jobs: dict[str, str] +) -> dict[str, dict]: """ Expected output from a `retry_job` request """ @@ -173,14 +176,11 @@ def generate_job_retries(all_jobs, retried_jobs): return job_retries -def log_gen(n_lines): - lines = [] - for i in range(n_lines): - lines.append({"is_error": 0, "line": f"This is line {str(i+1)}"}) - return lines +def log_gen(n_lines: int) -> list[dict[str, int | str]]: + return [{"is_error": 0, "line": f"This is line {i+1}"} for i in range(n_lines)] -def generate_job_logs(all_jobs): +def generate_job_logs(all_jobs: dict[str, dict]) -> dict[str, dict]: """ Expected output from a `job_logs` request. Note that only completed jobs have logs in this case. """ @@ -211,11 +211,9 @@ def generate_job_logs(all_jobs): # mapping of cell IDs to jobs INVALID_CELL_ID = "invalid_cell_id" -TEST_CELL_ID_LIST = list(JOBS_BY_CELL_ID.keys()) + [INVALID_CELL_ID] +TEST_CELL_ID_LIST = [*list(JOBS_BY_CELL_ID.keys()), INVALID_CELL_ID] # mapping expected as output from get_job_states_by_cell_id -TEST_CELL_IDs = { - cell_id: list(JOBS_BY_CELL_ID[cell_id]) for cell_id in JOBS_BY_CELL_ID.keys() -} +TEST_CELL_IDs = {cell_id: list(JOBS_BY_CELL_ID[cell_id]) for cell_id in JOBS_BY_CELL_ID} TEST_CELL_IDs[INVALID_CELL_ID] = [] @@ -230,7 +228,7 @@ def generate_job_logs(all_jobs): config.write_json_file(RESPONSE_DATA_FILE, ALL_RESPONSE_DATA) -def main(args=None): +def main(args: list[str] | None = None) -> None: if args and args[0] == "--force" or not os.path.exists(RESPONSE_DATA_FILE): config.write_json_file(RESPONSE_DATA_FILE, ALL_RESPONSE_DATA) diff --git a/src/biokbase/narrative/tests/job_test_constants.py b/src/biokbase/narrative/tests/job_test_constants.py index 3fa741d867..721f73ca1a 100644 --- a/src/biokbase/narrative/tests/job_test_constants.py +++ b/src/biokbase/narrative/tests/job_test_constants.py @@ -95,7 +95,7 @@ def get_test_jobs(job_ids): BATCH_PARENT_CHILDREN = [BATCH_PARENT] + BATCH_CHILDREN -JOBS_TERMINALITY = { +JOB_TERMINAL_STATE = { job_id: TEST_JOBS[job_id]["status"] in TERMINAL_STATUSES for job_id in TEST_JOBS.keys() } @@ -103,7 +103,7 @@ def get_test_jobs(job_ids): TERMINAL_JOBS = [] ACTIVE_JOBS = [] REFRESH_STATE = {} -for key, value in JOBS_TERMINALITY.items(): +for key, value in JOB_TERMINAL_STATE.items(): if value: TERMINAL_JOBS.append(key) else: diff --git a/src/biokbase/narrative/tests/narrative_mock/mockclients.py b/src/biokbase/narrative/tests/narrative_mock/mockclients.py index e86e912af8..a9a97b2175 100644 --- a/src/biokbase/narrative/tests/narrative_mock/mockclients.py +++ b/src/biokbase/narrative/tests/narrative_mock/mockclients.py @@ -18,7 +18,7 @@ ) from biokbase.workspace.baseclient import ServerError -from ..util import ConfigTests +from src.biokbase.narrative.tests.util import ConfigTests RANDOM_DATE = "2018-08-10T16:47:36+0000" RANDOM_TYPE = "ModuleA.TypeA-1.0" @@ -29,11 +29,11 @@ NARR_HASH = "278abf8f0dbf8ab5ce349598a8674a6e" -def generate_ee2_error(fn): +def generate_ee2_error(fn: str) -> EEServerError: return EEServerError("JSONRPCError", -32000, fn + " failed") -def get_nar_obj(i): +def get_nar_obj(i: int) -> list[str | int | dict[str, str]]: return [ i, "My_Test_Narrative", @@ -291,9 +291,7 @@ def check_job(self, params): job_id = params.get("job_id") if not job_id: return {} - job_state = self.job_state_data.get( - job_id, {"job_id": job_id, "status": "unmocked"} - ) + job_state = self.job_state_data.get(job_id, {"job_id": job_id}) if "exclude_fields" in params: for f in params["exclude_fields"]: if f in job_state: @@ -546,6 +544,9 @@ def check_workspace_jobs(self, params): def check_job(self, params): raise generate_ee2_error("check_job") + def check_jobs(self, params): + raise generate_ee2_error("check_jobs") + def cancel_job(self, params): raise generate_ee2_error(MESSAGE_TYPE["CANCEL"]) @@ -618,7 +619,7 @@ def __exit__(self, exc_type, exc_value, traceback): assert ( getattr(self.target, self.method_name) == self.called - ), f"Method {self.target.__name__}.{self.method_name} was modified during context managment with {self.__class__.name}" + ), f"Method {self.target.__name__}.{self.method_name} was modified during context management with {self.__class__.name}" setattr(self.target, self.method_name, self.orig_method) self.assert_called(self.call_status) diff --git a/src/biokbase/narrative/tests/test_appmanager.py b/src/biokbase/narrative/tests/test_appmanager.py index f99341ef1c..ef86f10110 100644 --- a/src/biokbase/narrative/tests/test_appmanager.py +++ b/src/biokbase/narrative/tests/test_appmanager.py @@ -316,6 +316,8 @@ def test_run_app__from_gui_cell(self, auth, c): c.return_value.send_comm_message, False, cell_id=cell_id ) + # N.b. the following three tests contact the workspace + # src/config/config.json must be set to use the CI configuration @mock.patch(JOB_COMM_MOCK) def test_run_app__bad_id(self, c): c.return_value.send_comm_message = MagicMock() @@ -507,6 +509,8 @@ def test_run_legacy_batch_app__gui_cell(self, auth, c): c.return_value.send_comm_message, False, cell_id=cell_id ) + # N.b. the following three tests contact the workspace + # src/config/config.json must be set to use the CI configuration @mock.patch(JOB_COMM_MOCK) def test_run_legacy_batch_app__bad_id(self, c): c.return_value.send_comm_message = MagicMock() @@ -1094,6 +1098,8 @@ def test_generate_input_bad(self): with self.assertRaises(ValueError): self.am._generate_input({"symbols": -1}) + # N.b. the following test contacts the workspace + # src/config/config.json must be set to use the CI configuration def test_transform_input_good(self): ws_name = self.public_ws test_data = [ diff --git a/src/biokbase/narrative/tests/test_job.py b/src/biokbase/narrative/tests/test_job.py index 52a7a755b9..2f33cf0002 100644 --- a/src/biokbase/narrative/tests/test_job.py +++ b/src/biokbase/narrative/tests/test_job.py @@ -30,7 +30,7 @@ JOB_CREATED, JOB_RUNNING, JOB_TERMINATED, - JOBS_TERMINALITY, + JOB_TERMINAL_STATE, MAX_LOG_LINES, TERMINAL_JOBS, get_test_job, @@ -109,6 +109,7 @@ def create_attrs_from_ee2(job_id): "run_id": narr_cell_info.get("run_id", JOB_ATTR_DEFAULTS["run_id"]), "tag": narr_cell_info.get("tag", JOB_ATTR_DEFAULTS["tag"]), "user": state.get("user", JOB_ATTR_DEFAULTS["user"]), + "wsid": state.get("wsid", JOB_ATTR_DEFAULTS["wsid"]), } @@ -140,17 +141,14 @@ def get_batch_family_jobs(return_list=False): As invoked in appmanager's run_app_batch, i.e., with from_job_id(s) """ - child_jobs = Job.from_job_ids(BATCH_CHILDREN, return_list=True) + child_jobs = Job.from_job_ids(BATCH_CHILDREN) batch_job = Job.from_job_id(BATCH_PARENT, children=child_jobs) if return_list: - return [batch_job] + child_jobs + return [batch_job, *child_jobs] return { BATCH_PARENT: batch_job, - **{ - child_id: child_job - for child_id, child_job in zip(BATCH_CHILDREN, child_jobs) - }, + **{job.job_id: job for job in child_jobs}, } @@ -186,11 +184,11 @@ def check_jobs_equal(self, jobl, jobr): def check_job_attrs_custom(self, job, exp_attr=None): if not exp_attr: exp_attr = {} - attr = dict(JOB_ATTR_DEFAULTS) + attr = copy.deepcopy(JOB_ATTR_DEFAULTS) attr.update(exp_attr) with mock.patch(CLIENTS, get_mock_client): for name, value in attr.items(): - self.assertEqual(value, getattr(job, name)) + assert value == getattr(job, name) def check_job_attrs(self, job, job_id, exp_attrs=None, skip_state=False): # TODO check _acc_state full vs pruned, extra_data @@ -240,10 +238,10 @@ def test_job_init__from_job_ids(self): job_ids.remove(BATCH_PARENT) with mock.patch(CLIENTS, get_mock_client): - jobs = Job.from_job_ids(job_ids, return_list=False) + jobs = Job.from_job_ids(job_ids) - for job_id, job in jobs.items(): - self.check_job_attrs(job, job_id) + for job in jobs: + self.check_job_attrs(job, job.job_id) def test_job_init__extra_state(self): """ @@ -311,6 +309,7 @@ def test_job_from_state__custom(self): "job_id": "0123456789abcdef", "params": params, "run_id": JOB_ATTR_DEFAULTS["run_id"], + "status": JOB_ATTR_DEFAULTS["status"], "tag": JOB_ATTR_DEFAULTS["tag"], "user": "the_user", } @@ -344,9 +343,9 @@ def test_refresh_state__non_terminal(self): """ # ee2_state is fully populated (includes job_input, no job_output) job = create_job_from_ee2(JOB_CREATED) - self.assertFalse(job.was_terminal()) + self.assertFalse(job.in_terminal_state()) state = job.refresh_state() - self.assertFalse(job.was_terminal()) + self.assertFalse(job.in_terminal_state()) self.assertEqual(state["status"], "created") expected_state = create_state_from_ee2(JOB_CREATED) @@ -357,7 +356,7 @@ def test_refresh_state__terminal(self): test that a completed job emits its state without calling check_job """ job = create_job_from_ee2(JOB_COMPLETED) - self.assertTrue(job.was_terminal()) + self.assertTrue(job.in_terminal_state()) expected = create_state_from_ee2(JOB_COMPLETED) with assert_obj_method_called(MockClients, "check_job", call_status=False): @@ -368,53 +367,33 @@ def test_refresh_state__terminal(self): @mock.patch(CLIENTS, get_failing_mock_client) def test_refresh_state__raise_exception(self): """ - test that the correct exception is thrown if check_job cannot be called + test that the correct exception is thrown if check_jobs cannot be called """ job = create_job_from_ee2(JOB_CREATED) - self.assertFalse(job.was_terminal()) - with self.assertRaisesRegex(ServerError, "check_job failed"): + self.assertFalse(job.in_terminal_state()) + with self.assertRaisesRegex(ServerError, "check_jobs failed"): job.refresh_state() - def test_refresh_state__returns_none(self): - def mock_state(self, state=None): - return None - - job = create_job_from_ee2(JOB_CREATED) - expected = { - "status": "error", - "error": { - "code": -1, - "name": "Job Error", - "message": "Unable to return job state", - "error": "Unable to find current job state. Please try again later, or contact KBase.", - }, - "errormsg": "Unable to return job state", - "error_code": -1, - "job_id": job.job_id, - "cell_id": job.cell_id, - "run_id": job.run_id, - "created": 0, - } - - with mock.patch.object(Job, "refresh_state", mock_state): - state = job.output_state() - self.assertEqual(expected, state) - # TODO: improve this test def test_job_update__no_state(self): """ test that without a state object supplied, the job state is unchanged """ job = create_job_from_ee2(JOB_CREATED) - self.assertFalse(job.was_terminal()) + job_copy = copy.deepcopy(job) + assert job.cached_state() == job_copy.cached_state() # should fail with error 'state must be a dict' with self.assertRaisesRegex(TypeError, "state must be a dict"): job.update_state(None) - self.assertFalse(job.was_terminal()) + assert job.cached_state() == job_copy.cached_state() job.update_state({}) - self.assertFalse(job.was_terminal()) + assert job.cached_state() == job_copy.cached_state() + + job.update_state({"wsid": "something or other"}) + assert job.cached_state() != job_copy.cached_state() + assert job.wsid == "something or other" @mock.patch(CLIENTS, get_mock_client) def test_job_update__invalid_job_id(self): @@ -574,7 +553,7 @@ def test_parent_children__ok(self): children=child_jobs, ) - self.assertFalse(parent_job.was_terminal()) + self.assertFalse(parent_job.in_terminal_state()) # Make all child jobs completed with mock.patch.object( @@ -585,7 +564,7 @@ def test_parent_children__ok(self): for child_job in child_jobs: child_job.refresh_state(force_refresh=True) - self.assertTrue(parent_job.was_terminal()) + self.assertTrue(parent_job.in_terminal_state()) def test_parent_children__fail(self): parent_state = create_state_from_ee2(BATCH_PARENT) @@ -645,19 +624,19 @@ def test_get_viewer_params__batch_parent(self): self.assertIsNone(out) @mock.patch(CLIENTS, get_mock_client) - def test_query_job_state(self): + def test_query_job_states_single_job(self): for job_id in ALL_JOBS: exp = create_state_from_ee2( job_id, exclude_fields=JOB_INIT_EXCLUDED_JOB_STATE_FIELDS ) - got = Job.query_ee2_state(job_id, init=True) - self.assertEqual(exp, got) + got = Job.query_ee2_states([job_id], init=True) + assert exp == got[job_id] exp = create_state_from_ee2( job_id, exclude_fields=EXCLUDED_JOB_STATE_FIELDS ) - got = Job.query_ee2_state(job_id, init=False) - self.assertEqual(exp, got) + got = Job.query_ee2_states([job_id], init=False) + assert exp == got[job_id] @mock.patch(CLIENTS, get_mock_client) def test_query_job_states(self): @@ -731,18 +710,18 @@ def mock_check_job(self_, params): with mock.patch.object(MockClients, "check_job", mock_check_job): self.check_job_attrs(job, job_id, {"child_jobs": self.NEW_CHILD_JOBS}) - def test_was_terminal(self): + def test_in_terminal_state(self): all_jobs = get_all_jobs() for job_id, job in all_jobs.items(): - self.assertEqual(JOBS_TERMINALITY[job_id], job.was_terminal()) + self.assertEqual(JOB_TERMINAL_STATE[job_id], job.in_terminal_state()) @mock.patch(CLIENTS, get_mock_client) - def test_was_terminal__batch(self): + def test_in_terminal_state__batch(self): batch_fam = get_batch_family_jobs(return_list=True) batch_job, child_jobs = batch_fam[0], batch_fam[1:] - self.assertFalse(batch_job.was_terminal()) + self.assertFalse(batch_job.in_terminal_state()) def mock_check_job(self_, params): self.assertTrue(params["job_id"] in BATCH_CHILDREN) @@ -752,7 +731,7 @@ def mock_check_job(self_, params): for job in child_jobs: job.refresh_state(force_refresh=True) - self.assertTrue(batch_job.was_terminal()) + self.assertTrue(batch_job.in_terminal_state()) def test_in_cells(self): all_jobs = get_all_jobs() diff --git a/src/biokbase/narrative/tests/test_jobcomm.py b/src/biokbase/narrative/tests/test_jobcomm.py index ce0fc1fea9..ca49991b2e 100644 --- a/src/biokbase/narrative/tests/test_jobcomm.py +++ b/src/biokbase/narrative/tests/test_jobcomm.py @@ -618,6 +618,14 @@ def test_get_job_states__job_id_list__ee2_error(self): exc = Exception("Test exception") exc_message = str(exc) + expected = { + job_id: copy.deepcopy(ALL_RESPONSE_DATA[STATUS][job_id]) + for job_id in ALL_JOBS + } + for job_id in ACTIVE_JOBS: + # add in the ee2_error message + expected[job_id]["error"] = exc_message + def mock_check_jobs(params): raise exc @@ -628,14 +636,6 @@ def mock_check_jobs(params): self.jc._handle_comm_message(req_dict) msg = self.jc._comm.last_message - expected = { - job_id: copy.deepcopy(ALL_RESPONSE_DATA[STATUS][job_id]) - for job_id in ALL_JOBS - } - for job_id in ACTIVE_JOBS: - # add in the ee2_error message - expected[job_id]["error"] = exc_message - self.assertEqual( { "msg_type": STATUS, diff --git a/src/biokbase/narrative/tests/test_jobmanager.py b/src/biokbase/narrative/tests/test_jobmanager.py index d40a9e48c1..5f7be32d0f 100644 --- a/src/biokbase/narrative/tests/test_jobmanager.py +++ b/src/biokbase/narrative/tests/test_jobmanager.py @@ -107,7 +107,7 @@ def test_initialize_jobs(self): terminal_ids = [ job_id for job_id, d in self.jm._running_jobs.items() - if d["job"].was_terminal() + if d["job"].in_terminal_state() ] self.assertEqual( set(TERMINAL_JOBS), @@ -233,12 +233,6 @@ def test__construct_job_output_state_set__ee2_error(self): exc = Exception("Test exception") exc_msg = str(exc) - def mock_check_jobs(params): - raise exc - - with mock.patch.object(MockClients, "check_jobs", side_effect=mock_check_jobs): - job_states = self.jm._construct_job_output_state_set(ALL_JOBS) - expected = { job_id: copy.deepcopy(ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][job_id]) for job_id in ALL_JOBS @@ -248,6 +242,12 @@ def mock_check_jobs(params): # expect there to be an error message added expected[job_id]["error"] = exc_msg + def mock_check_jobs(params): + raise exc + + with mock.patch.object(MockClients, "check_jobs", side_effect=mock_check_jobs): + job_states = self.jm._construct_job_output_state_set(ALL_JOBS) + self.assertEqual( expected, job_states, @@ -403,8 +403,8 @@ def test_cancel_jobs__bad_inputs(self): def test_cancel_jobs__job_already_finished(self): self.assertEqual(get_test_job(JOB_COMPLETED)["status"], "completed") self.assertEqual(get_test_job(JOB_TERMINATED)["status"], "terminated") - self.assertTrue(self.jm.get_job(JOB_COMPLETED).was_terminal()) - self.assertTrue(self.jm.get_job(JOB_TERMINATED).was_terminal()) + self.assertTrue(self.jm.get_job(JOB_COMPLETED).in_terminal_state()) + self.assertTrue(self.jm.get_job(JOB_TERMINATED).in_terminal_state()) job_id_list = [JOB_COMPLETED, JOB_TERMINATED] with mock.patch( "biokbase.narrative.jobs.jobmanager.JobManager._cancel_job" diff --git a/src/biokbase/narrative/tests/test_narrativeio.py b/src/biokbase/narrative/tests/test_narrativeio.py index 428f3556af..cfb8e962a6 100644 --- a/src/biokbase/narrative/tests/test_narrativeio.py +++ b/src/biokbase/narrative/tests/test_narrativeio.py @@ -5,10 +5,8 @@ import unittest from unittest.mock import patch -from tornado.web import HTTPError - import biokbase.auth -import biokbase.narrative.clients as clients +from biokbase.narrative import clients from biokbase.narrative.common.exceptions import WorkspaceError from biokbase.narrative.common.narrative_ref import NarrativeRef from biokbase.narrative.common.url_config import URLS @@ -16,27 +14,26 @@ LIST_OBJECTS_FIELDS, KBaseWSManagerMixin, ) +from tornado.web import HTTPError from . import util from .narrative_mock.mockclients import MockClients, get_mock_client, get_nar_obj __author__ = "Bill Riehl " -metadata_fields = set( - [ - "objid", - "name", - "type", - "save_date", - "ver", - "saved_by", - "wsid", - "workspace", - "chsum", - "size", - "meta", - ] -) +metadata_fields = { + "objid", + "name", + "type", + "save_date", + "ver", + "saved_by", + "wsid", + "workspace", + "chsum", + "size", + "meta", +} HAS_TEST_TOKEN = False diff --git a/src/biokbase/narrative/tests/test_widgetmanager.py b/src/biokbase/narrative/tests/test_widgetmanager.py index 8fd73085fd..b075cfc457 100644 --- a/src/biokbase/narrative/tests/test_widgetmanager.py +++ b/src/biokbase/narrative/tests/test_widgetmanager.py @@ -74,6 +74,8 @@ def test_show_output_widget_bad(self): check_widget=True, ) + # N.b. the following test contacts the workspace + # src/config/config.json must be set to use the CI configuration def test_show_advanced_viewer_widget(self): title = "Widget Viewer" cell_id = "abcde" From 9c2ec3d46e178c7dd6af5c2a66b23064b998c03b Mon Sep 17 00:00:00 2001 From: ialarmedalien Date: Mon, 13 Nov 2023 07:39:33 -0800 Subject: [PATCH 2/3] Convert most of the unittest-based tests to plain `assert` or `pytest.raises` --- package-lock.json | 6 +- .../narrative/contents/narrativeio.py | 7 +- .../narrative/tests/job_test_constants.py | 52 +- .../tests/narrative_mock/mockclients.py | 18 +- .../tests/narrative_mock/mockcomm.py | 2 +- src/biokbase/narrative/tests/test_app_util.py | 23 +- .../narrative/tests/test_appeditor.py | 23 +- .../narrative/tests/test_appmanager.py | 148 ++-- src/biokbase/narrative/tests/test_auth.py | 19 +- src/biokbase/narrative/tests/test_batch.py | 168 ++-- src/biokbase/narrative/tests/test_clients.py | 6 +- .../tests/test_content_manager_util.py | 36 +- .../narrative/tests/test_exception_util.py | 63 +- src/biokbase/narrative/tests/test_job.py | 188 ++--- src/biokbase/narrative/tests/test_job_util.py | 16 +- src/biokbase/narrative/tests/test_jobcomm.py | 787 ++++++++---------- .../narrative/tests/test_jobmanager.py | 310 ++++--- .../narrative/tests/test_kbasewsmanager.py | 9 +- src/biokbase/narrative/tests/test_kvp.py | 13 + .../narrative/tests/test_log_proxy.py | 58 +- src/biokbase/narrative/tests/test_logging.py | 3 +- .../narrative/tests/test_narrative_logger.py | 31 +- .../narrative/tests/test_narrative_ref.py | 49 +- .../narrative/tests/test_narrativeio.py | 204 ++--- .../narrative/tests/test_specmanager.py | 26 +- .../narrative/tests/test_staging_helper.py | 101 ++- src/biokbase/narrative/tests/test_system.py | 3 +- src/biokbase/narrative/tests/test_upa_api.py | 37 +- src/biokbase/narrative/tests/test_updater.py | 24 +- .../narrative/tests/test_url_config.py | 6 +- .../narrative/tests/test_user_service.py | 2 +- src/biokbase/narrative/tests/test_viewers.py | 80 +- .../narrative/tests/test_widgetmanager.py | 109 ++- src/biokbase/narrative/tests/util.py | 43 +- .../test_fix_workspace_info.py | 170 ++-- 35 files changed, 1327 insertions(+), 1513 deletions(-) create mode 100644 src/biokbase/narrative/tests/test_kvp.py diff --git a/package-lock.json b/package-lock.json index 7bf13a7d96..684f7dc42b 100644 --- a/package-lock.json +++ b/package-lock.json @@ -4237,9 +4237,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001523", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001523.tgz", - "integrity": "sha512-I5q5cisATTPZ1mc588Z//pj/Ox80ERYDfR71YnvY7raS/NOk8xXlZcB0sF7JdqaV//kOaa6aus7lRfpdnt1eBA==", + "version": "1.0.30001561", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001561.tgz", + "integrity": "sha512-NTt0DNoKe958Q0BE0j0c1V9jbUzhBxHIEJy7asmGrpE0yG63KTV7PLHPnK2E1O9RsQrQ081I3NLuXGS6zht3cw==", "dev": true, "funding": [ { diff --git a/src/biokbase/narrative/contents/narrativeio.py b/src/biokbase/narrative/contents/narrativeio.py index 9c586af837..bd32e6d437 100644 --- a/src/biokbase/narrative/contents/narrativeio.py +++ b/src/biokbase/narrative/contents/narrativeio.py @@ -99,15 +99,14 @@ def narrative_exists(self, ref): except WorkspaceError as err: if err.http_code == 404: return False - else: - raise + raise def _validate_nar_type(self, t, ref): if not t.startswith(NARRATIVE_TYPE): err = "Expected a Narrative object" if ref is not None: - err += " with reference {}".format(ref) - err += ", got a {}".format(t) + err += f" with reference {ref}" + err += f", got a {t}" raise HTTPError(500, err) def read_narrative(self, ref, content=True, include_metadata=True): diff --git a/src/biokbase/narrative/tests/job_test_constants.py b/src/biokbase/narrative/tests/job_test_constants.py index 721f73ca1a..7fd5f1eb6c 100644 --- a/src/biokbase/narrative/tests/job_test_constants.py +++ b/src/biokbase/narrative/tests/job_test_constants.py @@ -1,14 +1,27 @@ import copy +from typing import Any from biokbase.narrative.jobs.job import TERMINAL_STATUSES from .util import ConfigTests config = ConfigTests() -TEST_JOBS = config.load_json_file(config.get("jobs", "ee2_job_test_data_file")) - - -def generate_error(job_id, err_type): +TEST_JOBS: dict[str, dict] = config.load_json_file( + config.get("jobs", "ee2_job_test_data_file") +) + + +def generate_error(job_id: str, err_type: str) -> str: + """Given a job id and an error type, generate the appropriate error string. + + :param job_id: job ID + :type job_id: str + :param err_type: error type + :type err_type: str + :raises KeyError: if the error type does not exist + :return: error string + :rtype: str + """ user_id = None status = None @@ -32,11 +45,25 @@ def generate_error(job_id, err_type): return error_strings[err_type] -def get_test_job(job_id): +def get_test_job(job_id: str) -> dict[str, Any]: + """Given a job ID, fetch the appropriate job. + + :param job_id: job ID + :type job_id: str + :return: job data + :rtype: dict[str, Any] + """ return copy.deepcopy(TEST_JOBS[job_id]) -def get_test_jobs(job_ids): +def get_test_jobs(job_ids: list[str]) -> dict[str, dict[str, Any]]: + """Given a list of job IDs, fetch the appropriate jobs. + + :param job_ids: list of job IDs + :type job_ids: list[str] + :return: dict of jobs keyed by job ID + :rtype: dict[str, dict[str, Any]] + """ return {job_id: get_test_job(job_id) for job_id in job_ids} @@ -93,16 +120,15 @@ def get_test_jobs(job_ids): BATCH_RETRY_ERROR, ] -BATCH_PARENT_CHILDREN = [BATCH_PARENT] + BATCH_CHILDREN +BATCH_PARENT_CHILDREN = [BATCH_PARENT, *BATCH_CHILDREN] -JOB_TERMINAL_STATE = { - job_id: TEST_JOBS[job_id]["status"] in TERMINAL_STATUSES - for job_id in TEST_JOBS.keys() +JOB_TERMINAL_STATE: dict[str, bool] = { + job_id: TEST_JOBS[job_id]["status"] in TERMINAL_STATUSES for job_id in TEST_JOBS } -TERMINAL_JOBS = [] -ACTIVE_JOBS = [] -REFRESH_STATE = {} +TERMINAL_JOBS: list[str] = [] +ACTIVE_JOBS: list[str] = [] +REFRESH_STATE: dict[str, bool] = {} for key, value in JOB_TERMINAL_STATE.items(): if value: TERMINAL_JOBS.append(key) diff --git a/src/biokbase/narrative/tests/narrative_mock/mockclients.py b/src/biokbase/narrative/tests/narrative_mock/mockclients.py index a9a97b2175..73f1f1ee44 100644 --- a/src/biokbase/narrative/tests/narrative_mock/mockclients.py +++ b/src/biokbase/narrative/tests/narrative_mock/mockclients.py @@ -81,7 +81,7 @@ def test_my_function(self): config = ConfigTests() _job_state_data = TEST_JOBS - def __init__(self, client_name=None, token=None): + def __init__(self, client_name=None, token=None) -> None: if token is not None: assert isinstance(token, str) self.client_name = client_name @@ -326,9 +326,7 @@ def log_gen(log_params, total_lines=MAX_LOG_LINES): lines = [] if skip < total_lines: for i in range(total_lines - skip): - lines.append( - {"is_error": 0, "line": "This is line {}".format(i + skip)} - ) + lines.append({"is_error": 0, "line": f"This is line {i + skip}"}) return {"last_line_number": max(total_lines, skip), "lines": lines} if job_id == JOB_COMPLETED: @@ -355,6 +353,7 @@ def log_gen(log_params, total_lines=MAX_LOG_LINES): def sync_call(self, call, params): if call == "NarrativeService.list_objects_with_sets": return self._mock_ns_list_objects_with_sets(params) + return None def _mock_ns_list_objects_with_sets(self, params): """ @@ -374,7 +373,7 @@ def _mock_ns_list_objects_with_sets(self, params): if params.get("workspaces"): ws_name = params["workspaces"][0] dp_id = 999 - dp_ref = "{}/{}".format(ws_id, dp_id) + dp_ref = f"{ws_id}/{dp_id}" data = { "data": [ @@ -511,10 +510,7 @@ def _mock_ns_list_objects_with_sets(self, params): data["data"] = list( filter( lambda x: any( - [ - x["object_info"][2].lower().startswith(t.lower()) - for t in types - ] + x["object_info"][2].lower().startswith(t.lower()) for t in types ), data["data"], ) @@ -535,7 +531,7 @@ def get_failing_mock_client(client_name, token=None): class FailingMockClient: - def __init__(self, token=None): + def __init__(self, token=None) -> None: pass def check_workspace_jobs(self, params): @@ -590,7 +586,7 @@ class assert_obj_method_called: ) """ - def __init__(self, target, method_name, call_status=True): + def __init__(self, target, method_name, call_status=True) -> None: self.target = target self.method_name = method_name self.call_status = call_status diff --git a/src/biokbase/narrative/tests/narrative_mock/mockcomm.py b/src/biokbase/narrative/tests/narrative_mock/mockcomm.py index 17457b02ea..235c59b44b 100644 --- a/src/biokbase/narrative/tests/narrative_mock/mockcomm.py +++ b/src/biokbase/narrative/tests/narrative_mock/mockcomm.py @@ -5,7 +5,7 @@ class MockComm: analyzed during the test. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: """Mock the init""" self.messages = [] diff --git a/src/biokbase/narrative/tests/test_app_util.py b/src/biokbase/narrative/tests/test_app_util.py index 46aa46821a..e3486a4f84 100644 --- a/src/biokbase/narrative/tests/test_app_util.py +++ b/src/biokbase/narrative/tests/test_app_util.py @@ -3,13 +3,10 @@ """ import copy import os +import re from unittest import mock import pytest -import re -from biokbase.narrative.common.url_config import URLS -from biokbase.workspace.client import Workspace - from biokbase.narrative.app_util import ( app_param, check_tag, @@ -19,10 +16,11 @@ map_outputs_from_state, transform_param_value, ) - -from biokbase.narrative.tests.conftest import narrative_vcr as vcr +from biokbase.narrative.common.url_config import URLS from biokbase.narrative.tests import util +from biokbase.narrative.tests.conftest import narrative_vcr as vcr from biokbase.narrative.upa import is_upa +from biokbase.workspace.client import Workspace config = util.ConfigTests() user_name = config.get("users", "test_user") @@ -74,7 +72,7 @@ def set_ws_name(ws_name): ] -@pytest.mark.parametrize("result,path,expected", get_result_sub_path_cases) +@pytest.mark.parametrize(("result", "path", "expected"), get_result_sub_path_cases) def test_get_result_sub_path(result, path, expected): assert get_result_sub_path(result, path) == expected @@ -275,7 +273,7 @@ def test_map_outputs_from_state_bad_spec(workspace_name): ] -@pytest.mark.parametrize("field_type,spec_add,expect_add", app_param_cases) +@pytest.mark.parametrize(("field_type", "spec_add", "expect_add"), app_param_cases) def test_app_param(field_type, spec_add, expect_add): spec_param = copy.deepcopy(base_app_param) expected = copy.deepcopy(base_expect) @@ -307,7 +305,8 @@ def test_app_param(field_type, spec_add, expect_add): @pytest.mark.parametrize( - "transform_type,value,spec_param,expected", transform_param_value_simple_cases + ("transform_type", "value", "spec_param", "expected"), + transform_param_value_simple_cases, ) def test_transform_param_value_simple(transform_type, value, spec_param, expected): assert transform_param_value(transform_type, value, spec_param) == expected @@ -328,7 +327,7 @@ def test_transform_param_value_fail(): ] -@pytest.mark.parametrize("value,expected", textsubdata_cases) +@pytest.mark.parametrize(("value", "expected"), textsubdata_cases) def test_transform_param_value_textsubdata(value, expected): spec = {"type": "textsubdata"} assert transform_param_value(None, value, spec) == expected @@ -506,7 +505,7 @@ def get_workspace(_): class RefChainWorkspace: - def __init__(self): + def __init__(self) -> None: pass def get_object_info3(self, params): @@ -568,7 +567,7 @@ def get_ref_path_mock_ws(name="workspace"): ) @mock.patch("biokbase.narrative.app_util.clients.get", get_ref_path_mock_ws) def test_transform_param_value_upa_path(tf_type): - upa_path = f"69375/2/2;67729/2/2" + upa_path = "69375/2/2;67729/2/2" assert transform_param_value(tf_type, upa_path, None) == upa_path diff --git a/src/biokbase/narrative/tests/test_appeditor.py b/src/biokbase/narrative/tests/test_appeditor.py index 7643ec4628..2764271e1f 100644 --- a/src/biokbase/narrative/tests/test_appeditor.py +++ b/src/biokbase/narrative/tests/test_appeditor.py @@ -5,6 +5,7 @@ import re import unittest +import pytest from biokbase.narrative.appeditor import generate_app_cell from .util import ConfigTests @@ -22,24 +23,26 @@ def setUpClass(cls): def test_gen_app_cell_post_validation(self): js = generate_app_cell(validated_spec=self.specs_list[0]) - self.assertIsNotNone(js) + assert js is not None def test_gen_app_cell_pre_valid(self): js = generate_app_cell( spec_tuple=(json.dumps(self.spec_json), self.display_yaml) ) - self.assertIsNotNone(js) - self.assertIsNotNone(js.data) - self.assertIn( - "A description string, with "quoted" values, shouldn't fail.", - js.data, + assert js is not None + assert js.data is not None + assert ( + "A description string, with "quoted" values, shouldn't fail." + in js.data ) - self.assertIn("Test Simple Inputs with "quotes"", js.data) - self.assertIn("A simple test spec with a single 'input'.", js.data) + assert "Test Simple Inputs with "quotes"" in js.data + assert "A simple test spec with a single 'input'." in js.data def test_gen_app_cell_fail_validation(self): - with self.assertRaisesRegexp( + with pytest.raises( Exception, - re.escape("Can't find sub-node [categories] within path [/] in spec.json"), + match=re.escape( + "Can't find sub-node [categories] within path [/] in spec.json" + ), ): generate_app_cell(spec_tuple=("{}", self.display_yaml)) diff --git a/src/biokbase/narrative/tests/test_appmanager.py b/src/biokbase/narrative/tests/test_appmanager.py index ef86f10110..23c1f6b84e 100644 --- a/src/biokbase/narrative/tests/test_appmanager.py +++ b/src/biokbase/narrative/tests/test_appmanager.py @@ -10,10 +10,9 @@ from unittest import mock from unittest.mock import MagicMock -from IPython.display import HTML, Javascript - +import pytest from biokbase.auth import TokenInfo -import biokbase.narrative.app_util as app_util +from biokbase.narrative import app_util from biokbase.narrative.jobs.appmanager import BATCH_APP, AppManager from biokbase.narrative.jobs.job import Job from biokbase.narrative.jobs.jobcomm import MESSAGE_TYPE @@ -24,6 +23,7 @@ READS_OBJ_1, READS_OBJ_2, ) +from IPython.display import HTML, Javascript from .narrative_mock.mockclients import WSID_STANDARD, get_mock_client from .util import ConfigTests @@ -205,52 +205,49 @@ def run_app_expect_error( """ output = io.StringIO() sys.stdout = output - self.assertIsNone(run_func()) + assert run_func() is None sys.stdout = sys.__stdout__ # reset to normal output_str = output.getvalue() if print_error is not None and len(print_error): - self.assertIn( - f"Error while trying to start your app ({func_name})!", output_str - ) - self.assertIn(print_error, output_str) + assert f"Error while trying to start your app ({func_name})!" in output_str + assert print_error in output_str else: - self.assertEqual( - "", output_str - ) # if nothing gets written to a StringIO, getvalue returns an empty string + # if nothing gets written to a StringIO, getvalue returns an empty string + assert output_str == "" self._verify_comm_error(comm_mock, cell_id=cell_id) def test_reload(self): self.am.reload() info = self.am.app_usage(self.good_app_id, self.good_tag) - self.assertTrue(info) + assert info def test_app_usage(self): # good id and good tag usage = self.am.app_usage(self.good_app_id, self.good_tag) - self.assertTrue(usage) + assert usage # bad id - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.am.app_usage(self.bad_app_id) # bad tag - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.am.app_usage(self.good_app_id, self.bad_tag) def test_app_usage_html(self): usage = self.am.app_usage(self.good_app_id, self.good_tag) - self.assertTrue(usage._repr_html_()) + assert usage._repr_html_() def test_app_usage_str(self): usage = self.am.app_usage(self.good_app_id, self.good_tag) - self.assertTrue(str(usage)) + assert str(usage) def test_available_apps_good(self): apps = self.am.available_apps(self.good_tag) - self.assertIsInstance(apps, HTML) + assert isinstance(apps, HTML) def test_available_apps_bad(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.am.available_apps(self.bad_tag) # Testing run_app @@ -275,8 +272,8 @@ def test_run_app__dry_run(self, auth, c): output = self.am.run_app( self.test_app_id, self.test_app_params, tag=self.test_tag, dry_run=True ) - self.assertEqual(expected, output) - self.assertEqual(mock_comm.call_count, 0) + assert expected == output + assert mock_comm.call_count == 0 @mock.patch(CLIENTS_AM, get_mock_client) @mock.patch(JOB_COMM_MOCK) @@ -289,11 +286,11 @@ def test_run_app__good_inputs(self, auth, c): new_job = self.am.run_app( self.test_app_id, self.test_app_params, tag=self.test_tag ) - self.assertIsInstance(new_job, Job) - self.assertEqual(self.jm.get_job(self.test_job_id), new_job) + assert isinstance(new_job, Job) + assert self.jm.get_job(self.test_job_id) == new_job self._verify_comm_success(c.return_value.send_comm_message, False) - self.assertEqual(False, self.jm._running_jobs[new_job.job_id]["refresh"]) + assert self.jm._running_jobs[new_job.job_id]["refresh"] is False @mock.patch(CLIENTS_AM, get_mock_client) @mock.patch(JOB_COMM_MOCK) @@ -304,13 +301,14 @@ def test_run_app__good_inputs(self, auth, c): def test_run_app__from_gui_cell(self, auth, c): cell_id = "12345" c.return_value.send_comm_message = MagicMock() - self.assertIsNone( + assert ( self.am.run_app( self.test_app_id, self.test_app_params, tag=self.test_tag, cell_id=cell_id, ) + is None ) self._verify_comm_success( c.return_value.send_comm_message, False, cell_id=cell_id @@ -371,7 +369,7 @@ def run_func(): ) def test_run_app__missing_inputs(self, auth, c): c.return_value.send_comm_message = MagicMock() - self.assertIsNotNone(self.am.run_app(self.good_app_id, None, tag=self.good_tag)) + assert self.am.run_app(self.good_app_id, None, tag=self.good_tag) is not None self._verify_comm_success(c.return_value.send_comm_message, False) @mock.patch(CLIENTS_AM, get_mock_client) @@ -465,7 +463,7 @@ def test_run_legacy_batch_app__dry_run_good_inputs(self, auth, c): "wsid": WSID_STANDARD, } - self.assertEqual(job_runner_inputs, expected) + assert job_runner_inputs == expected @mock.patch(CLIENTS_AM, get_mock_client) @mock.patch(JOB_COMM_MOCK) @@ -482,11 +480,11 @@ def test_run_legacy_batch_app__good_inputs(self, auth, c): version=self.test_app_version, tag=self.test_tag, ) - self.assertIsInstance(new_job, Job) - self.assertEqual(self.jm.get_job(self.test_job_id), new_job) + assert isinstance(new_job, Job) + assert self.jm.get_job(self.test_job_id) == new_job self._verify_comm_success(c.return_value.send_comm_message, False) - self.assertEqual(False, self.jm._running_jobs[new_job.job_id]["refresh"]) + assert self.jm._running_jobs[new_job.job_id]["refresh"] is False @mock.patch(CLIENTS_AM, get_mock_client) @mock.patch(JOB_COMM_MOCK) @@ -497,13 +495,14 @@ def test_run_legacy_batch_app__good_inputs(self, auth, c): def test_run_legacy_batch_app__gui_cell(self, auth, c): cell_id = "12345" c.return_value.send_comm_message = MagicMock() - self.assertIsNone( + assert ( self.am.run_legacy_batch_app( self.test_app_id, [self.test_app_params, self.test_app_params], tag=self.test_tag, cell_id=cell_id, ) + is None ) self._verify_comm_success( c.return_value.send_comm_message, False, cell_id=cell_id @@ -568,8 +567,9 @@ def run_func(): ) def test_run_legacy_batch_app__missing_inputs(self, auth, c): c.return_value.send_comm_message = MagicMock() - self.assertIsNotNone( + assert ( self.am.run_legacy_batch_app(self.good_app_id, None, tag=self.good_tag) + is not None ) self._verify_comm_success(c.return_value.send_comm_message, False) @@ -623,8 +623,8 @@ def test_run_local_app_ok(self, auth, c): {"param0": "fakegenome"}, tag="release", ) - self.assertIsInstance(result, Javascript) - self.assertIn("KBaseNarrativeOutputCell", result.data) + assert isinstance(result, Javascript) + assert "KBaseNarrativeOutputCell" in result.data @mock.patch(CLIENTS_AM, get_mock_client) @mock.patch(JOB_COMM_MOCK) @@ -733,15 +733,15 @@ def test_run_app_batch__dry_run(self, auth, c): expected_batch_run_keys = {"method", "service_ver", "params", "app_id", "meta"} # expect only the above keys in each batch run params (note the missing wsid key) for param_set in batch_run_params: - self.assertTrue(expected_batch_run_keys == set(param_set.keys())) - self.assertEqual(["wsid"], list(batch_params.keys())) + assert expected_batch_run_keys == set(param_set.keys()) + assert ["wsid"] == list(batch_params.keys()) # expect shared_params to have been merged into respective param_sets for exp, outp in zip( iter_bulk_run_good_inputs_param_sets(spec_mapped=True), batch_run_params ): got = outp["params"][0] - self.assertDictEqual({**got, **exp}, got) # assert exp_params <= got_params + assert {**got, **exp} == got def mod(param_set): for key, value in param_set.items(): @@ -766,10 +766,10 @@ def mod(param_set): ] exp_batch_params = {"wsid": WSID_STANDARD} - self.assertEqual(exp_batch_run_params, batch_run_params) - self.assertEqual(exp_batch_params, batch_params) + assert exp_batch_run_params == batch_run_params + assert exp_batch_params == batch_params - self.assertEqual(mock_comm.call_count, 0) + assert mock_comm.call_count == 0 @mock.patch(CLIENTS_AM, get_mock_client) @mock.patch(JOB_COMM_MOCK) @@ -782,23 +782,22 @@ def test_run_app_batch__good_inputs(self, auth, c): test_input = get_bulk_run_good_inputs() new_jobs = self.am.run_app_batch(test_input) - self.assertIsInstance(new_jobs, dict) - self.assertIn("parent_job", new_jobs) - self.assertIn("child_jobs", new_jobs) - self.assertTrue(new_jobs["parent_job"]) + assert isinstance(new_jobs, dict) + assert "parent_job" in new_jobs + assert "child_jobs" in new_jobs + assert new_jobs["parent_job"] parent_job = new_jobs["parent_job"] child_jobs = new_jobs["child_jobs"] - self.assertIsInstance(parent_job, Job) - self.assertIsInstance(child_jobs, list) - self.assertEqual(len(child_jobs), 3) - self.assertEqual( - [job.job_id for job in child_jobs], - [f"{self.test_job_id}_child_{i}" for i in range(len(child_jobs))], - ) + assert isinstance(parent_job, Job) + assert isinstance(child_jobs, list) + assert len(child_jobs) == 3 + assert [job.job_id for job in child_jobs] == [ + f"{self.test_job_id}_child_{i}" for i in range(len(child_jobs)) + ] self._verify_comm_success(c.return_value.send_comm_message, True, num_jobs=4) - for job in [parent_job] + child_jobs: - self.assertEqual(False, self.jm._running_jobs[job.job_id]["refresh"]) + for job in [parent_job, *child_jobs]: + assert self.jm._running_jobs[job.job_id]["refresh"] is False @mock.patch(CLIENTS_AM, get_mock_client) @mock.patch(JOB_COMM_MOCK) @@ -814,10 +813,11 @@ def test_run_app_batch__from_gui_cell(self, auth, c): # test with / w/o run_id # should return None, fire a couple of messages for run_id in run_ids: - self.assertIsNone( + assert ( self.am.run_app_batch( get_bulk_run_good_inputs(), cell_id=cell_id, run_id=run_id ) + is None ) self._verify_comm_success( @@ -969,25 +969,25 @@ def test_reconstitute_shared_params(self): # Merge shared_params into each params dict self.am._reconstitute_shared_params(app_info_el) - self.assertEqual(expected, app_info_el) + assert expected == app_info_el # No shared_params means no change self.am._reconstitute_shared_params(app_info_el) - self.assertEqual(expected, app_info_el) + assert expected == app_info_el @mock.patch(CLIENTS_AM_SM, get_mock_client) def test_app_description(self): desc = self.am.app_description(self.good_app_id, tag=self.good_tag) - self.assertIsInstance(desc, HTML) + assert isinstance(desc, HTML) @mock.patch(CLIENTS_AM_SM, get_mock_client) def test_app_description_bad_tag(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.am.app_description(self.good_app_id, tag=self.bad_tag) @mock.patch(CLIENTS_AM_SM, get_mock_client) def test_app_description_bad_name(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.am.app_description(self.bad_app_id) @mock.patch(CLIENTS_AM_SM, get_mock_client) @@ -1016,9 +1016,9 @@ def test_validate_params(self): (params, ws_inputs) = app_util.validate_parameters( app_id, tag, spec_params, inputs ) - self.assertDictEqual(params, inputs) - self.assertIn("12345/8/1", ws_inputs) - self.assertIn("12345/7/1", ws_inputs) + assert params == inputs + assert "12345/8/1" in ws_inputs + assert "12345/7/1" in ws_inputs @mock.patch(CLIENTS_AM_SM, get_mock_client) @mock.patch(CLIENTS_SM, get_mock_client) @@ -1073,13 +1073,13 @@ def test_input_mapping(self): "workspace": ws_name, } ] - self.assertDictEqual(expected[0], mapped_inputs[0]) + assert expected[0] == mapped_inputs[0] ref_path = ( ws_name + "/MyReadsSet; " + ws_name + "/rhodobacterium.art.q10.PE.reads" ) # ref_paths get mocked as 1/1/1;2/2/2;...N/N/N;18836/5/1 ret = app_util.transform_param_value("resolved-ref", ref_path, None) - self.assertEqual(ret, "1/1/1;18836/5/1") + assert ret == "1/1/1;18836/5/1" @mock.patch(CLIENTS_AM_SM, get_mock_client) def test_generate_input(self): @@ -1088,14 +1088,14 @@ def test_generate_input(self): num_symbols = 8 generator = {"symbols": num_symbols, "prefix": prefix, "suffix": suffix} rand_str = self.am._generate_input(generator) - self.assertTrue(rand_str.startswith(prefix)) - self.assertTrue(rand_str.endswith(suffix)) - self.assertEqual(len(rand_str), len(prefix) + len(suffix) + num_symbols) + assert rand_str.startswith(prefix) + assert rand_str.endswith(suffix) + assert len(rand_str) == len(prefix) + len(suffix) + num_symbols def test_generate_input_bad(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.am._generate_input({"symbols": "foo"}) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.am._generate_input({"symbols": -1}) # N.b. the following test contacts the workspace @@ -1169,10 +1169,10 @@ def test_transform_input_good(self): for test in test_data: spec = test.get("spec") ret = app_util.transform_param_value(test["type"], test["value"], spec) - self.assertEqual(ret, test["expected"]) + assert ret == test["expected"] def test_transform_input_bad(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): app_util.transform_param_value("foo", "bar", None) def _transform_comm_messages(self, comm_mock): @@ -1205,10 +1205,10 @@ def _verify_comm_error(self, comm_mock, cell_id=None, run_id=None) -> None: expected_message["content"]["cell_id"] = cell_id for key in ["error_message", "error_stacktrace"]: - self.assertTrue(key in transformed_call_args_list[0]["content"]) + assert key in transformed_call_args_list[0]["content"] del transformed_call_args_list[0]["content"][key] - self.assertEqual(transformed_call_args_list, [expected_message]) + assert transformed_call_args_list == [expected_message] def _single_messages(self, cell_id=None, run_id=None): return [ @@ -1254,7 +1254,7 @@ def _verify_comm_success( else: expected = self._single_messages(cell_id, run_id) - self.assertEqual(transformed_call_args_list, expected) + assert transformed_call_args_list == expected if __name__ == "__main__": diff --git a/src/biokbase/narrative/tests/test_auth.py b/src/biokbase/narrative/tests/test_auth.py index 040c2d4614..3047c81bdd 100644 --- a/src/biokbase/narrative/tests/test_auth.py +++ b/src/biokbase/narrative/tests/test_auth.py @@ -2,8 +2,6 @@ import os import pytest -from requests import HTTPError - from biokbase.auth import ( TokenInfo, UserInfo, @@ -18,6 +16,7 @@ ) from biokbase.narrative.common.url_config import URLS from biokbase.narrative.common.util import kbase_env +from requests import HTTPError AUTH_URL = URLS.auth + "/api/V2/" @@ -46,7 +45,7 @@ } -@pytest.fixture +@pytest.fixture() def mock_auth_call(requests_mock): def run_mock_auth( verb: str, endpoint: str, token: str, return_data: dict, status_code=200 @@ -68,9 +67,11 @@ def run_mock_auth( return run_mock_auth -@pytest.fixture +@pytest.fixture() def mock_token_endpoint(mock_auth_call): - def token_mocker(token, verb, return_info={}, status_code=200): + def token_mocker(token, verb, return_info=None, status_code=200): + if return_info is None: + return_info = {} return mock_auth_call( verb, "token", token, return_info, status_code=status_code ) @@ -78,9 +79,11 @@ def token_mocker(token, verb, return_info={}, status_code=200): return token_mocker -@pytest.fixture +@pytest.fixture() def mock_display_names_call(mock_auth_call): - def names_mocker(token, user_ids, return_info={}, status_code=200): + def names_mocker(token, user_ids, return_info=None, status_code=200): + if return_info is None: + return_info = {} return mock_auth_call( "GET", f"users/?list={','.join(user_ids)}", @@ -92,7 +95,7 @@ def names_mocker(token, user_ids, return_info={}, status_code=200): return names_mocker -@pytest.fixture +@pytest.fixture() def mock_me_call(mock_auth_call): def me_mocker(token, return_info, status_code=200): return mock_auth_call("GET", "me", token, return_info, status_code=status_code) diff --git a/src/biokbase/narrative/tests/test_batch.py b/src/biokbase/narrative/tests/test_batch.py index 56c3499895..269f7facb5 100644 --- a/src/biokbase/narrative/tests/test_batch.py +++ b/src/biokbase/narrative/tests/test_batch.py @@ -5,6 +5,7 @@ from unittest import mock import biokbase.narrative.jobs.specmanager +import pytest from biokbase.narrative.jobs.batch import ( _generate_vals, _is_singleton, @@ -37,17 +38,17 @@ def test_list_objects(self): name=t.get("name"), fuzzy_name=t.get("fuzzy_name", True), ) - self.assertEqual(len(objs), t["count"]) + assert len(objs) == t["count"] for o in objs: if t.get("type"): - self.assertTrue(o["type"].startswith(t.get("type"))) - [self.assertIn(k, o) for k in req_keys] + assert o["type"].startswith(t.get("type")) + for k in req_keys: + assert k in o @mock.patch("biokbase.narrative.jobs.batch.clients.get", get_mock_client) def test_list_objects_bad_type(self): - with self.assertRaises(ValueError) as e: + with pytest.raises(ValueError, match="is not a valid type."): list_objects(obj_type="NotAType") - self.assertIn("is not a valid type.", str(e.exception)) @mock.patch( "biokbase.narrative.jobs.batch.specmanager.clients.get", get_mock_client @@ -81,21 +82,21 @@ def test_get_input_scaffold(self): "trailing_min_quality": None, "translate_to_phred33": None, } - self.assertEqual(scaffold_standard, scaffold) + assert scaffold_standard == scaffold @mock.patch("biokbase.narrative.jobs.specmanager.clients.get", get_mock_client) def test_get_input_scaffold_bad_id(self): - with self.assertRaises(ValueError) as e: + with pytest.raises( + ValueError, match='Unknown app id "foo" tagged as "release"' + ): get_input_scaffold("foo") - self.assertIn('Unknown app id "foo" tagged as "release"', str(e.exception)) @mock.patch("biokbase.narrative.jobs.specmanager.clients.get", get_mock_client) def test_get_input_scaffold_bad_tag(self): - with self.assertRaises(ValueError) as e: + with pytest.raises( + ValueError, match="Can't find tag bar - allowed tags are release, beta, dev" + ): get_input_scaffold("foo", tag="bar") - self.assertIn( - "Can't find tag bar - allowed tags are release, beta, dev", str(e.exception) - ) @mock.patch( "biokbase.narrative.jobs.batch.specmanager.clients.get", get_mock_client @@ -128,7 +129,7 @@ def test_get_input_scaffold_defaults(self): "trailing_min_quality": "3", "translate_to_phred33": "1", } - self.assertEqual(scaffold_standard, scaffold) + assert scaffold_standard == scaffold @mock.patch("biokbase.narrative.jobs.batch.StagingHelper", MockStagingHelper) def test_list_files(self): @@ -139,7 +140,7 @@ def test_list_files(self): ] for f in name_filters: files = list_files(name=f.get("name")) - self.assertEqual(len(files), f.get("count")) + assert len(files) == f.get("count") def test__generate_vals(self): good_inputs = [ @@ -170,28 +171,28 @@ def test__generate_vals(self): ret_val = _generate_vals(i["tuple"]) if "is_float" in i: ret_val = [round(x, 2) for x in ret_val] - self.assertEqual(ret_val, i["vals"]) + assert ret_val == i["vals"] unreachable_inputs = [(-1, -1, 5), (0, -1, 10), (-20, 5, -30), (10, -1, 20)] for t in unreachable_inputs: - with self.assertRaises(ValueError) as e: + with pytest.raises( + ValueError, + match="The maximum value of this tuple will never be reached based on the interval value", + ): _generate_vals(t) - self.assertIn( - "The maximum value of this tuple will never be reached based on the interval value", - str(e.exception), - ) - with self.assertRaises(ValueError) as e: + + with pytest.raises( + ValueError, match="The input tuple must be entirely numeric" + ): _generate_vals(("a", -1, 1)) - self.assertIn("The input tuple must be entirely numeric", str(e.exception)) - with self.assertRaises(ValueError) as e: + + with pytest.raises(ValueError, match="The interval value must not be 0"): _generate_vals((10, 0, 1)) - self.assertIn("The interval value must not be 0", str(e.exception)) wrong_size = [(1,), (1, 2), (1, 2, 3, 4), ()] for s in wrong_size: - with self.assertRaises(ValueError) as e: + with pytest.raises(ValueError, match="The input tuple must have 3 values"): _generate_vals(s) - self.assertIn("The input tuple must have 3 values", str(e.exception)) def test__is_singleton(self): scalar_param = { @@ -201,10 +202,10 @@ def test__is_singleton(self): "is_output": 0, "type": "int", } - self.assertTrue(_is_singleton(1, scalar_param)) - self.assertTrue(_is_singleton("foo", scalar_param)) - self.assertFalse(_is_singleton([1, 2, 3], scalar_param)) - self.assertFalse(_is_singleton([1], scalar_param)) + assert _is_singleton(1, scalar_param) + assert _is_singleton("foo", scalar_param) + assert _is_singleton([1, 2, 3], scalar_param) is False + assert _is_singleton([1], scalar_param) is False group_param = { "allow_multiple": 0, @@ -213,8 +214,8 @@ def test__is_singleton(self): "parameter_ids": ["first", "second"], "type": "group", } - self.assertTrue(_is_singleton({"first": 1, "second": 2}, group_param)) - self.assertFalse(_is_singleton([{"first": 1, "second": 2}], group_param)) + assert _is_singleton({"first": 1, "second": 2}, group_param) is True + assert _is_singleton([{"first": 1, "second": 2}], group_param) is False list_param = { "allow_multiple": 1, @@ -222,9 +223,9 @@ def test__is_singleton(self): "is_group": False, "type": "int", } - self.assertTrue(_is_singleton(["foo"], list_param)) - self.assertFalse(_is_singleton([["a", "b"], ["c", "d"]], list_param)) - self.assertFalse(_is_singleton([["a"]], list_param)) + assert _is_singleton(["foo"], list_param) is True + assert _is_singleton([["a", "b"], ["c", "d"]], list_param) is False + assert _is_singleton([["a"]], list_param) is False @mock.patch( "biokbase.narrative.jobs.batch.specmanager.clients.get", get_mock_client @@ -234,28 +235,23 @@ def test_generate_input_batch(self): biokbase.narrative.jobs.specmanager.SpecManager().reload() app_id = "kb_trimmomatic/run_trimmomatic" # fail with no inputs - with self.assertRaises(ValueError) as e: + with pytest.raises( + ValueError, + match="No inputs were given! If you just want to build an empty input set, try get_input_scaffold.", + ): generate_input_batch(app_id) - self.assertIn( - "No inputs were given! If you just want to build an empty input set, try get_input_scaffold.", - str(e.exception), - ) # fail with bad app id - with self.assertRaises(ValueError) as e: + with pytest.raises(ValueError, match="Unknown app id"): generate_input_batch("nope") - self.assertIn("Unknown app id", str(e.exception)) # fail with bad tag, and good app id - with self.assertRaises(ValueError) as e: + with pytest.raises(ValueError, match="Can't find tag foo"): generate_input_batch(app_id, tag="foo") - self.assertIn("Can't find tag foo", str(e.exception)) # fail with no output template and no default - with self.assertRaises(ValueError) as e: + with pytest.raises(ValueError, match="No output template provided"): generate_input_batch(app_id, input_reads_ref="abcde") - self.assertIn("No output template provided", str(e.exception)) # fail with incorrect input - with self.assertRaises(ValueError) as e: + with pytest.raises(ValueError, match="is not a parameter"): generate_input_batch(app_id, not_an_input="something") - self.assertIn("is not a parameter", str(e.exception)) # a simple test, should make 4, all with same input and output strings. inputs = { @@ -265,15 +261,15 @@ def test_generate_input_batch(self): "translate_to_phred33": [0, 1], } input_batch = generate_input_batch(app_id, **inputs) - self.assertEqual(len(input_batch), 4) + assert len(input_batch) == 4 for b in input_batch: - self.assertEqual(b["input_reads_ref"], inputs["input_reads_ref"]) - self.assertEqual(b["output_reads_name"], inputs["output_reads_name"]) - self.assertIn( - b["adapter_clip"][0]["palindrome_clip_threshold"], - inputs["palindrome_clip_threshold"], + assert b["input_reads_ref"] == inputs["input_reads_ref"] + assert b["output_reads_name"] == inputs["output_reads_name"] + assert ( + b["adapter_clip"][0]["palindrome_clip_threshold"] + in inputs["palindrome_clip_threshold"] ) - self.assertIn(b["translate_to_phred33"], inputs["translate_to_phred33"]) + assert b["translate_to_phred33"] in inputs["translate_to_phred33"] # more complex test, should make several, uses ranges. inputs = { @@ -282,9 +278,8 @@ def test_generate_input_batch(self): "min_length": (0, 10, 100), } input_batch = generate_input_batch(app_id, **inputs) - self.assertEqual( - len(input_batch), 22 - ) # product of [0,10,20,30,40,50,60,70,80,90,100] and [5,7] + assert len(input_batch) == 22 + # product of [0,10,20,30,40,50,60,70,80,90,100] and [5,7] len_ranges = {} for i in range(0, 101, 10): len_ranges[i] = 0 @@ -295,24 +290,24 @@ def test_generate_input_batch(self): for b in input_batch: palindrome_ranges[b["adapter_clip"][0]["palindrome_clip_threshold"]] += 1 - self.assertIn( - b["adapter_clip"][0]["palindrome_clip_threshold"], - inputs["palindrome_clip_threshold"], + assert ( + b["adapter_clip"][0]["palindrome_clip_threshold"] + in inputs["palindrome_clip_threshold"] ) len_ranges[b["min_length"]] += 1 - self.assertIn(b["min_length"], len_ranges) - self.assertIn(b["output_reads_name"], out_strs) + assert b["min_length"] in len_ranges + assert b["output_reads_name"] in out_strs out_strs[b["output_reads_name"]] += 1 # make sure each value is used the right number of times. - self.assertEqual(len(len_ranges.keys()), 11) + assert len(len_ranges.keys()) == 11 for v in len_ranges.values(): - self.assertEqual(v, 2) - self.assertEqual(len(palindrome_ranges.keys()), 2) + assert v == 2 + assert len(palindrome_ranges.keys()) == 2 for v in palindrome_ranges.values(): - self.assertEqual(v, 11) - self.assertEqual(len(out_strs.keys()), 22) + assert v == 11 + assert len(out_strs.keys()) == 22 for v in out_strs.values(): - self.assertEqual(v, 1) + assert v == 1 def test__prepare_output_vals(self): basic_params_dict = { @@ -333,35 +328,32 @@ def test__prepare_output_vals(self): out_vals = _prepare_output_vals(o, basic_params_dict, 3) o2 = o.copy() o2.update({"other_param": "default_output${run_number}"}) - self.assertEqual(o2, out_vals) + assert o2 == out_vals # check default output value out_vals = _prepare_output_vals( {"output": "foo", "other_param": None}, basic_params_dict, 3 ) - self.assertEqual( - out_vals["other_param"], - basic_params_dict["other_param"]["default"] + "${run_number}", + assert ( + out_vals["other_param"] + == basic_params_dict["other_param"]["default"] + "${run_number}" ) # some fails - with self.assertRaises(ValueError) as e: + with pytest.raises( + ValueError, + match="The output parameter output must have 5 values if it's a list", + ): _prepare_output_vals({"output": ["a", "b", "c"]}, basic_params_dict, 5) - self.assertIn( - "The output parameter output must have 5 values if it's a list", - str(e.exception), - ) - with self.assertRaises(ValueError) as e: + with pytest.raises( + ValueError, + match='No output template provided for parameter "output" and no default value found!', + ): _prepare_output_vals({"output": None}, basic_params_dict, 3) - self.assertIn( - 'No output template provided for parameter "output" and no default value found!', - str(e.exception), - ) - with self.assertRaises(ValueError) as e: + with pytest.raises( + ValueError, + match="Output template field not_a_param doesn't match a parameter id or 'run_number'", + ): _prepare_output_vals({"output": "foo_${not_a_param}"}, basic_params_dict, 3) - self.assertIn( - "Output template field not_a_param doesn't match a parameter id or 'run_number'", - str(e.exception), - ) diff --git a/src/biokbase/narrative/tests/test_clients.py b/src/biokbase/narrative/tests/test_clients.py index 6d995ffae5..15c25a00f4 100644 --- a/src/biokbase/narrative/tests/test_clients.py +++ b/src/biokbase/narrative/tests/test_clients.py @@ -1,15 +1,13 @@ import pytest - -from biokbase.narrative import clients from biokbase.catalog.Client import Catalog as Catalog_Client from biokbase.execution_engine2.execution_engine2Client import ( execution_engine2 as EE2_Client, ) +from biokbase.narrative import clients from biokbase.narrative_method_store.client import NarrativeMethodStore as NMS_Client from biokbase.service.Client import Client as Service_Client from biokbase.workspace.client import Workspace as WS_Client - name_to_type_tests = [ ("workspace", WS_Client), ("execution_engine2", EE2_Client), @@ -19,7 +17,7 @@ ] -@pytest.mark.parametrize("client_name,client_type", name_to_type_tests) +@pytest.mark.parametrize(("client_name", "client_type"), name_to_type_tests) def test_valid_clients(client_name, client_type): client = clients.get(client_name) assert isinstance(client, client_type) diff --git a/src/biokbase/narrative/tests/test_content_manager_util.py b/src/biokbase/narrative/tests/test_content_manager_util.py index 66a65b1fa1..d1a421eb1f 100644 --- a/src/biokbase/narrative/tests/test_content_manager_util.py +++ b/src/biokbase/narrative/tests/test_content_manager_util.py @@ -8,21 +8,21 @@ def test_base_model(self): name = "foo" path = "bar" model = base_model(name, path) - self.assertIn("name", model) - self.assertEqual(model["name"], name) - self.assertIn("path", model) - self.assertEqual(model["path"], path) - self.assertIn("last_modified", model) - self.assertEqual(model["last_modified"], "00-00-0000") - self.assertIn("created", model) - self.assertEqual(model["created"], "00-00-0000") - self.assertIn("content", model) - self.assertIsNone(model["content"]) - self.assertIn("format", model) - self.assertIsNone(model["format"]) - self.assertIn("mimetype", model) - self.assertIsNone(model["mimetype"]) - self.assertIn("writable", model) - self.assertFalse(model["writable"]) - self.assertIn("type", model) - self.assertIsNone(model["type"]) + assert "name" in model + assert model["name"] == name + assert "path" in model + assert model["path"] == path + assert "last_modified" in model + assert model["last_modified"] == "00-00-0000" + assert "created" in model + assert model["created"] == "00-00-0000" + assert "content" in model + assert model["content"] is None + assert "format" in model + assert model["format"] is None + assert "mimetype" in model + assert model["mimetype"] is None + assert "writable" in model + assert not model["writable"] + assert "type" in model + assert model["type"] is None diff --git a/src/biokbase/narrative/tests/test_exception_util.py b/src/biokbase/narrative/tests/test_exception_util.py index 36f09ca646..2eb595dd6f 100644 --- a/src/biokbase/narrative/tests/test_exception_util.py +++ b/src/biokbase/narrative/tests/test_exception_util.py @@ -1,10 +1,9 @@ import unittest import requests -from requests.exceptions import HTTPError - from biokbase.execution_engine2.baseclient import ServerError as EEServerError from biokbase.narrative.exception_util import transform_job_exception +from requests.exceptions import HTTPError ERROR_MSG = "some error message" @@ -19,11 +18,11 @@ def test_transform_ee2_err(self): name = "EEError" ee2_err = EEServerError(name, code, message) nar_err = transform_job_exception(ee2_err) - self.assertEqual(nar_err.code, code) - self.assertEqual(nar_err.message, message) - self.assertEqual(nar_err.name, name) - self.assertEqual(nar_err.source, "ee2") - self.assertIsNone(nar_err.error) + assert nar_err.code == code + assert nar_err.message == message + assert nar_err.name == name + assert nar_err.source == "ee2" + assert nar_err.error is None def test_transform_ee2_err__with_error(self): code = 1000 @@ -32,11 +31,11 @@ def test_transform_ee2_err__with_error(self): error = "Unable to perform some request" ee2_err = EEServerError(name, code, message) nar_err = transform_job_exception(ee2_err, error) - self.assertEqual(nar_err.code, code) - self.assertEqual(nar_err.message, message) - self.assertEqual(nar_err.name, name) - self.assertEqual(nar_err.source, "ee2") - self.assertEqual(nar_err.error, error) + assert nar_err.code == code + assert nar_err.message == message + assert nar_err.name == name + assert nar_err.source == "ee2" + assert nar_err.error == error def test_transform_http_err_unavailable(self): codes = [404, 502, 503] @@ -47,11 +46,11 @@ def test_transform_http_err_unavailable(self): res.status_code = c err = HTTPError(HTTP_ERROR_MSG, response=res) nar_err = transform_job_exception(err) - self.assertEqual(nar_err.code, c) - self.assertEqual(nar_err.message, message) - self.assertEqual(nar_err.name, name) - self.assertEqual(nar_err.source, "network") - self.assertIsNone(nar_err.error) + assert nar_err.code == c + assert nar_err.message == message + assert nar_err.name == name + assert nar_err.source == "network" + assert nar_err.error is None def test_transform_http_err_timeout(self): codes = [504, 598, 599] @@ -62,11 +61,11 @@ def test_transform_http_err_timeout(self): res.status_code = c err = HTTPError(HTTP_ERROR_MSG, response=res) nar_err = transform_job_exception(err) - self.assertEqual(nar_err.code, c) - self.assertEqual(nar_err.message, message) - self.assertEqual(nar_err.name, name) - self.assertEqual(nar_err.source, "network") - self.assertIsNone(nar_err.error) + assert nar_err.code == c + assert nar_err.message == message + assert nar_err.name == name + assert nar_err.source == "network" + assert nar_err.error is None def test_transform_http_err_internal(self): code = 500 @@ -76,11 +75,11 @@ def test_transform_http_err_internal(self): res.status_code = code err = HTTPError(HTTP_ERROR_MSG, response=res) nar_err = transform_job_exception(err) - self.assertEqual(nar_err.code, code) - self.assertEqual(nar_err.message, message) - self.assertEqual(nar_err.name, name) - self.assertEqual(nar_err.source, "network") - self.assertIsNone(nar_err.error) + assert nar_err.code == code + assert nar_err.message == message + assert nar_err.name == name + assert nar_err.source == "network" + assert nar_err.error is None def test_transform_http_err_unknown(self): code = 666 @@ -90,8 +89,8 @@ def test_transform_http_err_unknown(self): res.status_code = code err = HTTPError(HTTP_ERROR_MSG, response=res) nar_err = transform_job_exception(err) - self.assertEqual(nar_err.code, code) - self.assertEqual(nar_err.message, message) - self.assertEqual(nar_err.name, name) - self.assertEqual(nar_err.source, "network") - self.assertIsNone(nar_err.error) + assert nar_err.code == code + assert nar_err.message == message + assert nar_err.name == name + assert nar_err.source == "network" + assert nar_err.error is None diff --git a/src/biokbase/narrative/tests/test_job.py b/src/biokbase/narrative/tests/test_job.py index 2f33cf0002..b2a1e02c23 100644 --- a/src/biokbase/narrative/tests/test_job.py +++ b/src/biokbase/narrative/tests/test_job.py @@ -1,11 +1,13 @@ import copy import itertools +import re import sys import unittest from contextlib import contextmanager from io import StringIO from unittest import mock +import pytest from biokbase.execution_engine2.baseclient import ServerError from biokbase.narrative.app_util import map_inputs_from_job, map_outputs_from_state from biokbase.narrative.jobs.job import ( @@ -29,8 +31,8 @@ JOB_COMPLETED, JOB_CREATED, JOB_RUNNING, - JOB_TERMINATED, JOB_TERMINAL_STATE, + JOB_TERMINATED, MAX_LOG_LINES, TERMINAL_JOBS, get_test_job, @@ -173,13 +175,13 @@ def setUpClass(cls): cls.NEW_CHILD_JOBS = ["cerulean", "magenta"] def check_jobs_equal(self, jobl, jobr): - self.assertEqual(jobl._acc_state, jobr._acc_state) + assert jobl._acc_state == jobr._acc_state with mock.patch(CLIENTS, get_mock_client): - self.assertEqual(jobl.refresh_state(), jobr.refresh_state()) + assert jobl.refresh_state() == jobr.refresh_state() for attr in JOB_ATTRS: - self.assertEqual(getattr(jobl, attr), getattr(jobr, attr)) + assert getattr(jobl, attr), getattr(jobr == attr) def check_job_attrs_custom(self, job, exp_attr=None): if not exp_attr: @@ -199,7 +201,7 @@ def check_job_attrs(self, job, job_id, exp_attrs=None, skip_state=False): if not exp_attrs and not skip_state: state = create_state_from_ee2(job_id) with mock.patch(CLIENTS, get_mock_client): - self.assertEqual(state, job.refresh_state()) + assert state == job.refresh_state() attrs = create_attrs_from_ee2(job_id) attrs.update(exp_attrs) @@ -208,17 +210,15 @@ def check_job_attrs(self, job, job_id, exp_attrs=None, skip_state=False): if name in ["child_jobs", "retry_ids"]: # job.child_jobs and job.retry_ids may query EE2 with mock.patch(CLIENTS, get_mock_client): - self.assertEqual(value, getattr(job, name)) + assert value == getattr(job, name) else: with assert_obj_method_called( MockClients, "check_job", call_status=False ): - self.assertEqual(value, getattr(job, name)) + assert value == getattr(job, name) def test_job_init__error_no_job_id(self): - with self.assertRaisesRegex( - ValueError, "Cannot create a job without a job ID!" - ): + with pytest.raises(ValueError, match="Cannot create a job without a job ID!"): Job({"params": {}, "app_id": "this/that"}) def test_job_init__from_job_id(self): @@ -277,7 +277,7 @@ def test_job_init__batch_family(self): self.check_job_attrs(job, job_id) batch_job = batch_jobs[BATCH_PARENT] - self.assertEqual(batch_job.job_id, batch_job.batch_id) + assert batch_job.job_id == batch_job.batch_id def test_job_from_state__custom(self): """ @@ -324,17 +324,17 @@ def test_set_job_attrs(self): job = create_job_from_ee2(JOB_COMPLETED) expected = create_state_from_ee2(JOB_COMPLETED) # job is completed so refresh_state will do nothing - self.assertEqual(job.refresh_state(), expected) + assert job.refresh_state() == expected for attr in ALL_ATTRS: - with self.assertRaisesRegex( + with pytest.raises( AttributeError, - "Job attributes must be updated using the `update_state` method", + match="Job attributes must be updated using the `update_state` method", ): setattr(job, attr, "BLAM!") # ensure nothing has changed - self.assertEqual(job.refresh_state(), expected) + assert job.refresh_state() == expected @mock.patch(CLIENTS, get_mock_client) def test_refresh_state__non_terminal(self): @@ -343,26 +343,26 @@ def test_refresh_state__non_terminal(self): """ # ee2_state is fully populated (includes job_input, no job_output) job = create_job_from_ee2(JOB_CREATED) - self.assertFalse(job.in_terminal_state()) + assert not job.in_terminal_state() state = job.refresh_state() - self.assertFalse(job.in_terminal_state()) - self.assertEqual(state["status"], "created") + assert not job.in_terminal_state() + assert state["status"] == "created" expected_state = create_state_from_ee2(JOB_CREATED) - self.assertEqual(state, expected_state) + assert state == expected_state def test_refresh_state__terminal(self): """ test that a completed job emits its state without calling check_job """ job = create_job_from_ee2(JOB_COMPLETED) - self.assertTrue(job.in_terminal_state()) + assert job.in_terminal_state() expected = create_state_from_ee2(JOB_COMPLETED) with assert_obj_method_called(MockClients, "check_job", call_status=False): state = job.refresh_state() - self.assertEqual(state["status"], "completed") - self.assertEqual(state, expected) + assert state["status"] == "completed" + assert state == expected @mock.patch(CLIENTS, get_failing_mock_client) def test_refresh_state__raise_exception(self): @@ -370,8 +370,8 @@ def test_refresh_state__raise_exception(self): test that the correct exception is thrown if check_jobs cannot be called """ job = create_job_from_ee2(JOB_CREATED) - self.assertFalse(job.in_terminal_state()) - with self.assertRaisesRegex(ServerError, "check_jobs failed"): + assert not job.in_terminal_state() + with pytest.raises(ServerError, match="check_jobs failed"): job.refresh_state() # TODO: improve this test @@ -384,7 +384,7 @@ def test_job_update__no_state(self): assert job.cached_state() == job_copy.cached_state() # should fail with error 'state must be a dict' - with self.assertRaisesRegex(TypeError, "state must be a dict"): + with pytest.raises(TypeError, match="state must be a dict"): job.update_state(None) assert job.cached_state() == job_copy.cached_state() @@ -402,10 +402,10 @@ def test_job_update__invalid_job_id(self): """ job = create_job_from_ee2(JOB_RUNNING) expected = create_state_from_ee2(JOB_RUNNING) - self.assertEqual(job.refresh_state(), expected) + assert job.refresh_state() == expected # try to update it with the job state from a different job - with self.assertRaisesRegex(ValueError, "Job ID mismatch in update_state"): + with pytest.raises(ValueError, match="Job ID mismatch in update_state"): job.update_state(get_test_job(JOB_COMPLETED)) @mock.patch(CLIENTS, get_mock_client) @@ -426,37 +426,37 @@ def test_job_info(self): ) with capture_stdout() as (out, err): job.info() - self.assertIn(info_str, out.getvalue().strip()) + assert info_str in out.getvalue().strip() def test_repr(self): job = create_job_from_ee2(JOB_COMPLETED) job_str = job.__repr__() - self.assertRegex("KBase Narrative Job - " + job.job_id, job_str) + assert re.search(job_str, "KBase Narrative Job - " + job.job_id) @mock.patch(CLIENTS, get_mock_client) def test_repr_js(self): job = create_job_from_ee2(JOB_COMPLETED) js_out = job._repr_javascript_() - self.assertIsInstance(js_out, str) + assert isinstance(js_out, str) # spot check to make sure the core pieces are present. needs the # element.html part, job_id, and widget - self.assertIn("element.html", js_out) - self.assertIn(job.job_id, js_out) - self.assertIn("kbaseNarrativeJobStatus", js_out) + assert "element.html" in js_out + assert job.job_id in js_out + assert "kbaseNarrativeJobStatus" in js_out @mock.patch("biokbase.narrative.widgetmanager.WidgetManager.show_output_widget") @mock.patch(CLIENTS, get_mock_client) def test_show_output_widget(self, mock_method): mock_method.return_value = True job = Job(get_test_job(JOB_COMPLETED)) - self.assertTrue(job.show_output_widget()) + assert job.show_output_widget() mock_method.assert_called_once() @mock.patch(CLIENTS, get_mock_client) def test_show_output_widget__incomplete_state(self): job = Job(get_test_job(JOB_CREATED)) - self.assertRegex( - job.show_output_widget(), "Job is incomplete! It has status 'created'" + assert re.search( + "Job is incomplete! It has status 'created'", job.show_output_widget() ) @mock.patch(CLIENTS, get_mock_client) @@ -468,34 +468,34 @@ def test_log(self): job = create_job_from_ee2(JOB_COMPLETED) logs = job.log() # we know there's MAX_LOG_LINES lines total, so roll with it that way. - self.assertEqual(logs[0], total_lines) - self.assertEqual(len(logs[1]), total_lines) + assert logs[0] == total_lines + assert len(logs[1]) == total_lines for i in range(len(logs[1])): line = logs[1][i] - self.assertIn("is_error", line) - self.assertIn("line", line) - self.assertIn(str(i), line["line"]) + assert "is_error" in line + assert "line" in line + assert str(i) in line["line"] # grab the last half offset = int(MAX_LOG_LINES / 2) logs = job.log(first_line=offset) - self.assertEqual(logs[0], total_lines) - self.assertEqual(len(logs[1]), offset) + assert logs[0] == total_lines + assert len(logs[1]) == offset for i in range(total_lines - offset): - self.assertIn(str(i + offset), logs[1][i]["line"]) + assert str(i + offset) in logs[1][i]["line"] # grab a bite from the middle num_fetch = int(MAX_LOG_LINES / 5) logs = job.log(first_line=offset, num_lines=num_fetch) - self.assertEqual(logs[0], total_lines) - self.assertEqual(len(logs[1]), num_fetch) + assert logs[0] == total_lines + assert len(logs[1]) == num_fetch for i in range(num_fetch): - self.assertIn(str(i + offset), logs[1][i]["line"]) + assert str(i + offset) in logs[1][i]["line"] # should normalize negative numbers properly logs = job.log(first_line=-5) - self.assertEqual(logs[0], total_lines) - self.assertEqual(len(logs[1]), total_lines) + assert logs[0] == total_lines + assert len(logs[1]) == total_lines logs = job.log(num_lines=-5) - self.assertEqual(logs[0], total_lines) - self.assertEqual(len(logs[1]), 0) + assert logs[0] == total_lines + assert len(logs[1]) == 0 @mock.patch(CLIENTS, get_mock_client) def test_parameters(self): @@ -504,14 +504,14 @@ def test_parameters(self): """ job_state = get_test_job(JOB_COMPLETED) job_params = job_state.get("job_input", {}).get("params") - self.assertIsNotNone(job_params) + assert job_params is not None job = Job(job_state) - self.assertIsNotNone(job.params) + assert job.params is not None with assert_obj_method_called(MockClients, "check_job", call_status=False): params = job.parameters() - self.assertIsNotNone(params) - self.assertEqual(params, job_params) + assert params is not None + assert params == job_params @mock.patch(CLIENTS, get_mock_client) def test_parameters__param_fetch_ok(self): @@ -521,16 +521,16 @@ def test_parameters__param_fetch_ok(self): """ job_state = get_test_job(JOB_CREATED) job_params = job_state.get("job_input", {}).get("params") - self.assertIsNotNone(job_params) + assert job_params is not None # delete the job params from the input del job_state["job_input"]["params"] job = Job(job_state) - self.assertEqual(job.params, JOB_ATTR_DEFAULTS["params"]) + assert job.params == JOB_ATTR_DEFAULTS["params"] with assert_obj_method_called(MockClients, "check_job", call_status=True): params = job.parameters() - self.assertEqual(params, job_params) + assert params == job_params @mock.patch(CLIENTS, get_failing_mock_client) def test_parameters__param_fetch_fail(self): @@ -540,9 +540,9 @@ def test_parameters__param_fetch_fail(self): job_state = get_test_job(JOB_TERMINATED) del job_state["job_input"]["params"] job = Job(job_state) - self.assertEqual(job.params, JOB_ATTR_DEFAULTS["params"]) + assert job.params == JOB_ATTR_DEFAULTS["params"] - with self.assertRaisesRegex(Exception, "Unable to fetch parameters for job"): + with pytest.raises(Exception, match="Unable to fetch parameters for job"): job.parameters() @mock.patch(CLIENTS, get_mock_client) @@ -553,7 +553,7 @@ def test_parent_children__ok(self): children=child_jobs, ) - self.assertFalse(parent_job.in_terminal_state()) + assert not parent_job.in_terminal_state() # Make all child jobs completed with mock.patch.object( @@ -564,33 +564,34 @@ def test_parent_children__ok(self): for child_job in child_jobs: child_job.refresh_state(force_refresh=True) - self.assertTrue(parent_job.in_terminal_state()) + assert parent_job.in_terminal_state() def test_parent_children__fail(self): parent_state = create_state_from_ee2(BATCH_PARENT) child_states = [create_state_from_ee2(job_id) for job_id in BATCH_CHILDREN] - with self.assertRaisesRegex( - ValueError, "Must supply children when setting children of batch job parent" + with pytest.raises( + ValueError, + match="Must supply children when setting children of batch job parent", ): Job(parent_state) child_jobs = [Job(child_state) for child_state in child_states] - with self.assertRaisesRegex(ValueError, CHILD_ID_MISMATCH): + with pytest.raises(ValueError, match=CHILD_ID_MISMATCH): Job( parent_state, children=child_jobs[1:], ) - with self.assertRaisesRegex(ValueError, CHILD_ID_MISMATCH): + with pytest.raises(ValueError, match=CHILD_ID_MISMATCH): Job( parent_state, children=child_jobs * 2, ) - with self.assertRaisesRegex(ValueError, CHILD_ID_MISMATCH): + with pytest.raises(ValueError, match=CHILD_ID_MISMATCH): Job( parent_state, - children=child_jobs + [create_job_from_ee2(JOB_COMPLETED)], + children=[*child_jobs, create_job_from_ee2(JOB_COMPLETED)], ) def test_get_viewer_params__active(self): @@ -600,7 +601,7 @@ def test_get_viewer_params__active(self): job = create_job_from_ee2(job_id) state = create_state_from_ee2(job_id) out = job.get_viewer_params(state) - self.assertIsNone(out) + assert out is None @mock.patch(CLIENTS, get_mock_client) def test_get_viewer_params__finished(self): @@ -609,7 +610,7 @@ def test_get_viewer_params__finished(self): state = create_state_from_ee2(job_id) exp = get_widget_info(job_id) got = job.get_viewer_params(state) - self.assertEqual(exp, got) + assert exp == got def test_get_viewer_params__batch_parent(self): """ @@ -621,7 +622,7 @@ def test_get_viewer_params__batch_parent(self): job = create_job_from_ee2(BATCH_PARENT, children=batch_children) out = job.get_viewer_params(state) - self.assertIsNone(out) + assert out is None @mock.patch(CLIENTS, get_mock_client) def test_query_job_states_single_job(self): @@ -645,14 +646,14 @@ def test_query_job_states(self): exp = create_state_from_ee2( job_id, exclude_fields=JOB_INIT_EXCLUDED_JOB_STATE_FIELDS ) - self.assertEqual(exp, got) + assert exp == got states = Job.query_ee2_states(ALL_JOBS, init=False) for job_id, got in states.items(): exp = create_state_from_ee2( job_id, exclude_fields=EXCLUDED_JOB_STATE_FIELDS ) - self.assertEqual(exp, got) + assert exp == got def test_refresh_attrs__non_batch_active(self): """ @@ -663,7 +664,7 @@ def test_refresh_attrs__non_batch_active(self): self.check_job_attrs(job, job_id) def mock_check_job(self_, params): - self.assertEqual(params["job_id"], job_id) + assert params["job_id"] == job_id return {"retry_ids": self.NEW_RETRY_IDS} with mock.patch.object(MockClients, "check_job", mock_check_job): @@ -678,7 +679,7 @@ def test_refresh_attrs__non_batch_terminal(self): self.check_job_attrs(job, job_id) def mock_check_job(self_, params): - self.assertEqual(params["job_id"], job_id) + assert params["job_id"] == job_id return {"retry_ids": self.NEW_RETRY_IDS} with mock.patch.object(MockClients, "check_job", mock_check_job): @@ -704,7 +705,7 @@ def test_refresh_attrs__batch(self): self.check_job_attrs(job, job_id) def mock_check_job(self_, params): - self.assertEqual(params["job_id"], job_id) + assert params["job_id"] == job_id return {"child_jobs": self.NEW_CHILD_JOBS} with mock.patch.object(MockClients, "check_job", mock_check_job): @@ -714,24 +715,24 @@ def test_in_terminal_state(self): all_jobs = get_all_jobs() for job_id, job in all_jobs.items(): - self.assertEqual(JOB_TERMINAL_STATE[job_id], job.in_terminal_state()) + assert JOB_TERMINAL_STATE[job_id] == job.in_terminal_state() @mock.patch(CLIENTS, get_mock_client) def test_in_terminal_state__batch(self): batch_fam = get_batch_family_jobs(return_list=True) batch_job, child_jobs = batch_fam[0], batch_fam[1:] - self.assertFalse(batch_job.in_terminal_state()) + assert not batch_job.in_terminal_state() def mock_check_job(self_, params): - self.assertTrue(params["job_id"] in BATCH_CHILDREN) + assert params["job_id"] in BATCH_CHILDREN return {"status": COMPLETED_STATUS} with mock.patch.object(MockClients, "check_job", mock_check_job): for job in child_jobs: job.refresh_state(force_refresh=True) - self.assertTrue(batch_job.in_terminal_state()) + assert batch_job.in_terminal_state() def test_in_cells(self): all_jobs = get_all_jobs() @@ -739,20 +740,21 @@ def test_in_cells(self): # Iterate through all combinations of cell IDs for combo_len in range(len(cell_ids) + 1): for combo in itertools.combinations(cell_ids, combo_len): - combo = list(combo) + combo_list = list(combo) # Get jobs expected to be associated with the cell IDs exp_job_ids = [ job_id for cell_id, job_ids in JOBS_BY_CELL_ID.items() for job_id in job_ids - if cell_id in combo + if cell_id in combo_list ] for job_id, job in all_jobs.items(): - self.assertEqual(job_id in exp_job_ids, job.in_cells(combo)) + expected = job_id in exp_job_ids + assert job.in_cells(combo_list) == expected def test_in_cells__none(self): job = create_job_from_ee2(JOB_COMPLETED) - with self.assertRaisesRegex(ValueError, "cell_ids cannot be None"): + with pytest.raises(ValueError, match="cell_ids cannot be None"): job.in_cells(None) def test_in_cells__batch__same_cell(self): @@ -762,9 +764,9 @@ def test_in_cells__batch__same_cell(self): for job in child_jobs: job._acc_state["job_input"]["narrative_cell_info"]["cell_id"] = "hello" - self.assertTrue(batch_job.in_cells(["hi", "hello"])) + assert batch_job.in_cells(["hi", "hello"]) - self.assertFalse(batch_job.in_cells(["goodbye", "hasta manana"])) + assert not batch_job.in_cells(["goodbye", "hasta manana"]) def test_in_cells__batch__diff_cells(self): batch_fam = get_batch_family_jobs(return_list=True) @@ -775,17 +777,17 @@ def test_in_cells__batch__diff_cells(self): job._acc_state["job_input"]["narrative_cell_info"]["cell_id"] = cell_id for cell_id in children_cell_ids: - self.assertTrue(batch_job.in_cells([cell_id])) - self.assertTrue(batch_job.in_cells(["A", cell_id, "B"])) - self.assertTrue(batch_job.in_cells([cell_id, "B", "A"])) - self.assertTrue(batch_job.in_cells(["B", "A", cell_id])) + assert batch_job.in_cells([cell_id]) + assert batch_job.in_cells(["A", cell_id, "B"]) + assert batch_job.in_cells([cell_id, "B", "A"]) + assert batch_job.in_cells(["B", "A", cell_id]) - self.assertFalse(batch_job.in_cells(["goodbye", "hasta manana"])) + assert not batch_job.in_cells(["goodbye", "hasta manana"]) def test_app_name(self): for job in get_all_jobs().values(): if job.batch_job: - self.assertEqual("batch", job.app_name) + assert job.app_name == "batch" else: test_spec = get_test_spec(job.tag, job.app_id) - self.assertEqual(test_spec["info"]["name"], job.app_name) + assert test_spec["info"]["name"] == job.app_name diff --git a/src/biokbase/narrative/tests/test_job_util.py b/src/biokbase/narrative/tests/test_job_util.py index 656d86f160..554957fc36 100644 --- a/src/biokbase/narrative/tests/test_job_util.py +++ b/src/biokbase/narrative/tests/test_job_util.py @@ -1,5 +1,6 @@ import unittest +import pytest from biokbase.narrative.jobs.util import load_job_constants @@ -14,7 +15,7 @@ def test_load_job_constants__no_file(self): "job_constants", "does_not_exist.json", ] - with self.assertRaises(FileNotFoundError): + with pytest.raises(FileNotFoundError): load_job_constants(file_path) def test_load_job_constants__missing_section(self): @@ -27,8 +28,9 @@ def test_load_job_constants__missing_section(self): "job_constants", "job_config-missing-datatype.json", ] - with self.assertRaisesRegex( - ValueError, "job_config.json is missing the 'message_types' config section" + with pytest.raises( + ValueError, + match="job_config.json is missing the 'message_types' config section", ): load_job_constants(file_path) @@ -42,9 +44,9 @@ def test_load_job_constants__missing_value(self): "job_constants", "job_config-missing-item.json", ] - with self.assertRaisesRegex( + with pytest.raises( ValueError, - "job_config.json is missing the following values for params: BATCH_ID, FIRST_LINE, JOB_ID, LATEST, NUM_LINES, TS", + match="job_config.json is missing the following values for params: BATCH_ID, FIRST_LINE, JOB_ID, LATEST, NUM_LINES, TS", ): load_job_constants(file_path) @@ -52,9 +54,9 @@ def test_load_job_constants__valid(self): # the live file! (params, message_types) = load_job_constants() for item in ["BATCH_ID", "JOB_ID"]: - self.assertIn(item, params) + assert item in params for item in ["STATUS", "RETRY", "INFO", "ERROR"]: - self.assertIn(item, message_types) + assert item in message_types if __name__ == "__main__": diff --git a/src/biokbase/narrative/tests/test_jobcomm.py b/src/biokbase/narrative/tests/test_jobcomm.py index ca49991b2e..757e41558c 100644 --- a/src/biokbase/narrative/tests/test_jobcomm.py +++ b/src/biokbase/narrative/tests/test_jobcomm.py @@ -5,6 +5,7 @@ import unittest from unittest import mock +import pytest from biokbase.narrative.exception_util import ( JobRequestException, NarrativeException, @@ -169,19 +170,16 @@ def check_error_message(self, req, err, extra_params=None): source = req msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": ERROR, - "content": { - "request": request, - "source": source, - **extra_params, - "name": type(err).__name__, - "message": str(err), - }, + assert msg == { + "msg_type": ERROR, + "content": { + "request": request, + "source": source, + **extra_params, + "name": type(err).__name__, + "message": str(err), }, - msg, - ) + } def check_job_id_list__no_jobs(self, request_type): job_id_list = [None, ""] @@ -190,12 +188,12 @@ def check_job_id_list__no_jobs(self, request_type): err = JobRequestException(JOBS_MISSING_ERR, job_id_list) # using handler - with self.assertRaisesRegex(type(err), re.escape(str(err))): + with pytest.raises(type(err), match=re.escape(str(err))): self.jc._handle_comm_message(req_dict) self.check_error_message(req_dict, err) # run directly - with self.assertRaisesRegex(type(err), re.escape(str(err))): + with pytest.raises(type(err), match=re.escape(str(err))): self.jc._msg_map[request_type](req) def check_job_id_list__dne_jobs(self, request_type, response_type=None): @@ -208,17 +206,14 @@ def check_job_id_list__dne_jobs(self, request_type, response_type=None): req_dict = make_comm_msg(request_type, job_id_list, False) self.jc._handle_comm_message(req_dict) msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": response_type if response_type else request_type, - "content": expected_output, - }, - msg, - ) + assert msg == { + "msg_type": response_type if response_type else request_type, + "content": expected_output, + } def check_id_error(self, req_dict, err): self.jc._comm.clear_message_cache() - with self.assertRaisesRegex(type(err), re.escape(str(err))): + with pytest.raises(type(err), match=re.escape(str(err))): self.jc._handle_comm_message(req_dict) self.check_error_message(req_dict, err) @@ -231,7 +226,7 @@ def check_job_id__no_job_test(self, request_type): self.check_id_error(req_dict, err) # run directly - with self.assertRaisesRegex(type(err), re.escape(str(err))): + with pytest.raises(type(err), match=re.escape(str(err))): self.jc._msg_map[request_type](req) def check_job_id__dne_test(self, request_type): @@ -261,77 +256,62 @@ def check_batch_id__not_batch_test(self, request_type): def test_send_comm_msg_ok(self): self.jc.send_comm_message("some_msg", {"foo": "bar"}) msg = self.jc._comm.last_message - self.assertEqual( - msg, - { - "msg_type": "some_msg", - "content": {"foo": "bar"}, - }, - ) + assert msg == { + "msg_type": "some_msg", + "content": {"foo": "bar"}, + } self.jc._comm.clear_message_cache() def test_send_error_msg__JobRequest(self): req = make_comm_msg("bar", "aeaeae", True) self.jc.send_error_message(req, {"extra": "field"}) msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": ERROR, - "content": { - "source": "bar", - "extra": "field", - "request": req.rq_data, - }, + assert msg == { + "msg_type": ERROR, + "content": { + "source": "bar", + "extra": "field", + "request": req.rq_data, }, - msg, - ) + } def test_send_error_msg__dict(self): req_dict = make_comm_msg("bar", "aeaeae", False) self.jc.send_error_message(req_dict, {"extra": "field"}) msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": ERROR, - "content": { - "source": "bar", - "extra": "field", - "request": req_dict["content"]["data"], - }, + assert msg == { + "msg_type": ERROR, + "content": { + "source": "bar", + "extra": "field", + "request": req_dict["content"]["data"], }, - msg, - ) + } def test_send_error_msg__None(self): self.jc.send_error_message(None, {"extra": "field"}) msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": ERROR, - "content": { - "source": None, - "extra": "field", - "request": None, - }, + assert msg == { + "msg_type": ERROR, + "content": { + "source": None, + "extra": "field", + "request": None, }, - msg, - ) + } def test_send_error_msg__str(self): source = "test_jobcomm" self.jc.send_error_message(source, {"extra": "field"}) msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": ERROR, - "content": { - "source": source, - "extra": "field", - "request": source, - }, + assert msg == { + "msg_type": ERROR, + "content": { + "source": source, + "extra": "field", + "request": source, }, - msg, - ) + } # --------------------- # Requests @@ -341,7 +321,7 @@ def test_req_no_inputs__succeed(self): req_dict = make_comm_msg(STATUS_ALL, None, False) self.jc._handle_comm_message(req_dict) msg = self.jc._comm.last_message - self.assertEqual(STATUS_ALL, msg["msg_type"]) + assert STATUS_ALL == msg["msg_type"] def test_req_no_inputs__fail(self): functions = [ @@ -357,7 +337,7 @@ def test_req_no_inputs__fail(self): for msg_type in functions: req_dict = make_comm_msg(msg_type, None, False) err = JobRequestException(ONE_INPUT_TYPE_ONLY_ERR) - with self.assertRaisesRegex(type(err), str(err)): + with pytest.raises(type(err), match=str(err)): self.jc._handle_comm_message(req_dict) self.check_error_message(req_dict, err) @@ -375,7 +355,7 @@ def test_req_multiple_inputs__fail(self): msg_type, {"job_id": "something", "batch_id": "another_thing"}, False ) err = JobRequestException(ONE_INPUT_TYPE_ONLY_ERR) - with self.assertRaisesRegex(type(err), str(err)): + with pytest.raises(type(err), match=str(err)): self.jc._handle_comm_message(req_dict) self.check_error_message(req_dict, err) @@ -384,28 +364,26 @@ def test_req_multiple_inputs__fail(self): # --------------------- @mock.patch(CLIENTS, get_mock_client) def test_start_stop_job_status_loop(self): - self.assertFalse(self.jc._running_lookup_loop) - self.assertIsNone(self.jc._lookup_timer) + assert self.jc._running_lookup_loop is False + assert self.jc._lookup_timer is None self.jc.start_job_status_loop() msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": STATUS_ALL, - "content": { - job_id: ALL_RESPONSE_DATA[STATUS][job_id] - for job_id in REFRESH_STATE - if REFRESH_STATE[job_id] - }, + assert msg == { + "msg_type": STATUS_ALL, + "content": { + job_id: ALL_RESPONSE_DATA[STATUS][job_id] + for job_id in REFRESH_STATE + if REFRESH_STATE[job_id] }, - msg, - ) - self.assertTrue(self.jc._running_lookup_loop) - self.assertIsNotNone(self.jc._lookup_timer) + } + + assert self.jc._running_lookup_loop is True + assert self.jc._lookup_timer is not None self.jc.stop_job_status_loop() - self.assertFalse(self.jc._running_lookup_loop) - self.assertIsNone(self.jc._lookup_timer) + assert self.jc._running_lookup_loop is False + assert self.jc._lookup_timer is None @mock.patch(CLIENTS, get_mock_client) def test_start_job_status_loop__cell_ids(self): @@ -413,20 +391,18 @@ def test_start_job_status_loop__cell_ids(self): # Iterate through all combinations of cell IDs for combo_len in range(len(cell_ids) + 1): for combo in itertools.combinations(cell_ids, combo_len): - combo = list(combo) - self.jm._running_jobs = {} - self.assertFalse(self.jc._running_lookup_loop) - self.assertIsNone(self.jc._lookup_timer) + assert self.jc._running_lookup_loop is False + assert self.jc._lookup_timer is None - self.jc.start_job_status_loop(init_jobs=True, cell_list=combo) + self.jc.start_job_status_loop(init_jobs=True, cell_list=list(combo)) msg = self.jc._comm.last_message exp_job_ids = [ job_id for cell_id, job_ids in JOBS_BY_CELL_ID.items() for job_id in job_ids - if cell_id in combo and REFRESH_STATE[job_id] + if cell_id in list(combo) and REFRESH_STATE[job_id] ] exp_msg = { "msg_type": "job_status_all", @@ -435,36 +411,33 @@ def test_start_job_status_loop__cell_ids(self): for job_id in exp_job_ids }, } - self.assertEqual(exp_msg, msg) + assert exp_msg == msg if exp_job_ids: - self.assertTrue(self.jc._running_lookup_loop) - self.assertTrue(self.jc._lookup_timer) + assert self.jc._running_lookup_loop + assert self.jc._lookup_timer self.jc.stop_job_status_loop() - self.assertFalse(self.jc._running_lookup_loop) - self.assertIsNone(self.jc._lookup_timer) + assert self.jc._running_lookup_loop is False + assert self.jc._lookup_timer is None @mock.patch(CLIENTS, get_failing_mock_client) def test_start_job_status_loop__initialise_jobs_error(self): # check_workspace_jobs throws an EEServerError self.jc.start_job_status_loop(init_jobs=True) - self.assertEqual( - self.jc._comm.last_message, - { - "msg_type": ERROR, - "content": { - "code": -32000, - "error": "Unable to get initial jobs list", - "message": "check_workspace_jobs failed", - "name": "JSONRPCError", - "request": "jc.start_job_status_loop", - "source": "ee2", - }, + assert self.jc._comm.last_message == { + "msg_type": ERROR, + "content": { + "code": -32000, + "error": "Unable to get initial jobs list", + "message": "check_workspace_jobs failed", + "name": "JSONRPCError", + "request": "jc.start_job_status_loop", + "source": "ee2", }, - ) - self.assertFalse(self.jc._running_lookup_loop) + } + assert self.jc._running_lookup_loop is False @mock.patch(CLIENTS, get_mock_client) def test_start_job_status_loop__no_jobs_stop_loop(self): @@ -472,16 +445,14 @@ def test_start_job_status_loop__no_jobs_stop_loop(self): self.jm._running_jobs = {} self.jm._jobs_by_cell_id = {} self.jm = JobManager() - self.assertEqual(self.jm._running_jobs, {}) + assert self.jm._running_jobs == {} # this will trigger a call to get_all_job_states # a message containing all jobs (i.e. {}) will be sent out # when it returns 0 jobs, the JobComm will run stop_job_status_loop self.jc.start_job_status_loop() - self.assertFalse(self.jc._running_lookup_loop) - self.assertIsNone(self.jc._lookup_timer) - self.assertEqual( - self.jc._comm.last_message, {"msg_type": STATUS_ALL, "content": {}} - ) + assert self.jc._running_lookup_loop is False + assert self.jc._lookup_timer is None + assert self.jc._comm.last_message == {"msg_type": STATUS_ALL, "content": {}} # --------------------- # Lookup all job states @@ -521,21 +492,18 @@ def check_job_output_states( output_states = self.jc._handle_comm_message(req_dict) msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": response_type, - "content": output_states, - }, - msg, - ) + assert msg == { + "msg_type": response_type, + "content": output_states, + } for job_id, state in output_states.items(): - self.assertEqual(ALL_RESPONSE_DATA[STATUS][job_id], state) + assert ALL_RESPONSE_DATA[STATUS][job_id] == state if job_id in ok_states: validate_job_state(state) else: # every valid job ID should be in either error_states or ok_states - self.assertIn(job_id, error_states) + assert job_id in error_states @mock.patch(CLIENTS, get_mock_client) def test_get_all_job_states__ok(self): @@ -553,8 +521,8 @@ def test_get_job_state__1_ok(self): ) def test_get_job_state__no_job(self): - with self.assertRaisesRegex( - JobRequestException, re.escape(f"{JOBS_MISSING_ERR}: {[None]}") + with pytest.raises( + JobRequestException, match=re.escape(f"{JOBS_MISSING_ERR}: {[None]}") ): self.jc.get_job_state(None) @@ -636,13 +604,10 @@ def mock_check_jobs(params): self.jc._handle_comm_message(req_dict) msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": STATUS, - "content": expected, - }, - msg, - ) + assert msg == { + "msg_type": STATUS, + "content": expected, + } # ----------------------- # get cell job states @@ -651,7 +616,7 @@ def test_get_job_states_by_cell_id__cell_id_list_none(self): cell_id_list = None req_dict = make_comm_msg(CELL_JOB_STATUS, {CELL_ID_LIST: cell_id_list}, False) err = JobRequestException(CELLS_NOT_PROVIDED_ERR) - with self.assertRaisesRegex(type(err), re.escape(str(err))): + with pytest.raises(type(err), match=re.escape(str(err))): self.jc._handle_comm_message(req_dict) self.check_error_message(req_dict, err) @@ -659,7 +624,7 @@ def test_get_job_states_by_cell_id__empty_cell_id_list(self): cell_id_list = [] req_dict = make_comm_msg(CELL_JOB_STATUS, {CELL_ID_LIST: cell_id_list}, False) err = JobRequestException(CELLS_NOT_PROVIDED_ERR) - with self.assertRaisesRegex(type(err), re.escape(str(err))): + with pytest.raises(type(err), match=re.escape(str(err))): self.jc._handle_comm_message(req_dict) self.check_error_message(req_dict, err) @@ -669,28 +634,22 @@ def test_get_job_states_by_cell_id__invalid_cell_id_list(self): req_dict = make_comm_msg(CELL_JOB_STATUS, {CELL_ID_LIST: cell_id_list}, False) self.jc._handle_comm_message(req_dict) msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": CELL_JOB_STATUS, - "content": NO_JOBS_MAPPING, - }, - msg, - ) + assert msg == { + "msg_type": CELL_JOB_STATUS, + "content": NO_JOBS_MAPPING, + } @mock.patch(CLIENTS, get_mock_client) def test_get_job_states_by_cell_id__invalid_cell_id_list_req(self): cell_id_list = ["a", "b", "c"] req_dict = make_comm_msg(CELL_JOB_STATUS, {CELL_ID_LIST: cell_id_list}, False) result = self.jc._handle_comm_message(req_dict) - self.assertEqual(result, NO_JOBS_MAPPING) + assert result == NO_JOBS_MAPPING msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": CELL_JOB_STATUS, - "content": NO_JOBS_MAPPING, - }, - msg, - ) + assert msg == { + "msg_type": CELL_JOB_STATUS, + "content": NO_JOBS_MAPPING, + } @mock.patch(CLIENTS, get_mock_client) def test_get_job_states_by_cell_id__all_results(self): @@ -703,14 +662,12 @@ def test_get_job_states_by_cell_id__all_results(self): req_dict = make_comm_msg(CELL_JOB_STATUS, {CELL_ID_LIST: cell_id_list}, False) self.jc._handle_comm_message(req_dict) msg = self.jc._comm.last_message - self.assertEqual(set(msg.keys()), set(["msg_type", "content"])) - self.assertEqual(msg["msg_type"], CELL_JOB_STATUS) - self.assertEqual(msg["content"]["jobs"], expected_states) - self.assertEqual(set(cell_id_list), set(msg["content"]["mapping"].keys())) - for key in msg["content"]["mapping"].keys(): - self.assertEqual( - set(TEST_CELL_IDs[key]), set(msg["content"]["mapping"][key]) - ) + assert set(msg.keys()), set(["msg_type" == "content"]) + assert msg["msg_type"] == CELL_JOB_STATUS + assert msg["content"]["jobs"] == expected_states + assert set(cell_id_list) == set(msg["content"]["mapping"].keys()) + for key in msg["content"]["mapping"]: + assert set(TEST_CELL_IDs[key]) == set(msg["content"]["mapping"][key]) # ----------------------- # Lookup job info @@ -722,13 +679,10 @@ def check_job_info_results(self, job_args, job_id_list): self.jc._handle_comm_message(req_dict) msg = self.jc._comm.last_message expected = {job_id: ALL_RESPONSE_DATA[INFO][job_id] for job_id in job_id_list} - self.assertEqual( - { - "msg_type": INFO, - "content": expected, - }, - msg, - ) + assert msg == { + "msg_type": INFO, + "content": expected, + } @mock.patch(CLIENTS, get_mock_client) def test_get_job_info__job_id__ok(self): @@ -787,7 +741,7 @@ def test_cancel_jobs__job_id__invalid(self): for job_id in job_id_list: req_dict = make_comm_msg(CANCEL, {JOB_ID: job_id}, False) err = JobRequestException(JOBS_MISSING_ERR, [job_id]) - with self.assertRaisesRegex(type(err), re.escape(str(err))): + with pytest.raises(type(err), match=re.escape(str(err))): self.jc._handle_comm_message(req_dict) self.check_error_message(req_dict, err) @@ -811,13 +765,13 @@ def test_cancel_jobs__job_id_list__no_jobs(self): job_id_list = None req_dict = make_comm_msg(CANCEL, {JOB_ID_LIST: job_id_list}, False) err = JobRequestException(JOBS_MISSING_ERR, job_id_list) - with self.assertRaisesRegex(type(err), str(err)): + with pytest.raises(type(err), match=str(err)): self.jc._handle_comm_message(req_dict) job_id_list = [None, ""] req_dict = make_comm_msg(CANCEL, job_id_list, False) err = JobRequestException(JOBS_MISSING_ERR, job_id_list) - with self.assertRaisesRegex(type(err), re.escape(str(err))): + with pytest.raises(type(err), match=re.escape(str(err))): self.jc._handle_comm_message(req_dict) self.check_error_message(req_dict, err) @@ -861,14 +815,11 @@ def test_cancel_jobs__job_id_list__failure(self): }, } - self.assertEqual(output, expected) - self.assertEqual( - self.jc._comm.last_message, - { - "msg_type": STATUS, - "content": expected, - }, - ) + assert output == expected + assert self.jc._comm.last_message == { + "msg_type": STATUS, + "content": expected, + } # ------------ # Retry list of jobs @@ -881,15 +832,12 @@ def check_retry_jobs(self, job_args, job_id_list): job_id: ALL_RESPONSE_DATA[RETRY][job_id] for job_id in job_id_list if job_id } retry_data = self.jc._handle_comm_message(req_dict) - self.assertEqual(expected, retry_data) + assert expected == retry_data retry_msg = self.jc._comm.pop_message() - self.assertEqual( - { - "msg_type": RETRY, - "content": expected, - }, - retry_msg, - ) + assert retry_msg == { + "msg_type": RETRY, + "content": expected, + } def test_retry_jobs__job_id__ok(self): job_id_list = [BATCH_TERMINATED_RETRIED] @@ -934,24 +882,21 @@ def test_retry_jobs__job_id_list__failure(self): generate_ee2_error(RETRY), "Unable to retry job(s)" ) - with self.assertRaisesRegex(type(err), re.escape(str(err))): + with pytest.raises(type(err), match=re.escape(str(err))): self.jc._handle_comm_message(req_dict) msg = self.jc._comm.last_message - self.assertEqual( - msg, - { - "msg_type": ERROR, - "content": { - "request": req_dict["content"]["data"], - "source": RETRY, - "name": "JSONRPCError", - "error": "Unable to retry job(s)", - "code": -32000, - "message": RETRY + " failed", - }, + assert msg == { + "msg_type": ERROR, + "content": { + "request": req_dict["content"]["data"], + "source": RETRY, + "name": "JSONRPCError", + "error": "Unable to retry job(s)", + "code": -32000, + "message": RETRY + " failed", }, - ) + } # ----------------- # Fetching job logs @@ -984,13 +929,13 @@ def test_get_job_logs__job_id__ok(self): req_dict = make_comm_msg(LOGS, [job_id], False, content) self.jc._handle_comm_message(req_dict) msg = self.jc._comm.last_message - self.assertEqual(LOGS, msg["msg_type"]) + assert LOGS == msg["msg_type"] msg_content = msg["content"][job_id] - self.assertEqual(job_id, msg_content["job_id"]) - self.assertEqual(None, msg_content["batch_id"]) - self.assertEqual(lines_available, msg_content["max_lines"]) - self.assertEqual(c[3], len(msg_content["lines"])) - self.assertEqual(c[2], msg_content["latest"]) + assert job_id == msg_content["job_id"] + assert None == msg_content["batch_id"] + assert lines_available == msg_content["max_lines"] + assert c[3] == len(msg_content["lines"]) + assert c[2] == msg_content["latest"] first = 0 if c[1] is None and c[2] is True else c[0] n_lines = c[1] if c[1] else lines_available if first < 0: @@ -998,10 +943,10 @@ def test_get_job_logs__job_id__ok(self): if c[2]: first = lines_available - min(n_lines, lines_available) - self.assertEqual(first, msg_content["first"]) + assert first == msg_content["first"] for idx, line in enumerate(msg_content["lines"]): - self.assertIn(str(first + idx), line["line"]) - self.assertEqual(0, line["is_error"]) + assert str(first + idx) in line["line"] + assert 0 == line["is_error"] @mock.patch(CLIENTS, get_mock_client) def test_get_job_logs__job_id__failure(self): @@ -1009,25 +954,22 @@ def test_get_job_logs__job_id__failure(self): req_dict = make_comm_msg(LOGS, job_id, False) self.jc._handle_comm_message(req_dict) msg = self.jc._comm.last_message - self.assertEqual( - msg, - { - "msg_type": LOGS, - "content": { - JOB_CREATED: { - "job_id": JOB_CREATED, - "batch_id": None, - "error": "Cannot find job log with id: " + JOB_CREATED, - } - }, + assert msg == { + "msg_type": LOGS, + "content": { + JOB_CREATED: { + "job_id": JOB_CREATED, + "batch_id": None, + "error": "Cannot find job log with id: " + JOB_CREATED, + } }, - ) + } def test_get_job_logs__job_id__no_job(self): job_id = None req_dict = make_comm_msg(LOGS, {JOB_ID: job_id}, False) err = JobRequestException(JOBS_MISSING_ERR, [job_id]) - with self.assertRaisesRegex(type(err), re.escape(str(err))): + with pytest.raises(type(err), match=re.escape(str(err))): self.jc._handle_comm_message(req_dict) self.check_error_message(req_dict, err) @@ -1036,18 +978,15 @@ def test_get_job_logs__job_id__job_dne(self): req_dict = make_comm_msg(LOGS, JOB_NOT_FOUND, False) self.jc._handle_comm_message(req_dict) msg = self.jc._comm.last_message - self.assertEqual( - msg, - { - "msg_type": LOGS, - "content": { - JOB_NOT_FOUND: { - "job_id": JOB_NOT_FOUND, - "error": generate_error(JOB_NOT_FOUND, "not_found"), - } - }, + assert msg == { + "msg_type": LOGS, + "content": { + JOB_NOT_FOUND: { + "job_id": JOB_NOT_FOUND, + "error": generate_error(JOB_NOT_FOUND, "not_found"), + } }, - ) + } @mock.patch(CLIENTS, get_mock_client) def test_get_job_logs__job_id_list__one_ok_one_bad_one_fetch_fail(self): @@ -1056,30 +995,27 @@ def test_get_job_logs__job_id_list__one_ok_one_bad_one_fetch_fail(self): ) self.jc._handle_comm_message(req_dict) msg = self.jc._comm.last_message - self.assertEqual(LOGS, msg["msg_type"]) - - self.assertEqual( - msg["content"], - { - JOB_COMPLETED: { - "job_id": JOB_COMPLETED, - "first": 0, - "max_lines": MAX_LOG_LINES, - "latest": False, - "batch_id": None, - "lines": LOG_LINES, - }, - JOB_CREATED: { - "job_id": JOB_CREATED, - "batch_id": None, - "error": generate_error(JOB_CREATED, "no_logs"), - }, - JOB_NOT_FOUND: { - "job_id": JOB_NOT_FOUND, - "error": generate_error(JOB_NOT_FOUND, "not_found"), - }, + assert LOGS == msg["msg_type"] + + assert msg["content"] == { + JOB_COMPLETED: { + "job_id": JOB_COMPLETED, + "first": 0, + "max_lines": MAX_LOG_LINES, + "latest": False, + "batch_id": None, + "lines": LOG_LINES, }, - ) + JOB_CREATED: { + "job_id": JOB_CREATED, + "batch_id": None, + "error": generate_error(JOB_CREATED, "no_logs"), + }, + JOB_NOT_FOUND: { + "job_id": JOB_NOT_FOUND, + "error": generate_error(JOB_NOT_FOUND, "not_found"), + }, + } @mock.patch(CLIENTS, get_mock_client) def test_get_job_logs__job_id_list__one_ok_one_bad_one_fetch_fail__with_params( @@ -1097,30 +1033,27 @@ def test_get_job_logs__job_id_list__one_ok_one_bad_one_fetch_fail__with_params( ) self.jc._handle_comm_message(req_dict) msg = self.jc._comm.last_message - self.assertEqual(LOGS, msg["msg_type"]) - - self.assertEqual( - msg["content"], - { - JOB_COMPLETED: { - "job_id": JOB_COMPLETED, - "first": first, - "max_lines": MAX_LOG_LINES, - "latest": True, - "batch_id": None, - "lines": lines, - }, - JOB_CREATED: { - "job_id": JOB_CREATED, - "batch_id": None, - "error": generate_error(JOB_CREATED, "no_logs"), - }, - JOB_NOT_FOUND: { - "job_id": JOB_NOT_FOUND, - "error": generate_error(JOB_NOT_FOUND, "not_found"), - }, + assert LOGS == msg["msg_type"] + + assert msg["content"] == { + JOB_COMPLETED: { + "job_id": JOB_COMPLETED, + "first": first, + "max_lines": MAX_LOG_LINES, + "latest": True, + "batch_id": None, + "lines": lines, }, - ) + JOB_CREATED: { + "job_id": JOB_CREATED, + "batch_id": None, + "error": generate_error(JOB_CREATED, "no_logs"), + }, + JOB_NOT_FOUND: { + "job_id": JOB_NOT_FOUND, + "error": generate_error(JOB_NOT_FOUND, "not_found"), + }, + } # ------------------------ # Modify job update @@ -1135,14 +1068,11 @@ def test_modify_job_update__job_id_list__start__ok(self): ) for job_id in ALL_JOBS: if job_id in job_id_list: - self.assertEqual(self.jm._running_jobs[job_id]["refresh"], True) + assert self.jm._running_jobs[job_id]["refresh"] else: - self.assertEqual( - self.jm._running_jobs[job_id]["refresh"], - REFRESH_STATE[job_id], - ) - self.assertTrue(self.jc._lookup_timer) - self.assertTrue(self.jc._running_lookup_loop) + assert self.jm._running_jobs[job_id]["refresh"] == REFRESH_STATE[job_id] + assert self.jc._lookup_timer + assert self.jc._running_lookup_loop @mock.patch(CLIENTS, get_mock_client) def test_modify_job_update__job_id_list__stop__ok(self): @@ -1154,23 +1084,17 @@ def test_modify_job_update__job_id_list__stop__ok(self): ) for job_id in ALL_JOBS: if job_id in job_id_list: - self.assertEqual( - self.jm._running_jobs[job_id]["refresh"], - False, - ) + assert self.jm._running_jobs[job_id]["refresh"] is False else: - self.assertEqual( - self.jm._running_jobs[job_id]["refresh"], - REFRESH_STATE[job_id], - ) - self.assertIsNone(self.jc._lookup_timer) - self.assertFalse(self.jc._running_lookup_loop) + assert self.jm._running_jobs[job_id]["refresh"] == REFRESH_STATE[job_id] + assert self.jc._lookup_timer is None + assert self.jc._running_lookup_loop is False def test_modify_job_update__job_id_list__no_jobs(self): job_id_list = [None] req_dict = make_comm_msg(START_UPDATE, job_id_list, False) err = JobRequestException(JOBS_MISSING_ERR, job_id_list) - with self.assertRaisesRegex(type(err), re.escape(str(err))): + with pytest.raises(type(err), match=re.escape(str(err))): self.jc._handle_comm_message(req_dict) self.check_error_message(req_dict, err) @@ -1185,14 +1109,11 @@ def test_modify_job_update__job_id_list__stop__ok_bad_job(self): for job_id in ALL_JOBS: if job_id in job_id_list: - self.assertEqual(self.jm._running_jobs[job_id]["refresh"], False) + assert self.jm._running_jobs[job_id]["refresh"] is False else: - self.assertEqual( - self.jm._running_jobs[job_id]["refresh"], - REFRESH_STATE[job_id], - ) - self.assertIsNone(self.jc._lookup_timer) - self.assertFalse(self.jc._running_lookup_loop) + assert self.jm._running_jobs[job_id]["refresh"] == REFRESH_STATE[job_id] + assert self.jc._lookup_timer is None + assert self.jc._running_lookup_loop is False @mock.patch(CLIENTS, get_mock_client) def test_modify_job_update__job_id_list__stop__loop_still_running(self): @@ -1207,14 +1128,11 @@ def test_modify_job_update__job_id_list__stop__loop_still_running(self): ) for job_id in ALL_JOBS: if job_id in job_id_list: - self.assertEqual(self.jm._running_jobs[job_id]["refresh"], False) + assert self.jm._running_jobs[job_id]["refresh"] is False else: - self.assertEqual( - self.jm._running_jobs[job_id]["refresh"], - REFRESH_STATE[job_id], - ) - self.assertTrue(self.jc._lookup_timer) - self.assertTrue(self.jc._running_lookup_loop) + assert self.jm._running_jobs[job_id]["refresh"] == REFRESH_STATE[job_id] + assert self.jc._lookup_timer + assert self.jc._running_lookup_loop # ------------------------ # Modify job update batch @@ -1230,14 +1148,11 @@ def test_modify_job_update__batch_id__start__ok(self): ) for job_id in ALL_JOBS: if job_id in job_id_list: - self.assertEqual(self.jm._running_jobs[job_id]["refresh"], True) + assert self.jm._running_jobs[job_id]["refresh"] else: - self.assertEqual( - self.jm._running_jobs[job_id]["refresh"], - REFRESH_STATE[job_id], - ) - self.assertTrue(self.jc._lookup_timer) - self.assertTrue(self.jc._running_lookup_loop) + assert self.jm._running_jobs[job_id]["refresh"] == REFRESH_STATE[job_id] + assert self.jc._lookup_timer + assert self.jc._running_lookup_loop @mock.patch(CLIENTS, get_mock_client) def test_modify_job_update__batch_id__stop__ok(self): @@ -1250,14 +1165,11 @@ def test_modify_job_update__batch_id__stop__ok(self): ) for job_id in ALL_JOBS: if job_id in job_id_list: - self.assertEqual(self.jm._running_jobs[job_id]["refresh"], False) + assert self.jm._running_jobs[job_id]["refresh"] is False else: - self.assertEqual( - self.jm._running_jobs[job_id]["refresh"], - REFRESH_STATE[job_id], - ) - self.assertIsNone(self.jc._lookup_timer) - self.assertFalse(self.jc._running_lookup_loop) + assert self.jm._running_jobs[job_id]["refresh"] == REFRESH_STATE[job_id] + assert self.jc._lookup_timer is None + assert self.jc._running_lookup_loop is False def test_modify_job_update__batch_id__no_job(self): self.check_batch_id__no_job_test(START_UPDATE) @@ -1275,16 +1187,17 @@ def test_modify_job_update__batch_id__not_batch(self): # Handle bad comm messages # ------------------------ def test_handle_comm_message_bad(self): - with self.assertRaisesRegex(JobRequestException, INVALID_REQUEST_ERR): + with pytest.raises(JobRequestException, match=INVALID_REQUEST_ERR): self.jc._handle_comm_message({"foo": "bar"}) - with self.assertRaisesRegex(JobRequestException, MISSING_REQUEST_TYPE_ERR): + with pytest.raises(JobRequestException, match=MISSING_REQUEST_TYPE_ERR): self.jc._handle_comm_message({"content": {"data": {"request_type": None}}}) def test_handle_comm_message_unknown(self): unknown = "NotAJobRequest" - with self.assertRaisesRegex( - JobRequestException, re.escape(f"Unknown KBaseJobs message '{unknown}'") + with pytest.raises( + JobRequestException, + match=re.escape(f"Unknown KBaseJobs message '{unknown}'"), ): self.jc._handle_comm_message( {"content": {"data": {"request_type": unknown}}} @@ -1304,13 +1217,13 @@ def test_request_ok(self): "content": {"data": {"request_type": "a_request"}}, } rq = JobRequest(rq_msg) - self.assertEqual(rq.msg_id, "some_id") - self.assertEqual(rq.request_type, "a_request") - self.assertEqual(rq.raw_request, rq_msg) - self.assertEqual(rq.rq_data, {"request_type": "a_request"}) - with self.assertRaisesRegex(JobRequestException, ONE_INPUT_TYPE_ONLY_ERR): + assert rq.msg_id == "some_id" + assert rq.request_type == "a_request" + assert rq.raw_request == rq_msg + assert rq.rq_data == {"request_type": "a_request"} + with pytest.raises(JobRequestException, match=ONE_INPUT_TYPE_ONLY_ERR): rq.job_id - with self.assertRaisesRegex(JobRequestException, ONE_INPUT_TYPE_ONLY_ERR): + with pytest.raises(JobRequestException, match=ONE_INPUT_TYPE_ONLY_ERR): rq.job_id_list def test_request_no_data(self): @@ -1319,7 +1232,7 @@ def test_request_no_data(self): rq_msg3 = {"msg_id": "some_id", "content": {"data": None}} rq_msg4 = {"msg_id": "some_id", "content": {"what": "?"}} for msg in [rq_msg1, rq_msg2, rq_msg3, rq_msg4]: - with self.assertRaisesRegex(JobRequestException, INVALID_REQUEST_ERR): + with pytest.raises(JobRequestException, match=INVALID_REQUEST_ERR): JobRequest(msg) def test_request_no_req(self): @@ -1327,7 +1240,7 @@ def test_request_no_req(self): rq_msg2 = {"msg_id": "some_id", "content": {"data": {"request_type": ""}}} rq_msg3 = {"msg_id": "some_id", "content": {"data": {"what": {}}}} for msg in [rq_msg1, rq_msg2, rq_msg3]: - with self.assertRaisesRegex(JobRequestException, MISSING_REQUEST_TYPE_ERR): + with pytest.raises(JobRequestException, match=MISSING_REQUEST_TYPE_ERR): JobRequest(msg) def test_request_more_than_one_input(self): @@ -1342,7 +1255,7 @@ def test_request_more_than_one_input(self): ) for co in combos: msg = make_comm_msg(STATUS, {**co[0], **co[1]}, False) - with self.assertRaisesRegex(JobRequestException, ONE_INPUT_TYPE_ONLY_ERR): + with pytest.raises(JobRequestException, match=ONE_INPUT_TYPE_ONLY_ERR): JobRequest(msg) # all three @@ -1355,20 +1268,20 @@ def test_request_more_than_one_input(self): }, False, ) - with self.assertRaisesRegex(JobRequestException, ONE_INPUT_TYPE_ONLY_ERR): + with pytest.raises(JobRequestException, match=ONE_INPUT_TYPE_ONLY_ERR): JobRequest(msg) def test_request__no_input(self): msg = make_comm_msg(STATUS, {}, False) req = JobRequest(msg) - with self.assertRaisesRegex(JobRequestException, ONE_INPUT_TYPE_ONLY_ERR): + with pytest.raises(JobRequestException, match=ONE_INPUT_TYPE_ONLY_ERR): req.job_id - with self.assertRaisesRegex(JobRequestException, ONE_INPUT_TYPE_ONLY_ERR): + with pytest.raises(JobRequestException, match=ONE_INPUT_TYPE_ONLY_ERR): req.job_id_list - with self.assertRaisesRegex(JobRequestException, ONE_INPUT_TYPE_ONLY_ERR): + with pytest.raises(JobRequestException, match=ONE_INPUT_TYPE_ONLY_ERR): req.batch_id - with self.assertRaisesRegex(JobRequestException, CELLS_NOT_PROVIDED_ERR): + with pytest.raises(JobRequestException, match=CELLS_NOT_PROVIDED_ERR): req.cell_id_list @@ -1412,23 +1325,20 @@ def test_with_nested_try__raise(self): def f(): raise RuntimeError(message) - with self.assertRaisesRegex(RuntimeError, message): + with pytest.raises(RuntimeError, match=message): f_var = [] self.bar(req, f, f_var) - self.assertEqual(["A"], f_var) + assert ["A"] == f_var msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": ERROR, - "content": { - "source": req_type, - "name": "RuntimeError", - "message": message, - "request": req.rq_data, - }, + assert { + "msg_type": ERROR, + "content": { + "source": req_type, + "name": "RuntimeError", + "message": message, + "request": req.rq_data, }, - msg, - ) + } == msg def test_with_nested_try__succeed(self): job_id_list = [BATCH_ERROR_RETRIED, JOB_RUNNING] @@ -1445,9 +1355,9 @@ def f(): f_var = [] self.bar(req, f, f_var) - self.assertEqual(["B", "C"], f_var) + assert ["B", "C"] == f_var msg = self.jc._comm.last_message - self.assertIsNone(msg) + assert msg is None def test_NarrativeException(self): job_id_list = BATCH_CHILDREN @@ -1464,24 +1374,21 @@ def test_NarrativeException(self): def f(): raise transform_job_exception(Exception(message), error) - with self.assertRaisesRegex(NarrativeException, message): + with pytest.raises(NarrativeException, match=message): self.foo(req, f) msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": ERROR, - "content": { - "request": req.rq_data, - "source": req_type, - # Below are from transform_job_exception - "name": "Exception", - "message": message, - "error": error, - "code": -1, - }, + assert msg == { + "msg_type": ERROR, + "content": { + "request": req.rq_data, + "source": req_type, + # Below are from transform_job_exception + "name": "Exception", + "message": message, + "error": error, + "code": -1, }, - msg, - ) + } def test_JobRequestException(self): job_id = BATCH_PARENT @@ -1497,21 +1404,18 @@ def test_JobRequestException(self): def f(): raise JobRequestException(message, "a0a0a0") - with self.assertRaisesRegex(JobRequestException, f"{message}: a0a0a0"): + with pytest.raises(JobRequestException, match=f"{message}: a0a0a0"): self.foo(req, f) msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": ERROR, - "content": { - "request": req.rq_data, - "source": req_type, - "name": "JobRequestException", - "message": f"{message}: a0a0a0", - }, + assert msg == { + "msg_type": ERROR, + "content": { + "request": req.rq_data, + "source": req_type, + "name": "JobRequestException", + "message": f"{message}: a0a0a0", }, - msg, - ) + } def test_ValueError(self): job_id_list = [JOB_RUNNING, JOB_COMPLETED] @@ -1527,21 +1431,18 @@ def test_ValueError(self): def f(): raise ValueError(message) - with self.assertRaisesRegex(ValueError, message): + with pytest.raises(ValueError, match=message): self.foo(req, f) msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": ERROR, - "content": { - "request": req.rq_data, - "source": req_type, - "name": "ValueError", - "message": message, - }, + assert msg == { + "msg_type": ERROR, + "content": { + "request": req.rq_data, + "source": req_type, + "name": "ValueError", + "message": message, }, - msg, - ) + } def test_dict_req__no_err(self): job_id = JOB_ERROR @@ -1559,7 +1460,7 @@ def f(): self.foo(req_dict, f) msg = self.jc._comm.last_message - self.assertIsNone(msg) + assert msg is None def test_dict_req__error_down_the_stack(self): job_id = JOB_CREATED @@ -1579,21 +1480,18 @@ def f(i=5): raise ValueError(message) f(i - 1) - with self.assertRaisesRegex(ValueError, message): + with pytest.raises(ValueError, match=message): self.foo(req_dict, f) msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": ERROR, - "content": { - "request": req_dict["content"]["data"], - "source": req_type, - "name": "ValueError", - "message": message, - }, + assert msg == { + "msg_type": ERROR, + "content": { + "request": req_dict["content"]["data"], + "source": req_type, + "name": "ValueError", + "message": message, }, - msg, - ) + } def test_dict_req__both_inputs(self): req_type = STATUS @@ -1610,21 +1508,18 @@ def test_dict_req__both_inputs(self): def f(): raise ValueError(message) - with self.assertRaisesRegex(ValueError, message): + with pytest.raises(ValueError, match=message): self.foo(req_dict, f) msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": ERROR, - "content": { - "request": req_dict["content"]["data"], - "source": req_type, - "name": "ValueError", - "message": message, - }, + assert msg == { + "msg_type": ERROR, + "content": { + "request": req_dict["content"]["data"], + "source": req_type, + "name": "ValueError", + "message": message, }, - msg, - ) + } def test_None_req(self): source = None @@ -1634,21 +1529,18 @@ def test_None_req(self): def f(): raise err - with self.assertRaisesRegex(type(err), str(err)): + with pytest.raises(type(err), match=str(err)): self.foo(source, f) msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": ERROR, - "content": { - "source": source, - "name": "ValueError", - "message": message, - "request": None, - }, + assert msg == { + "msg_type": ERROR, + "content": { + "source": source, + "name": "ValueError", + "message": message, + "request": None, }, - msg, - ) + } def test_str_req(self): source = "test_jobcomm" @@ -1658,18 +1550,15 @@ def test_str_req(self): def f(): raise err - with self.assertRaisesRegex(type(err), str(err)): + with pytest.raises(type(err), match=str(err)): self.foo(source, f) msg = self.jc._comm.last_message - self.assertEqual( - { - "msg_type": ERROR, - "content": { - "source": source, - "name": "ValueError", - "message": message, - "request": source, - }, + assert msg == { + "msg_type": ERROR, + "content": { + "source": source, + "name": "ValueError", + "message": message, + "request": source, }, - msg, - ) + } diff --git a/src/biokbase/narrative/tests/test_jobmanager.py b/src/biokbase/narrative/tests/test_jobmanager.py index 5f7be32d0f..27da6d7cac 100644 --- a/src/biokbase/narrative/tests/test_jobmanager.py +++ b/src/biokbase/narrative/tests/test_jobmanager.py @@ -6,8 +6,7 @@ from datetime import datetime from unittest import mock -from IPython.display import HTML - +import pytest from biokbase.narrative.exception_util import JobRequestException, NarrativeException from biokbase.narrative.jobs.job import ( EXCLUDED_JOB_STATE_FIELDS, @@ -51,6 +50,7 @@ generate_error, get_test_job, ) +from IPython.display import HTML from .narrative_mock.mockclients import ( MockClients, @@ -87,14 +87,14 @@ def reset_job_manager(self): self.jm._jobs_by_cell_id = {} self.jm = JobManager() - self.assertEqual(self.jm._running_jobs, {}) - self.assertEqual(self.jm._jobs_by_cell_id, {}) + assert self.jm._running_jobs == {} + assert self.jm._jobs_by_cell_id == {} @mock.patch(CLIENTS, get_failing_mock_client) def test_initialize_jobs_ee2_fail(self): # init jobs should fail. specifically, ee2.check_workspace_jobs should error. - with self.assertRaisesRegex( - NarrativeException, re.escape("check_workspace_jobs failed") + with pytest.raises( + NarrativeException, match=re.escape("check_workspace_jobs failed") ): self.jm.initialize_jobs() @@ -109,18 +109,15 @@ def test_initialize_jobs(self): for job_id, d in self.jm._running_jobs.items() if d["job"].in_terminal_state() ] - self.assertEqual( - set(TERMINAL_JOBS), - set(terminal_ids), - ) - self.assertEqual(set(ALL_JOBS), set(self.jm._running_jobs.keys())) + assert set(TERMINAL_JOBS) == set(terminal_ids) + assert set(ALL_JOBS) == set(self.jm._running_jobs.keys()) for job_id in TERMINAL_IDS: - self.assertFalse(self.jm._running_jobs[job_id]["refresh"]) + assert self.jm._running_jobs[job_id]["refresh"] is False for job_id in NON_TERMINAL_IDS: - self.assertTrue(self.jm._running_jobs[job_id]["refresh"]) + assert self.jm._running_jobs[job_id]["refresh"] is True - self.assertEqual(self.jm._jobs_by_cell_id, JOBS_BY_CELL_ID) + assert self.jm._jobs_by_cell_id == JOBS_BY_CELL_ID @mock.patch(CLIENTS, get_mock_client) def test_initialize_jobs__cell_ids(self): @@ -131,48 +128,48 @@ def test_initialize_jobs__cell_ids(self): # Iterate through all combinations of cell IDs for combo_len in range(len(cell_ids) + 1): for combo in itertools.combinations(cell_ids, combo_len): - combo = list(combo) + combo_list = list(combo) # Get jobs expected to be associated with the cell IDs exp_job_ids = [ job_id for cell_id, job_ids in JOBS_BY_CELL_ID.items() for job_id in job_ids - if cell_id in combo + if cell_id in combo_list ] + self.reset_job_manager() - self.jm.initialize_jobs(cell_ids=combo) + self.jm.initialize_jobs(cell_ids=combo_list) for job_id, d in self.jm._running_jobs.items(): - refresh = d["refresh"] - - self.assertEqual( - job_id in exp_job_ids and REFRESH_STATE[job_id], - refresh, + expected_refresh_state = ( + job_id in exp_job_ids and REFRESH_STATE[job_id] ) + assert expected_refresh_state == d["refresh"] def test__check_job_list_fail__wrong_type(self): - with self.assertRaisesRegex(JobRequestException, f"{JOBS_MISSING_ERR}: {{}}"): + with pytest.raises(JobRequestException, match=f"{JOBS_MISSING_ERR}: {{}}"): self.jm._check_job_list({}) def test__check_job_list_fail__none(self): - with self.assertRaisesRegex(JobRequestException, f"{JOBS_MISSING_ERR}: {None}"): + with pytest.raises(JobRequestException, match=f"{JOBS_MISSING_ERR}: {None}"): self.jm._check_job_list(None) def test__check_job_list_fail__no_args(self): - with self.assertRaisesRegex( - JobRequestException, re.escape(f"{JOBS_MISSING_ERR}: {None}") + with pytest.raises( + JobRequestException, match=re.escape(f"{JOBS_MISSING_ERR}: {None}") ): self.jm._check_job_list() def test__check_job_list_fail__empty(self): - with self.assertRaisesRegex( - JobRequestException, re.escape(f"{JOBS_MISSING_ERR}: {[]}") + with pytest.raises( + JobRequestException, match=re.escape(f"{JOBS_MISSING_ERR}: {[]}") ): self.jm._check_job_list([]) def test__check_job_list_fail__nonsense_list_items(self): - with self.assertRaisesRegex( - JobRequestException, re.escape(f'{JOBS_MISSING_ERR}: {["", "", None]}') + with pytest.raises( + JobRequestException, + match=re.escape(f'{JOBS_MISSING_ERR}: {["", "", None]}'), ): self.jm._check_job_list(["", "", None]) @@ -183,50 +180,39 @@ def test__check_job_list(self): job_b = JOB_COMPLETED job_c = "job_c" job_d = "job_d" - self.assertEqual( - self.jm._check_job_list([job_c]), - ( - [], - [job_c], - ), + assert self.jm._check_job_list([job_c]) == ( + [], + [job_c], ) - self.assertEqual( - self.jm._check_job_list([job_c, None, "", job_c, job_c, None, job_d]), - ( - [], - [job_c, job_d], - ), + assert self.jm._check_job_list( + [job_c, None, "", job_c, job_c, None, job_d] + ) == ( + [], + [job_c, job_d], ) - self.assertEqual( - self.jm._check_job_list([job_c, None, "", None, job_a, job_a, job_a]), - ( - [job_a], - [job_c], - ), + assert self.jm._check_job_list( + [job_c, None, "", None, job_a, job_a, job_a] + ) == ( + [job_a], + [job_c], ) - self.assertEqual( - self.jm._check_job_list([None, job_a, None, "", None, job_b]), - ( - [job_a, job_b], - [], - ), + assert self.jm._check_job_list([None, job_a, None, "", None, job_b]) == ( + [job_a, job_b], + [], ) @mock.patch(CLIENTS, get_mock_client) def test__construct_job_output_state_set(self): - self.assertEqual( - self.jm._construct_job_output_state_set(ALL_JOBS), - { - job_id: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][job_id] - for job_id in ALL_JOBS - }, - ) + assert self.jm._construct_job_output_state_set(ALL_JOBS) == { + job_id: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][job_id] + for job_id in ALL_JOBS + } def test__construct_job_output_state_set__empty_list(self): - self.assertEqual(self.jm._construct_job_output_state_set([]), {}) + assert self.jm._construct_job_output_state_set([]) == {} @mock.patch(CLIENTS, get_mock_client) def test__construct_job_output_state_set__ee2_error(self): @@ -248,37 +234,34 @@ def mock_check_jobs(params): with mock.patch.object(MockClients, "check_jobs", side_effect=mock_check_jobs): job_states = self.jm._construct_job_output_state_set(ALL_JOBS) - self.assertEqual( - expected, - job_states, - ) + assert expected == job_states def test__create_jobs__empty_list(self): - self.assertEqual(self.jm._create_jobs([]), {}) + assert self.jm._create_jobs([]) == {} def test__create_jobs__jobs_already_exist(self): job_list = self.jm._running_jobs.keys() - self.assertEqual(self.jm._create_jobs(job_list), {}) + assert self.jm._create_jobs(job_list) == {} def test__get_job_good(self): job_id = ALL_JOBS[0] job = self.jm.get_job(job_id) - self.assertEqual(job_id, job.job_id) - self.assertIsInstance(job, Job) + assert job_id == job.job_id + assert isinstance(job, Job) def test__get_job_fail(self): inputs = [None, "", JOB_NOT_FOUND] for bad_input in inputs: - with self.assertRaisesRegex( - JobRequestException, f"{JOB_NOT_REG_ERR}: {bad_input}" + with pytest.raises( + JobRequestException, match=f"{JOB_NOT_REG_ERR}: {bad_input}" ): self.jm.get_job(bad_input) @mock.patch(CLIENTS, get_mock_client) def test_list_jobs_html(self): jobs_html = self.jm.list_jobs() - self.assertIsInstance(jobs_html, HTML) + assert isinstance(jobs_html, HTML) html = jobs_html.data counts = { @@ -309,28 +292,28 @@ def test_list_jobs_html(self): n_not_started += 1 for job_id in ALL_JOBS: - self.assertIn(f'{job_id}', html) + assert f'{job_id}' in html for param in counts: for value in counts[param]: - self.assertIn(f'{str(value)}', html) + assert f'{str(value)}' in html value_count = html.count(f'{str(value)}') - self.assertEqual(counts[param][value], value_count) + assert counts[param][value] == value_count if n_incomplete: incomplete_count = html.count('Incomplete') - self.assertEqual(incomplete_count, n_incomplete) + assert incomplete_count == n_incomplete if n_not_started: not_started_count = html.count('Not started') - self.assertEqual(not_started_count, n_not_started) + assert not_started_count == n_not_started def test_list_jobs_twice__no_jobs(self): # with no jobs with mock.patch.object(self.jm, "_running_jobs", {}): expected = "No running jobs!" - self.assertEqual(self.jm.list_jobs(), expected) - self.assertEqual(self.jm.list_jobs(), expected) + assert self.jm.list_jobs() == expected + assert self.jm.list_jobs() == expected def test_list_jobs_twice__jobs(self): """ @@ -371,7 +354,7 @@ def test_list_jobs_twice__jobs(self): # compare date-times for dt0, dt1 in zip(date_times_0, date_times_1): - self.assertTrue((dt1 - dt0).total_seconds() <= 5) # usually 1s suffices + assert (dt1 - dt0).total_seconds() <= 5 # usually 1s suffices # just strip delta-times (don't compare) # delta-times are difficult to parse into date-times @@ -381,43 +364,40 @@ def test_list_jobs_twice__jobs(self): jobs_html_1 = re.sub(time_re_pattern, sub, jobs_html_1) # compare stripped - self.assertEqual(jobs_html_0, jobs_html_1) + assert jobs_html_0 == jobs_html_1 def test_cancel_jobs__bad_inputs(self): - with self.assertRaisesRegex( - JobRequestException, re.escape(f"{JOBS_MISSING_ERR}: {[]}") + with pytest.raises( + JobRequestException, match=re.escape(f"{JOBS_MISSING_ERR}: {[]}") ): self.jm.cancel_jobs([]) - with self.assertRaisesRegex( - JobRequestException, re.escape(f'{JOBS_MISSING_ERR}: {["", "", None]}') + with pytest.raises( + JobRequestException, + match=re.escape(f'{JOBS_MISSING_ERR}: {["", "", None]}'), ): self.jm.cancel_jobs(["", "", None]) job_states = self.jm.cancel_jobs([JOB_NOT_FOUND]) - self.assertEqual( - {JOB_NOT_FOUND: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][JOB_NOT_FOUND]}, - job_states, - ) + assert { + JOB_NOT_FOUND: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][JOB_NOT_FOUND] + } == job_states def test_cancel_jobs__job_already_finished(self): - self.assertEqual(get_test_job(JOB_COMPLETED)["status"], "completed") - self.assertEqual(get_test_job(JOB_TERMINATED)["status"], "terminated") - self.assertTrue(self.jm.get_job(JOB_COMPLETED).in_terminal_state()) - self.assertTrue(self.jm.get_job(JOB_TERMINATED).in_terminal_state()) + assert get_test_job(JOB_COMPLETED)["status"] == "completed" + assert get_test_job(JOB_TERMINATED)["status"] == "terminated" + assert self.jm.get_job(JOB_COMPLETED).in_terminal_state() is True + assert self.jm.get_job(JOB_TERMINATED).in_terminal_state() is True job_id_list = [JOB_COMPLETED, JOB_TERMINATED] with mock.patch( "biokbase.narrative.jobs.jobmanager.JobManager._cancel_job" ) as mock_cancel_job: canceled_jobs = self.jm.cancel_jobs(job_id_list) mock_cancel_job.assert_not_called() - self.assertEqual( - { - job_id: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][job_id] - for job_id in job_id_list - }, - canceled_jobs, - ) + assert { + job_id: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][job_id] + for job_id in job_id_list + } == canceled_jobs @mock.patch(CLIENTS, get_mock_client) def test_cancel_jobs__run_ee2_cancel_job(self): @@ -444,8 +424,8 @@ def test_cancel_jobs__run_ee2_cancel_job(self): self.jm._running_jobs[JOB_CREATED]["refresh"] = 1 def check_state(arg): - self.assertFalse(self.jm._running_jobs[arg["job_id"]]["refresh"]) - self.assertEqual(self.jm._running_jobs[arg["job_id"]]["canceling"], True) + assert self.jm._running_jobs[arg["job_id"]]["refresh"] is False + assert self.jm._running_jobs[arg["job_id"]]["canceling"] is True # patch MockClients.cancel_job so we can test the input with mock.patch.object( @@ -455,10 +435,10 @@ def check_state(arg): ) as mock_cancel_job: results = self.jm.cancel_jobs(jobs) for job in [JOB_RUNNING, JOB_CREATED]: - self.assertNotIn("canceling", self.jm._running_jobs[job]) - self.assertEqual(self.jm._running_jobs[job]["refresh"], 1) - self.assertEqual(results.keys(), expected.keys()) - self.assertEqual(results, expected) + assert "canceling" not in self.jm._running_jobs[job] + assert self.jm._running_jobs[job]["refresh"] == 1 + assert results.keys() == expected.keys() + assert results == expected mock_cancel_job.assert_has_calls( [ mock.call({"job_id": JOB_RUNNING}), @@ -477,7 +457,7 @@ def _check_retry_jobs( expected, retry_results, ): - self.assertEqual(expected, retry_results) + assert expected == retry_results orig_ids = [ result["job_id"] for result in retry_results.values() @@ -496,11 +476,11 @@ def _check_retry_jobs( for job_id in orig_ids + retry_ids: job = self.jm.get_job(job_id) - self.assertIn(job_id, self.jm._running_jobs) - self.assertIsNotNone(job._acc_state) + assert job_id in self.jm._running_jobs + assert job._acc_state is not None for job_id in dne_ids: - self.assertNotIn(job_id, self.jm._running_jobs) + assert job_id not in self.jm._running_jobs @mock.patch(CLIENTS, get_mock_client) def test_retry_jobs__success(self): @@ -564,13 +544,14 @@ def test_retry_jobs__none_exist(self): self._check_retry_jobs(expected, retry_results) def test_retry_jobs__bad_inputs(self): - with self.assertRaisesRegex( - JobRequestException, re.escape(f"{JOBS_MISSING_ERR}: {[]}") + with pytest.raises( + JobRequestException, match=re.escape(f"{JOBS_MISSING_ERR}: {[]}") ): self.jm.retry_jobs([]) - with self.assertRaisesRegex( - JobRequestException, re.escape(f'{JOBS_MISSING_ERR}: {["", "", None]}') + with pytest.raises( + JobRequestException, + match=re.escape(f'{JOBS_MISSING_ERR}: {["", "", None]}'), ): self.jm.retry_jobs(["", "", None]) @@ -578,44 +559,36 @@ def test_retry_jobs__bad_inputs(self): def test_get_all_job_states(self): states = self.jm.get_all_job_states() refreshing_jobs = [job_id for job_id, state in REFRESH_STATE.items() if state] - self.assertEqual(set(refreshing_jobs), set(states.keys())) - self.assertEqual( - states, - { - job_id: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][job_id] - for job_id in refreshing_jobs - }, - ) + assert set(refreshing_jobs) == set(states.keys()) + assert states == { + job_id: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][job_id] + for job_id in refreshing_jobs + } @mock.patch(CLIENTS, get_mock_client) def test_get_all_job_states__ignore_refresh_flag(self): states = self.jm.get_all_job_states(ignore_refresh_flag=True) - self.assertEqual(set(ALL_JOBS), set(states.keys())) - self.assertEqual( - { - job_id: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][job_id] - for job_id in ALL_JOBS - }, - states, - ) + assert set(ALL_JOBS) == set(states.keys()) + assert states == { + job_id: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][job_id] + for job_id in ALL_JOBS + } # get_job_states_by_cell_id @mock.patch(CLIENTS, get_mock_client) def test_get_job_states_by_cell_id__cell_id_list_None(self): - with self.assertRaisesRegex(JobRequestException, CELLS_NOT_PROVIDED_ERR): + with pytest.raises(JobRequestException, match=CELLS_NOT_PROVIDED_ERR): self.jm.get_job_states_by_cell_id(cell_id_list=None) @mock.patch(CLIENTS, get_mock_client) def test_get_job_states_by_cell_id__cell_id_list_empty(self): - with self.assertRaisesRegex(JobRequestException, CELLS_NOT_PROVIDED_ERR): + with pytest.raises(JobRequestException, match=CELLS_NOT_PROVIDED_ERR): self.jm.get_job_states_by_cell_id(cell_id_list=[]) @mock.patch(CLIENTS, get_mock_client) def test_get_job_states_by_cell_id__cell_id_list_no_results(self): result = self.jm.get_job_states_by_cell_id(cell_id_list=["a", "b", "c"]) - self.assertEqual( - {"jobs": {}, "mapping": {"a": set(), "b": set(), "c": set()}}, result - ) + assert {"jobs": {}, "mapping": {"a": set(), "b": set(), "c": set()}} == result def check_get_job_states_by_cell_id_results(self, cell_ids, expected_ids): expected_states = { @@ -624,11 +597,11 @@ def check_get_job_states_by_cell_id_results(self, cell_ids, expected_ids): if job_id in expected_ids } result = self.jm.get_job_states_by_cell_id(cell_id_list=cell_ids) - self.assertEqual(set(expected_ids), set(result["jobs"].keys())) - self.assertEqual(expected_states, result["jobs"]) - self.assertEqual(set(cell_ids), set(result["mapping"].keys())) + assert set(expected_ids) == set(result["jobs"].keys()) + assert expected_states == result["jobs"] + assert set(cell_ids) == set(result["mapping"].keys()) for key in result["mapping"].keys(): - self.assertEqual(set(TEST_CELL_IDs[key]), set(result["mapping"][key])) + assert set(TEST_CELL_IDs[key]) == set(result["mapping"][key]) @mock.patch(CLIENTS, get_mock_client) def test_get_job_states_by_cell_id__cell_id_list_all_results(self): @@ -698,36 +671,36 @@ def test_get_job_states(self): } res = self.jm.get_job_states(job_ids) - self.assertEqual(exp, res) + assert exp == res def test_get_job_states__empty(self): - with self.assertRaisesRegex( - JobRequestException, re.escape(f"{JOBS_MISSING_ERR}: {[]}") + with pytest.raises( + JobRequestException, match=re.escape(f"{JOBS_MISSING_ERR}: {[]}") ): self.jm.get_job_states([]) def test_update_batch_job__dne(self): - with self.assertRaisesRegex( - JobRequestException, f"{JOB_NOT_REG_ERR}: {JOB_NOT_FOUND}" + with pytest.raises( + JobRequestException, match=f"{JOB_NOT_REG_ERR}: {JOB_NOT_FOUND}" ): self.jm.update_batch_job(JOB_NOT_FOUND) def test_update_batch_job__not_batch(self): - with self.assertRaisesRegex( - JobRequestException, f"{JOB_NOT_BATCH_ERR}: {JOB_CREATED}" + with pytest.raises( + JobRequestException, match=f"{JOB_NOT_BATCH_ERR}: {JOB_CREATED}" ): self.jm.update_batch_job(JOB_CREATED) - with self.assertRaisesRegex( - JobRequestException, f"{JOB_NOT_BATCH_ERR}: {BATCH_TERMINATED}" + with pytest.raises( + JobRequestException, match=f"{JOB_NOT_BATCH_ERR}: {BATCH_TERMINATED}" ): self.jm.update_batch_job(BATCH_TERMINATED) @mock.patch(CLIENTS, get_mock_client) def test_update_batch_job__no_change(self): job_ids = self.jm.update_batch_job(BATCH_PARENT) - self.assertEqual(BATCH_PARENT, job_ids[0]) - self.assertCountEqual(BATCH_CHILDREN, job_ids[1:]) + assert BATCH_PARENT == job_ids[0] + assert len(BATCH_CHILDREN) == len(job_ids[1:]) @mock.patch(CLIENTS, get_mock_client) def test_update_batch_job__change(self): @@ -767,47 +740,44 @@ def mock_check_job(params): ] ) - self.assertEqual(BATCH_PARENT, job_ids[0]) - self.assertCountEqual(new_child_ids, job_ids[1:]) + assert BATCH_PARENT == job_ids[0] + assert len(new_child_ids) == len(job_ids[1:]) batch_job = self.jm.get_job(BATCH_PARENT) reg_child_jobs = [ self.jm.get_job(job_id) for job_id in batch_job._acc_state["child_jobs"] ] - self.assertCountEqual(batch_job.children, reg_child_jobs) - self.assertCountEqual(batch_job._acc_state["child_jobs"], new_child_ids) + assert len(batch_job.children) == len(reg_child_jobs) + assert len(batch_job._acc_state["child_jobs"]) == len(new_child_ids) with mock.patch.object( MockClients, "check_job", side_effect=mock_check_job ) as m: - self.assertCountEqual(batch_job.child_jobs, new_child_ids) + assert len(batch_job.child_jobs) == len(new_child_ids) def test_modify_job_refresh(self): for job_id, refreshing in REFRESH_STATE.items(): - self.assertEqual(self.jm._running_jobs[job_id]["refresh"], refreshing) + assert self.jm._running_jobs[job_id]["refresh"] == refreshing self.jm.modify_job_refresh([job_id], False) # stop - self.assertEqual(self.jm._running_jobs[job_id]["refresh"], False) + assert self.jm._running_jobs[job_id]["refresh"] is False self.jm.modify_job_refresh([job_id], False) # stop harder - self.assertEqual(self.jm._running_jobs[job_id]["refresh"], False) + assert self.jm._running_jobs[job_id]["refresh"] is False self.jm.modify_job_refresh([job_id], True) # start - self.assertEqual(self.jm._running_jobs[job_id]["refresh"], True) + assert self.jm._running_jobs[job_id]["refresh"] is True self.jm.modify_job_refresh([job_id], True) # start some more - self.assertEqual(self.jm._running_jobs[job_id]["refresh"], True) + assert self.jm._running_jobs[job_id]["refresh"] is True self.jm.modify_job_refresh([job_id], False) # stop - self.assertEqual(self.jm._running_jobs[job_id]["refresh"], False) + assert self.jm._running_jobs[job_id]["refresh"] is False @mock.patch(CLIENTS, get_mock_client) def test_get_job_info(self): infos = self.jm.get_job_info(ALL_JOBS) - self.assertCountEqual(ALL_JOBS, infos.keys()) - self.assertEqual( - infos, - { - job_id: ALL_RESPONSE_DATA[MESSAGE_TYPE["INFO"]][job_id] - for job_id in ALL_JOBS - }, - ) + assert len(ALL_JOBS) == len(infos.keys()) + assert infos == { + job_id: ALL_RESPONSE_DATA[MESSAGE_TYPE["INFO"]][job_id] + for job_id in ALL_JOBS + } if __name__ == "__main__": diff --git a/src/biokbase/narrative/tests/test_kbasewsmanager.py b/src/biokbase/narrative/tests/test_kbasewsmanager.py index a696a46d18..bdde99dacb 100644 --- a/src/biokbase/narrative/tests/test_kbasewsmanager.py +++ b/src/biokbase/narrative/tests/test_kbasewsmanager.py @@ -1,10 +1,10 @@ import unittest from unittest import mock -from tornado.web import HTTPError - +import pytest from biokbase.narrative.common.narrative_ref import NarrativeRef from biokbase.narrative.contents.kbasewsmanager import KBaseWSManager +from tornado.web import HTTPError from .narrative_mock.mockclients import get_mock_client @@ -26,7 +26,7 @@ def test__parse_path(self): ), ] for c in cases: - self.assertEqual(c[1], manager._parse_path(c[0])) + assert c[1] == manager._parse_path(c[0]) def test__parse_path_bad(self): manager = KBaseWSManager() @@ -40,6 +40,5 @@ def test__parse_path_bad(self): "ws.1.2", ] for c in cases: - with self.assertRaises(HTTPError) as e: + with pytest.raises(HTTPError, match=f"Invalid Narrative path {c}"): manager._parse_path(c) - self.assertIn("Invalid Narrative path {}".format(c), str(e.exception)) diff --git a/src/biokbase/narrative/tests/test_kvp.py b/src/biokbase/narrative/tests/test_kvp.py new file mode 100644 index 0000000000..da4687e80a --- /dev/null +++ b/src/biokbase/narrative/tests/test_kvp.py @@ -0,0 +1,13 @@ +from biokbase.narrative.common.kvp import parse_kvp + + +def test_parse_kvp() -> None: + for user_input, text, kvp in ( + ("foo", "foo", {}), + ("name=val", "", {"name": "val"}), + ("a name=val boy", "a boy", {"name": "val"}), + ): + rkvp = {} + rtext = parse_kvp(user_input, rkvp) + assert text == rtext + assert kvp == rkvp diff --git a/src/biokbase/narrative/tests/test_log_proxy.py b/src/biokbase/narrative/tests/test_log_proxy.py index ecbd13ff57..1846d29e65 100644 --- a/src/biokbase/narrative/tests/test_log_proxy.py +++ b/src/biokbase/narrative/tests/test_log_proxy.py @@ -7,7 +7,8 @@ import time import unittest -from biokbase.narrative.common import log_proxy as proxy +import pytest +from biokbase.narrative.common import log_proxy __author__ = "Dan Gunter " @@ -21,8 +22,8 @@ class MainTestCase(unittest.TestCase): def setUp(self): self._config(["db: test", "collection: kblog"]) - if proxy.g_log is None: - proxy.g_log = logging.getLogger(proxy.LOGGER_NAME) + if log_proxy.g_log is None: + log_proxy.g_log = logging.getLogger(log_proxy.LOGGER_NAME) def _config(self, lines): text = "\n".join(lines) @@ -33,34 +34,39 @@ def test_run_proxy(self): pid = os.fork() if pid == 0: print("Run child") - proxy.run(self) + log_proxy.run(self) else: time.sleep(1) print("Wait for child to start") # let it start time.sleep(4) # send it a HUP to stop it - print("Send child ({:d}) a HUP".format(pid)) + print(f"Send child ({pid:d}) a HUP") os.kill(pid, signal.SIGHUP) # wait for it to stop - print("Wait for child ({:d}) to stop".format(pid)) + print(f"Wait for child ({pid:d}) to stop") cpid, r = os.waitpid(pid, 0) - print("cpid, r: {}, {}".format(cpid, r)) - self.assertTrue(r < 2, "Bad exit status ({:d}) from proxy".format(r)) + print(f"cpid, r: {cpid}, {r}") + assert r < 2, f"Bad exit status ({r:d}) from proxy" def test_configuration(self): # empty self._config([]) - self.assertRaises(ValueError, proxy.DBConfiguration, self.conf) + with pytest.raises(ValueError): + log_proxy.DBConfiguration(self.conf) + # , proxy.DBConfiguration, self.conf) # missing collection self._config(["db: test"]) - self.assertRaises(KeyError, proxy.DBConfiguration, self.conf) + with pytest.raises(KeyError): + log_proxy.DBConfiguration(self.conf) # bad db name self._config(["db: 1test", "collection: kblog"]) - self.assertRaises(ValueError, proxy.DBConfiguration, self.conf) + with pytest.raises(ValueError): + log_proxy.DBConfiguration(self.conf) # bad db name self._config(["db: test.er", "collection: kblog"]) - self.assertRaises(ValueError, proxy.DBConfiguration, self.conf) + with pytest.raises(ValueError): + log_proxy.DBConfiguration(self.conf) # too long self._config( [ @@ -69,26 +75,29 @@ def test_configuration(self): "ddddddddddddddddddddddddddddddd", ] ) - self.assertRaises(ValueError, proxy.DBConfiguration, self.conf) + with pytest.raises(ValueError): + log_proxy.DBConfiguration(self.conf) # bad collection self._config(["db: test", "collection: kb$log"]) - self.assertRaises(ValueError, proxy.DBConfiguration, self.conf) + with pytest.raises(ValueError): + log_proxy.DBConfiguration(self.conf) # user, no pass self._config(["db: test", "collection: kblog", "user: joe"]) - self.assertRaises(KeyError, proxy.DBConfiguration, self.conf) + with pytest.raises(KeyError): + log_proxy.DBConfiguration(self.conf) class LogRecordTest(unittest.TestCase): def setUp(self): - if proxy.g_log is None: - proxy.g_log = logging.getLogger(proxy.LOGGER_NAME) + if log_proxy.g_log is None: + log_proxy.g_log = logging.getLogger(log_proxy.LOGGER_NAME) def test_basic(self): - for input in {}, {"message": "hello"}: - kbrec = proxy.DBRecord(input) - kbrec = proxy.DBRecord({"message": "greeting;Hello=World"}) - self.assertEqual(kbrec.record["event"], "greeting") - self.assertEqual(kbrec.record["Hello"], "World") + for inpt in {}, {"message": "hello"}: + log_proxy.DBRecord(inpt) + kbrec = log_proxy.DBRecord({"message": "greeting;Hello=World"}) + assert kbrec.record["event"] == "greeting" + assert kbrec.record["Hello"] == "World" def test_strict(self): for inp in ( @@ -96,8 +105,9 @@ def test_strict(self): {12: "xanthium"}, {"message": "Hello=World;greeting"}, ): - proxy.DBRecord(inp) - self.assertRaises(ValueError, proxy.DBRecord, inp, strict=True) + log_proxy.DBRecord(inp) + with pytest.raises(ValueError): + log_proxy.DBRecord(inp, strict=True) if __name__ == "__main__": diff --git a/src/biokbase/narrative/tests/test_logging.py b/src/biokbase/narrative/tests/test_logging.py index cb94e25c21..3ed7f3a016 100644 --- a/src/biokbase/narrative/tests/test_logging.py +++ b/src/biokbase/narrative/tests/test_logging.py @@ -54,7 +54,7 @@ def test_simple(self): self.stop_receiver(kblog) # check that receiver got the (buffered) messages - self.assertEqual(data, "helloworld") + assert data == "helloworld" @unittest.skip("Skipping buffering test for now") def test_buffering(self): @@ -72,7 +72,6 @@ def test_buffering(self): self.stop_receiver(kblog) # check that receiver got the (buffered) messages - # self.assertEqual(data, "helloworld") if __name__ == "__main__": diff --git a/src/biokbase/narrative/tests/test_narrative_logger.py b/src/biokbase/narrative/tests/test_narrative_logger.py index 7b1370c0ee..6b5e87233a 100644 --- a/src/biokbase/narrative/tests/test_narrative_logger.py +++ b/src/biokbase/narrative/tests/test_narrative_logger.py @@ -9,6 +9,16 @@ from . import util +def assert_log_msg(msg, event, narrative, version): + data = json.loads(msg) + assert data["type"] == "narrative" + assert data["user"] == "anonymous" + assert data["env"] == "ci" + assert data["narr_ver"] == version + assert data["narrative"] == narrative + assert data["operation"] == event + + class NarrativeLoggerTestCase(unittest.TestCase): @classmethod def setUpClass(cls): @@ -42,9 +52,9 @@ def stop_log_stack(cls): def test_logger_init(self): logger = NarrativeLogger() - self.assertEqual(logger.host, URLS.log_host) - self.assertEqual(logger.port, URLS.log_port) - self.assertEqual(logger.env, kbase_env.env) + assert logger.host == URLS.log_host + assert logger.port == URLS.log_port + assert logger.env == kbase_env.env def test_null_logger(self): URLS.log_host = None @@ -55,7 +65,7 @@ def test_null_logger(self): time.sleep(self.poll_interval * 4) data = self.log_recv.get_data() self.stop_log_stack() - self.assertFalse(data) + assert not data URLS.log_host = self.log_host def test_open_narr(self): @@ -65,7 +75,7 @@ def test_open_narr(self): logger = NarrativeLogger() logger.narrative_open(narrative, version) time.sleep(self.poll_interval * 4) - self.assert_log_msg(self.log_recv.get_data(), "open", narrative, version) + assert_log_msg(self.log_recv.get_data(), "open", narrative, version) self.stop_log_stack() def test_save_narr(self): @@ -75,7 +85,7 @@ def test_save_narr(self): logger = NarrativeLogger() logger.narrative_save(narrative, version) time.sleep(self.poll_interval * 4) - self.assert_log_msg(self.log_recv.get_data(), "save", narrative, version) + assert_log_msg(self.log_recv.get_data(), "save", narrative, version) self.stop_log_stack() def test_failed_message(self): @@ -91,15 +101,6 @@ def test_failed_message(self): "Log writing threw an unexpected exception without a live socket!" ) - def assert_log_msg(self, msg, event, narrative, version): - data = json.loads(msg) - self.assertEqual(data["type"], "narrative") - self.assertEqual(data["user"], "anonymous") - self.assertEqual(data["env"], "ci") - self.assertEqual(data["narr_ver"], version) - self.assertEqual(data["narrative"], narrative) - self.assertEqual(data["operation"], event) - if __name__ == "__main__": unittest.main() diff --git a/src/biokbase/narrative/tests/test_narrative_ref.py b/src/biokbase/narrative/tests/test_narrative_ref.py index e2e1528256..3b0fef20e3 100644 --- a/src/biokbase/narrative/tests/test_narrative_ref.py +++ b/src/biokbase/narrative/tests/test_narrative_ref.py @@ -1,6 +1,7 @@ import unittest from unittest import mock +import pytest from biokbase.narrative.common.exceptions import WorkspaceError from biokbase.narrative.common.narrative_ref import NarrativeRef @@ -11,47 +12,45 @@ class NarrativeRefTestCase(unittest.TestCase): @mock.patch("biokbase.narrative.common.narrative_ref.clients.get", get_mock_client) def test_no_objid_ok(self): ref = NarrativeRef({"wsid": 123, "objid": None, "ver": None}) - self.assertEqual(ref.wsid, 123) - self.assertEqual(ref.objid, 1) - self.assertIsNone(ref.ver) + assert ref.wsid == 123 + assert ref.objid == 1 + assert ref.ver is None def test_ok(self): ref = NarrativeRef({"wsid": 123, "objid": 456, "ver": 7}) - self.assertEqual(ref.wsid, 123) - self.assertEqual(ref.objid, 456) - self.assertEqual(ref.ver, 7) + assert ref.wsid == 123 + assert ref.objid == 456 + assert ref.ver == 7 @mock.patch("biokbase.narrative.common.narrative_ref.clients.get", get_mock_client) def test_no_objid_fail(self): - with self.assertRaises(RuntimeError) as e: + with pytest.raises( + RuntimeError, + match="Couldn't find Narrative object id in Workspace metadata", + ): NarrativeRef({"wsid": 678, "objid": None, "ver": None}) - self.assertIn( - "Couldn't find Narrative object id in Workspace metadata", str(e.exception) - ) @mock.patch("biokbase.narrative.common.narrative_ref.clients.get", get_mock_client) def test_ref_init_fail(self): - with self.assertRaises(ValueError) as e: + with pytest.raises( + ValueError, + match="A numerical Workspace id is required for a Narrative ref, not x", + ): NarrativeRef({"wsid": "x", "objid": None, "ver": None}) - self.assertIn( - "A numerical Workspace id is required for a Narrative ref, not x", - str(e.exception), - ) - with self.assertRaises(ValueError) as e: + with pytest.raises(ValueError, match="objid must be numerical, not x"): NarrativeRef({"wsid": 678, "objid": "x", "ver": None}) - self.assertIn("objid must be numerical, not x", str(e.exception)) - with self.assertRaises(ValueError) as e: + with pytest.raises( + ValueError, + match="If ver is present in the ref, it must be numerical, not x", + ): NarrativeRef({"wsid": 678, "objid": 1, "ver": "x"}) - self.assertIn( - "If ver is present in the ref, it must be numerical, not x", - str(e.exception), - ) @mock.patch("biokbase.narrative.common.narrative_ref.clients.get", get_mock_client) def test_no_ws_perm(self): - with self.assertRaises(WorkspaceError) as e: + with pytest.raises( + WorkspaceError, match="You do not have access to this workspace" + ) as e: NarrativeRef({"wsid": 789, "objid": None, "ver": None}) - self.assertEqual(403, e.exception.http_code) - self.assertIn("You do not have access to this workspace", e.exception.message) + assert e._excinfo[1].http_code == 403 diff --git a/src/biokbase/narrative/tests/test_narrativeio.py b/src/biokbase/narrative/tests/test_narrativeio.py index cfb8e962a6..21b63d3aa6 100644 --- a/src/biokbase/narrative/tests/test_narrativeio.py +++ b/src/biokbase/narrative/tests/test_narrativeio.py @@ -6,6 +6,7 @@ from unittest.mock import patch import biokbase.auth +import pytest from biokbase.narrative import clients from biokbase.narrative.common.exceptions import WorkspaceError from biokbase.narrative.common.narrative_ref import NarrativeRef @@ -36,6 +37,8 @@ } HAS_TEST_TOKEN = False +READ_NARRATIVE_REF_WARNING = "read_narrative must use a NarrativeRef as input!" + def get_exp_nar(i): return dict(zip(LIST_OBJECTS_FIELDS, get_nar_obj(i))) @@ -45,6 +48,7 @@ def skipUnlessToken(): global HAS_TEST_TOKEN if not HAS_TEST_TOKEN: return unittest.skip("No auth token") + return None def str_to_ref(s): @@ -157,7 +161,7 @@ def logout(self): biokbase.auth.set_environ_token(None) def test_mixin_instantiated(self): - self.assertIsInstance( + assert isinstance( self.mixin, biokbase.narrative.contents.narrativeio.KBaseWSManagerMixin ) @@ -166,24 +170,20 @@ def test_mixin_instantiated(self): def test_narrative_exists_valid(self): if self.test_token is None: self.skipTest("No auth token") - self.assertTrue(self.mixin.narrative_exists(self.public_nar["ref"])) + assert self.mixin.narrative_exists(self.public_nar["ref"]) def test_narrative_exists_invalid(self): - self.assertFalse(self.mixin.narrative_exists(self.invalid_nar_ref)) + assert not self.mixin.narrative_exists(self.invalid_nar_ref) def test_narrative_exists_bad(self): - with self.assertRaises(AssertionError) as err: + with pytest.raises(AssertionError, match=READ_NARRATIVE_REF_WARNING): self.mixin.narrative_exists(self.bad_nar_ref) - self.assertEqual( - "read_narrative must use a NarrativeRef as input!", str(err.exception) - ) def test_narrative_exists_noauth(self): if self.test_token is None: self.skipTest("No auth token") - with self.assertRaises(WorkspaceError) as err: + with pytest.raises(WorkspaceError): self.mixin.narrative_exists(self.private_nar["ref"]) - self.assertIsNotNone(err) # test KBaseWSManagerMixin.read_narrative ##### @@ -200,7 +200,7 @@ def validate_narrative(self, nar, with_content, with_meta): return "Narrative needs to be a dict to be valid." # expected keys: - exp_keys = set(["info"]) + exp_keys = {"info"} if with_content: exp_keys.update( [ @@ -227,18 +227,16 @@ def validate_narrative(self, nar, with_content, with_meta): if with_meta: if not nar["info"][10]: return "Narrative metadata not returned when expected" - meta_keys = set( - [ - "creator", - "data_dependencies", - "description", - "format", - "job_info", - "name", - "type", - "ws_name", - ] - ) + meta_keys = { + "creator", + "data_dependencies", + "description", + "format", + "job_info", + "name", + "type", + "ws_name", + } missing_keys = meta_keys - set(nar["info"][10]) if missing_keys: return "Narrative metadata is missing the following keys: {}".format( @@ -246,18 +244,13 @@ def validate_narrative(self, nar, with_content, with_meta): ) return None - def validate_metadata(self, std_meta, meta): - """ - Validates a Narrative's typed object metadata. - """ - def test_read_narrative_valid_content_metadata(self): if self.test_token is None: self.skipTest("No auth token") nar = self.mixin.read_narrative( self.public_nar["ref"], content=True, include_metadata=False ) - self.assertIsNone(self.validate_narrative(nar, True, True)) + assert self.validate_narrative(nar, True, True) is None def test_read_narrative_valid_content_no_metadata(self): if self.test_token is None: @@ -265,7 +258,7 @@ def test_read_narrative_valid_content_no_metadata(self): nar = self.mixin.read_narrative( self.public_nar["ref"], content=True, include_metadata=True ) - self.assertIsNone(self.validate_narrative(nar, True, False)) + assert self.validate_narrative(nar, True, False) is None def test_read_narrative_valid_no_content_metadata(self): if self.test_token is None: @@ -273,7 +266,7 @@ def test_read_narrative_valid_no_content_metadata(self): nar = self.mixin.read_narrative( self.public_nar["ref"], content=False, include_metadata=True ) - self.assertIsNone(self.validate_narrative(nar, False, True)) + assert self.validate_narrative(nar, False, True) is None def test_read_narrative_valid_no_content_no_metadata(self): if self.test_token is None: @@ -281,35 +274,29 @@ def test_read_narrative_valid_no_content_no_metadata(self): nar = self.mixin.read_narrative( self.public_nar["ref"], content=False, include_metadata=False ) - self.assertIsNone(self.validate_narrative(nar, False, False)) + assert self.validate_narrative(nar, False, False) is None def test_read_narrative_private_anon(self): if self.test_token is None: self.skipTest("No auth token") - with self.assertRaises(WorkspaceError) as err: + with pytest.raises(WorkspaceError): self.mixin.read_narrative(self.private_nar["ref"]) - self.assertIsNotNone(err) def test_read_narrative_unauth_login(self): if self.test_token is None: self.skipTest("No auth token") self.login() - with self.assertRaises(WorkspaceError) as err: + with pytest.raises(WorkspaceError): self.mixin.read_narrative(self.unauth_nar["ref"]) - self.assertIsNotNone(err) self.logout() def test_read_narrative_invalid(self): - with self.assertRaises(WorkspaceError) as err: + with pytest.raises(WorkspaceError): self.mixin.read_narrative(self.invalid_nar_ref) - self.assertIsNotNone(err) def test_read_narrative_bad(self): - with self.assertRaises(AssertionError) as err: + with pytest.raises(AssertionError, match=READ_NARRATIVE_REF_WARNING): self.mixin.read_narrative(self.bad_nar_ref) - self.assertEqual( - "read_narrative must use a NarrativeRef as input!", str(err.exception) - ) # test KBaseWSManagerMixin.write_narrative ##### @@ -323,23 +310,21 @@ def test_write_narrative_valid_auth(self): result = self.mixin.write_narrative( self.private_nar["ref"], nar, self.test_user ) - self.assertTrue( - result[1] == self.private_nar["ws"] and result[2] == self.private_nar["obj"] - ) - self.assertEqual(result[0]["metadata"]["is_temporary"], "false") + assert result[1] == self.private_nar["ws"] + assert result[2] == self.private_nar["obj"] + assert result[0]["metadata"]["is_temporary"] == "false" ws = clients.get("workspace") ws_info = ws.get_workspace_info({"id": result[1]}) - self.assertEqual(ws_info[8]["searchtags"], "narrative") + assert ws_info[8]["searchtags"] == "narrative" self.logout() def test_write_narrative_valid_anon(self): if self.test_token is None: self.skipTest("No auth token") nar = self.mixin.read_narrative(self.public_nar["ref"])["data"] - with self.assertRaises(WorkspaceError) as err: + with pytest.raises(WorkspaceError): self.mixin.write_narrative(self.public_nar["ref"], nar, "Anonymous") - self.assertIsNotNone(err) def test_write_narrative_valid_unauth(self): pass @@ -349,9 +334,8 @@ def test_write_narrative_invalid_ref(self): self.skipTest("No auth token") self.login() nar = self.mixin.read_narrative(self.public_nar["ref"])["data"] - with self.assertRaises(WorkspaceError) as err: + with pytest.raises(WorkspaceError): self.mixin.write_narrative(self.invalid_nar_ref, nar, self.test_user) - self.assertIsNotNone(err) self.logout() def test_write_narrative_shared_write_access(self): @@ -398,22 +382,21 @@ def test_write_narrative_bad_ref(self): self.skipTest("No auth token") self.login() nar = self.mixin.read_narrative(self.public_nar["ref"])["data"] - with self.assertRaises(AssertionError) as err: + with pytest.raises( + AssertionError, match="write_narrative must use a NarrativeRef as input!" + ): self.mixin.write_narrative(self.bad_nar_ref, nar, self.test_user) - self.assertEqual( - "write_narrative must use a NarrativeRef as input!", str(err.exception) - ) self.logout() def test_write_narrative_bad_file(self): if self.test_token is None: self.skipTest("No auth token") self.login() - with self.assertRaises(HTTPError) as err: + with pytest.raises(HTTPError) as err: self.mixin.write_narrative( self.private_nar["ref"], {"not": "a narrative"}, self.test_user ) - self.assertEqual(err.exception.status_code, 400) + assert err._excinfo[1].status_code == 400 self.logout() # test KBaseWSManagerMixin.rename_narrative ##### @@ -431,7 +414,7 @@ def test_rename_narrative_valid_auth(self): nar = self.mixin.read_narrative( self.private_nar["ref"], content=False, include_metadata=True ) - self.assertEqual(new_name, nar["info"][10]["name"]) + assert new_name == nar["info"][10]["name"] # now, put it back to the old name, so it doesn't break other tests... self.mixin.rename_narrative(self.private_nar["ref"], self.test_user, cur_name) @@ -440,36 +423,30 @@ def test_rename_narrative_valid_auth(self): def test_rename_narrative_valid_anon(self): if self.test_token is None: self.skipTest("No auth token") - with self.assertRaises(WorkspaceError) as err: + with pytest.raises(WorkspaceError): self.mixin.rename_narrative( self.public_nar["ref"], self.test_user, "new_name" ) - self.assertIsNotNone(err) def test_rename_narrative_unauth(self): if self.test_token is None: self.skipTest("No auth token") self.login() - with self.assertRaises(WorkspaceError) as err: + with pytest.raises(WorkspaceError): self.mixin.rename_narrative( self.unauth_nar["ref"], self.test_user, "new_name" ) - self.assertIsNotNone(err) self.logout() def test_rename_narrative_invalid(self): - with self.assertRaises(WorkspaceError) as err: + with pytest.raises(WorkspaceError): self.mixin.rename_narrative( self.invalid_nar_ref, self.test_user, "new_name" ) - self.assertIsNotNone(err) def test_rename_narrative_bad(self): - with self.assertRaises(AssertionError) as err: + with pytest.raises(AssertionError, match=READ_NARRATIVE_REF_WARNING): self.mixin.rename_narrative(self.bad_nar_ref, self.test_user, "new_name") - self.assertEqual( - "read_narrative must use a NarrativeRef as input!", str(err.exception) - ) # test KBaseWSManagerMixin.copy_narrative ##### @@ -484,10 +461,10 @@ def test_copy_narrative_invalid(self): # test KBaseWSManagerMixin.list_narratives ##### def validate_narrative_list(self, nar_list): - self.assertIsInstance(nar_list, list) + assert isinstance(nar_list, list) for nar in nar_list: - self.assertIsInstance(nar, dict) - self.assertTrue(set(nar.keys()).issubset(metadata_fields)) + assert isinstance(nar, dict) + assert set(nar.keys()).issubset(metadata_fields) def test_list_all_narratives_anon(self): res = self.mixin.list_narratives() @@ -510,9 +487,8 @@ def test_list_narrative_ws_valid_anon(self): def test_list_narrative_ws_valid_noperm_anon(self): if self.test_token is None: self.skipTest("No auth token") - with self.assertRaises(WorkspaceError) as err: + with pytest.raises(WorkspaceError): self.mixin.list_narratives(ws_id=self.private_nar["ws"]) - self.assertIsNotNone(err) def test_list_narrative_ws_valid_login(self): if self.test_token is None: @@ -531,7 +507,7 @@ def test_list_narratives__no_ws_id__0_ws_ids(self): ): nar_l = self.mixin.list_narratives() - self.assertEqual([], nar_l) + assert [] == nar_l self.validate_narrative_list(nar_l) @patch("biokbase.narrative.clients.get", get_mock_client) @@ -543,7 +519,7 @@ def test_list_narratives__no_ws_id__9999_ws_ids(self): ): nar_l = self.mixin.list_narratives() - self.assertEqual([get_exp_nar(i) for i in range(9999)], nar_l) + assert [get_exp_nar(i) for i in range(9999)] == nar_l self.validate_narrative_list(nar_l) @patch("biokbase.narrative.clients.get", get_mock_client) @@ -555,7 +531,7 @@ def test_list_narratives__no_ws_id__10000_ws_ids(self): ): nar_l = self.mixin.list_narratives() - self.assertEqual([get_exp_nar(i) for i in range(10000)], nar_l) + assert [get_exp_nar(i) for i in range(10000)] == nar_l self.validate_narrative_list(nar_l) @patch("biokbase.narrative.clients.get", get_mock_client) @@ -567,27 +543,24 @@ def test_list_narratives__no_ws_id__10001_ws_ids(self): ): nar_l = self.mixin.list_narratives() - self.assertEqual([get_exp_nar(i) for i in range(10001)], nar_l) + assert [get_exp_nar(i) for i in range(10001)] == nar_l self.validate_narrative_list(nar_l) def test_list_narrative_ws_invalid(self): - with self.assertRaises(WorkspaceError) as err: + with pytest.raises(WorkspaceError): self.mixin.list_narratives(ws_id=self.invalid_ws_id) - self.assertIsNotNone(err) def test_list_narrative_ws_valid_noperm_auth(self): if self.test_token is None: self.skipTest("No auth token") self.login() - with self.assertRaises(WorkspaceError) as err: + with pytest.raises(WorkspaceError): self.mixin.list_narratives(ws_id=self.unauth_nar["ws"]) - self.assertIsNotNone(err) self.logout() def test_list_narrative_ws_bad(self): - with self.assertRaises(ValueError) as err: + with pytest.raises(ValueError): self.mixin.list_narratives(ws_id=self.bad_nar_ref) - self.assertIsNotNone(err) # test KBaseWSManagerMixin.narrative_permissions ##### # params: @@ -599,21 +572,23 @@ def test_narrative_permissions_anon(self): if self.test_token is None: self.skipTest("No auth token") ret = self.mixin.narrative_permissions(self.public_nar["ref"]) - self.assertTrue(isinstance(ret, dict) and ret["*"] == "r") + assert isinstance(ret, dict) + assert ret["*"] == "r" def test_narrative_permissions_valid_login(self): if self.test_token is None: self.skipTest("No auth token") self.login() ret = self.mixin.narrative_permissions(self.public_nar["ref"]) - self.assertTrue(isinstance(ret, dict) and ret["*"] == "r") + assert isinstance(ret, dict) + assert ret["*"] == "r" self.logout() def test_narrative_permissions_invalid_login(self): if self.test_token is None: self.skipTest("No auth token") self.login() - with self.assertRaises(WorkspaceError): + with pytest.raises(WorkspaceError): self.mixin.narrative_permissions(self.invalid_nar_ref) self.logout() @@ -622,16 +597,16 @@ def test_narrative_permissions_inaccessible_login(self): self.skipTest("No auth token") self.login() ret = self.mixin.narrative_permissions(self.unauth_nar["ref"]) - self.assertTrue(isinstance(ret, dict) and ret[self.test_user] == "n") + assert isinstance(ret, dict) + assert ret[self.test_user] == "n" self.logout() def test_narrative_permissions_bad(self): - with self.assertRaises(AssertionError) as err: + with pytest.raises( + AssertionError, + match="narrative_permissions must use a NarrativeRef as input!", + ): self.mixin.narrative_permissions(self.bad_nar_ref) - self.assertEqual( - "narrative_permissions must use a NarrativeRef as input!", - str(err.exception), - ) # test KBaseWSManagerMixin.narrative_writable ##### @@ -639,15 +614,14 @@ def test_narrative_writable_anon(self): if self.test_token is None: self.skipTest("No auth token") ret = self.mixin.narrative_writable(self.public_nar["ref"], self.test_user) - self.assertFalse(ret) + assert not ret def test_narrative_writable_valid_login_nouser(self): if self.test_token is None: self.skipTest("No auth token") self.login() - with self.assertRaises(ValueError) as err: + with pytest.raises(ValueError): self.mixin.narrative_writable(self.public_nar["ref"], None) - self.assertIsNotNone(err) self.logout() def test_narrative_writable_valid_login_user(self): @@ -655,16 +629,15 @@ def test_narrative_writable_valid_login_user(self): self.skipTest("No auth token") self.login() ret = self.mixin.narrative_writable(self.public_nar["ref"], self.test_user) - self.assertTrue(ret) + assert ret self.logout() def test_narrative_writable_invalid_login_user(self): if self.test_token is None: self.skipTest("No auth token") self.login() - with self.assertRaises(WorkspaceError) as err: + with pytest.raises(WorkspaceError): self.mixin.narrative_writable(self.invalid_nar_ref, self.test_user) - self.assertIsNotNone(err) self.logout() def test_narrative_writable_inaccessible_login_user(self): @@ -672,47 +645,42 @@ def test_narrative_writable_inaccessible_login_user(self): self.skipTest("No auth token") self.login() ret = self.mixin.narrative_writable(self.unauth_nar["ref"], self.test_user) - self.assertFalse(ret) + assert not ret self.logout() def test_narrative_writable_bad_login_user(self): if self.test_token is None: self.skipTest("No auth token") self.login() - with self.assertRaises(AssertionError) as err: + with pytest.raises( + AssertionError, + match="narrative_permissions must use a NarrativeRef as input!", + ): self.mixin.narrative_writable(self.bad_nar_ref, self.test_user) - self.assertEqual( - "narrative_permissions must use a NarrativeRef as input!", - str(err.exception), - ) self.logout() # test KBaseWSManagerMixin._validate_nar_type ##### def test__validate_nar_type_ok(self): - self.assertIsNone( + assert ( self.mixin._validate_nar_type("KBaseNarrative.Narrative-123.45", None) + is None ) - self.assertIsNone( - self.mixin._validate_nar_type("KBaseNarrative.Narrative", None) - ) + assert self.mixin._validate_nar_type("KBaseNarrative.Narrative", None) is None def test__validate_nar_type_fail(self): bad_type = "NotANarrative" ref = "123/45" - with self.assertRaises(HTTPError) as err: + + with pytest.raises( + HTTPError, match=f"Expected a Narrative object, got a {bad_type}" + ): self.mixin._validate_nar_type(bad_type, None) - self.assertIn( - "Expected a Narrative object, got a {}".format(bad_type), str(err.exception) - ) - with self.assertRaises(HTTPError) as err: + with pytest.raises( + HTTPError, + match=f"Expected a Narrative object with reference {ref}, got a {bad_type}", + ): self.mixin._validate_nar_type(bad_type, ref) - self.assertIn( - "Expected a Narrative object with reference {}, got a {}".format( - ref, bad_type - ), - str(err.exception), - ) if __name__ == "__main__": diff --git a/src/biokbase/narrative/tests/test_specmanager.py b/src/biokbase/narrative/tests/test_specmanager.py index e098f76c18..b91739453a 100644 --- a/src/biokbase/narrative/tests/test_specmanager.py +++ b/src/biokbase/narrative/tests/test_specmanager.py @@ -1,6 +1,7 @@ import unittest from unittest import mock +import pytest from biokbase.narrative.jobs.specmanager import SpecManager from .narrative_mock.mockclients import get_mock_client @@ -21,35 +22,32 @@ def tearDownClass(cls): def test_apps_present(self): # on startup, should have app_specs - self.assertTrue(self.good_tag in self.sm.app_specs) + assert self.good_tag in self.sm.app_specs def test_check_app(self): # good id and good tag - self.assertTrue(self.sm.check_app(self.good_app_id, self.good_tag)) + assert self.sm.check_app(self.good_app_id, self.good_tag) # good id and bad tag no raise - self.assertFalse(self.sm.check_app(self.good_app_id, self.bad_tag)) + assert self.sm.check_app(self.good_app_id, self.bad_tag) is False # bad id and good tag no raise - self.assertFalse(self.sm.check_app(self.bad_app_id, self.good_tag)) + assert self.sm.check_app(self.bad_app_id, self.good_tag) is False # bad id with raise - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.sm.check_app(self.bad_app_id, raise_exception=True) @mock.patch("biokbase.narrative.jobs.specmanager.clients.get", get_mock_client) def test_get_type_spec(self): self.sm.reload() - self.assertIn( - "export_functions", list(self.sm.get_type_spec("KBaseFBA.FBA").keys()) + assert "export_functions" in list(self.sm.get_type_spec("KBaseFBA.FBA").keys()) + assert "export_functions" in list( + self.sm.get_type_spec("KBaseFBA.NU_FBA").keys() ) - self.assertIn( - "export_functions", list(self.sm.get_type_spec("KBaseFBA.NU_FBA").keys()) - ) - with self.assertRaisesRegex(ValueError, "Unknown type"): - self.assertIn( - "export_functions", - list(self.sm.get_type_spec("KBaseExpression.NU_FBA").keys()), + with pytest.raises(ValueError, match="Unknown type"): + assert "export_functions" in list( + self.sm.get_type_spec("KBaseExpression.NU_FBA").keys() ) diff --git a/src/biokbase/narrative/tests/test_staging_helper.py b/src/biokbase/narrative/tests/test_staging_helper.py index bf74850bf9..a6095e1363 100644 --- a/src/biokbase/narrative/tests/test_staging_helper.py +++ b/src/biokbase/narrative/tests/test_staging_helper.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import patch +import pytest from biokbase.narrative.staging.helper import Helper @@ -13,111 +14,119 @@ def setUp(self): def test_missing_token(self): os.environ["KB_AUTH_TOKEN"] = "" - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError, match="Cannot retrieve auth token"): Helper() - self.assertEqual("Cannot retrieve auth token", str(context.exception)) def test_token(self): - self.assertEqual(self.good_fake_token, self.staging_helper._token) + assert self.good_fake_token == self.staging_helper._token def test_staging_url(self): - self.assertTrue( - "kbase.us/services/staging_service" in self.staging_helper._staging_url - ) + assert "kbase.us/services/staging_service" in self.staging_helper._staging_url - @unittest.skip("Skipped: test contacts the staging service, but should not") + @pytest.mark.skip("Skipped: test contacts the staging service, but should not") def test_unauthorized_token(self): - with self.assertRaises(ValueError) as context: + import re + + with pytest.raises( + ValueError, + match=re.escape( + "The server could not fulfill the request.\nServer message: b'Error connecting to auth service: 401 Unauthorized\\n10020 Invalid token'\nReason: Unauthorized\nError code: 401\n" + ), + ): self.staging_helper.list() - self.assertTrue("Reason: Unauthorized" in str(context.exception)) - self.assertTrue("Error code: 401" in str(context.exception)) def mock_fetch_url( end_point, values=None, headers=None, method="GET", save_path=None - ): + ) -> str | None: if "list" in end_point: print("mocking __fetch_url list endpoint") return '[{"path": "tgu/test_file_1", "isFolder": false},\ {"path": "tgu/test_dir", "isFolder": true},\ {"path": "tgu/test_dir/test_file_2", "isFolder": false}]' - elif "jgi-metadata" in end_point: + + if "jgi-metadata" in end_point: print("mocking __fetch_url jgi-metadata endpoint") return '{"file_name": "test_file", "file_status": "BACKUP_COMPLETE"}' - elif "metadata" in end_point: + + if "metadata" in end_point: print("mocking __fetch_url metadata endpoint") return '{"head": "head_line", "tail": "tail_line", "lineCount": 10}' - elif "search" in end_point: + + if "search" in end_point: print("mocking __fetch_url search endpoint") return '[{"isFolder": false, "mtime": 1515526154896, "name": "LMS-PROC-315.pdf"}]' - elif "delete" in end_point: + + if "delete" in end_point: print("mocking __fetch_url delete endpoint") return "successfully deleted tgu2/test.pdf" - elif "download" in end_point: + + if "download" in end_point: print("mocking __fetch_url download endpoint") - elif "mv" in end_point: + return None + + if "mv" in end_point: print("mocking __fetch_url mv endpoint") return "successfully moved tgu2/test.pdf to tgu2/test_1.pdf" + return None + @patch.object(Helper, "_Helper__fetch_url", side_effect=mock_fetch_url) def test_list(self, _fetch_url): file_list = self.staging_helper.list() - self.assertTrue("tgu/test_file_1" in file_list) - self.assertTrue("tgu/test_dir/test_file_2" in file_list) - self.assertTrue("tgu/test_dir" not in file_list) + assert "tgu/test_file_1" in file_list + assert "tgu/test_dir/test_file_2" in file_list + assert "tgu/test_dir" not in file_list def test_missing_path(self): - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError, match="Must provide path argument"): self.staging_helper.metadata() - self.assertEqual("Must provide path argument", str(context.exception)) @patch.object(Helper, "_Helper__fetch_url", side_effect=mock_fetch_url) def test_metadata(self, _fetch_url): metadata = self.staging_helper.metadata("test_fake_file") - self.assertTrue("head" in metadata) - self.assertEqual(metadata.get("head"), "head_line") - self.assertTrue("tail" in metadata) - self.assertEqual(metadata.get("tail"), "tail_line") - self.assertTrue("lineCount" in metadata) - self.assertEqual(metadata.get("lineCount"), 10) + assert "head" in metadata + assert metadata.get("head") == "head_line" + assert "tail" in metadata + assert metadata.get("tail") == "tail_line" + assert "lineCount" in metadata + assert metadata.get("lineCount") == 10 @patch.object(Helper, "_Helper__fetch_url", side_effect=mock_fetch_url) def test_jgi_metadata(self, _fetch_url): metadata = self.staging_helper.jgi_metadata("test_fake_file") - self.assertTrue("file_name" in metadata) - self.assertEqual(metadata.get("file_name"), "test_file") - self.assertTrue("file_status" in metadata) - self.assertEqual(metadata.get("file_status"), "BACKUP_COMPLETE") + assert "file_name" in metadata + assert metadata.get("file_name") == "test_file" + assert "file_status" in metadata + assert metadata.get("file_status") == "BACKUP_COMPLETE" @patch.object(Helper, "_Helper__fetch_url", side_effect=mock_fetch_url) def test_search(self, _fetch_url): search_ret = self.staging_helper.search("test_fake_file") - self.assertTrue(isinstance(search_ret, (list))) + assert isinstance(search_ret, list) element = search_ret[0] - self.assertTrue("isFolder" in element) - self.assertFalse(element.get("isFolder")) - self.assertTrue("name" in element) - self.assertEqual(element.get("name"), "LMS-PROC-315.pdf") + assert "isFolder" in element + assert not element.get("isFolder") + assert "name" in element + assert element.get("name") == "LMS-PROC-315.pdf" @patch.object(Helper, "_Helper__fetch_url", side_effect=mock_fetch_url) def test_delete(self, _fetch_url): delete_ret = self.staging_helper.delete("test_fake_file") - self.assertTrue("server_response" in delete_ret) - self.assertEqual( - delete_ret.get("server_response"), "successfully deleted tgu2/test.pdf" - ) + assert "server_response" in delete_ret + assert delete_ret.get("server_response") == "successfully deleted tgu2/test.pdf" @patch.object(Helper, "_Helper__fetch_url", side_effect=mock_fetch_url) def test_download(self, _fetch_url): download_ret = self.staging_helper.download("test_fake_file") - self.assertTrue("test_fake_file" in download_ret) + assert "test_fake_file" in download_ret @patch.object(Helper, "_Helper__fetch_url", side_effect=mock_fetch_url) def test_mv(self, _fetch_url): mv_ret = self.staging_helper.mv("test.pdf ", "test_1.pdf") - self.assertTrue("server_response" in mv_ret) - self.assertEqual( - mv_ret.get("server_response"), - "successfully moved tgu2/test.pdf to tgu2/test_1.pdf", + assert "server_response" in mv_ret + assert ( + mv_ret.get("server_response") + == "successfully moved tgu2/test.pdf to tgu2/test_1.pdf" ) diff --git a/src/biokbase/narrative/tests/test_system.py b/src/biokbase/narrative/tests/test_system.py index aed5e5f68e..2c58561b8d 100644 --- a/src/biokbase/narrative/tests/test_system.py +++ b/src/biokbase/narrative/tests/test_system.py @@ -2,9 +2,8 @@ import time from unittest import mock -import pytest - import biokbase.auth +import pytest from biokbase.narrative.system import strict_system_variable, system_variable from . import util diff --git a/src/biokbase/narrative/tests/test_upa_api.py b/src/biokbase/narrative/tests/test_upa_api.py index 174c37f438..121bff3e6d 100644 --- a/src/biokbase/narrative/tests/test_upa_api.py +++ b/src/biokbase/narrative/tests/test_upa_api.py @@ -5,6 +5,7 @@ import unittest from unittest import mock +import pytest from biokbase.narrative.upa import ( deserialize, external_tag, @@ -83,7 +84,7 @@ def setUpClass(self): "foo/bar/3;1/2/3", ] - self.bad_refs = ["foo", "1", "1/2/3/4" "1;2"] + self.bad_refs = ["foo", "1", "1/2/3/41;2"] for bad_upa in self.bad_upas: self.bad_serials.append(external_tag + bad_upa) @@ -91,46 +92,43 @@ def setUpClass(self): def test_serialize_good(self): for pair in self.serialize_test_data: serial_upa = serialize(pair["upa"]) - self.assertEqual(serial_upa, pair["serial"]) + assert serial_upa == pair["serial"] def test_serialize_external_good(self): for pair in self.serialize_external_test_data: serial_upa = serialize_external(pair["upa"]) - self.assertEqual(serial_upa, pair["serial"]) + assert serial_upa == pair["serial"] @mock.patch("biokbase.narrative.upa.system_variable", mock_sys_var) def test_deserialize_good(self): for pair in self.serialize_test_data + self.serialize_external_test_data: if not isinstance(pair["upa"], list): deserial_upa = deserialize(pair["serial"]) - self.assertEqual(deserial_upa, pair["upa"]) + assert deserial_upa == pair["upa"] @mock.patch("biokbase.narrative.upa.system_variable", mock_sys_var) def test_serialize_bad(self): for bad_upa in self.bad_upas: - with self.assertRaisesRegex( + with pytest.raises( ValueError, - r'^".+" is not a valid UPA\. It may have already been serialized\.$', + match=r'^".+" is not a valid UPA\. It may have already been serialized\.$', ): serialize(bad_upa) @mock.patch("biokbase.narrative.upa.system_variable", mock_sys_var) def test_deserialize_bad(self): for bad_serial in self.bad_serials: - with self.assertRaisesRegex( - ValueError, 'Deserialized UPA: ".+" is invalid!' - ): + with pytest.raises(ValueError, match='Deserialized UPA: ".+" is invalid!'): deserialize(bad_serial) @mock.patch("biokbase.narrative.upa.system_variable", mock_sys_var) def test_deserialize_bad_type(self): bad_types = [["123/4/5", "6/7/8"], {"123": "456"}, None] for t in bad_types: - with self.assertRaises(ValueError) as e: + with pytest.raises( + ValueError, match="Can only deserialize UPAs from strings." + ): deserialize(t) - self.assertEqual( - str(e.exception), "Can only deserialize UPAs from strings." - ) def test_missing_ws_deserialize(self): tmp = None @@ -138,12 +136,11 @@ def test_missing_ws_deserialize(self): tmp = os.environ.get("KB_WORKSPACE_ID") del os.environ["KB_WORKSPACE_ID"] try: - with self.assertRaises(RuntimeError) as e: + with pytest.raises( + RuntimeError, + match="Currently loaded workspace is unknown! Unable to deserialize UPA.", + ): deserialize("[1]/2/3") - self.assertEqual( - str(e.exception), - "Currently loaded workspace is unknown! Unable to deserialize UPA.", - ) finally: if tmp is not None: os.environ["KB_WORKSPACE_ID"] = tmp @@ -154,10 +151,10 @@ def test_is_ref(self): UPAs should pass, as well as shorter ws_name/obj_name, ws_id/obj_name, ws_id/obj_id references. """ for ref in self.good_refs: - self.assertTrue(is_ref(ref)) + assert is_ref(ref) for ref in self.bad_refs: - self.assertFalse(is_ref(ref)) + assert is_ref(ref) is False if __name__ == "__main__": diff --git a/src/biokbase/narrative/tests/test_updater.py b/src/biokbase/narrative/tests/test_updater.py index dba74fddc0..57e5ddf05c 100644 --- a/src/biokbase/narrative/tests/test_updater.py +++ b/src/biokbase/narrative/tests/test_updater.py @@ -10,8 +10,8 @@ class KeyErrorTest(ValueError): - def __init__(self, keyname, source): - ValueError.__init__(self, "Key {} not found in {}".format(keyname, source)) + def __init__(self, keyname, source) -> None: + ValueError.__init__(self, f"Key {keyname} not found in {source}") class UpdaterTestCase(unittest.TestCase): @@ -122,35 +122,33 @@ def validate_cell(self, cell): def test_update_narrative(self): nar_update = update_narrative(self.test_nar) - self.assertTrue(self.validate_narrative(nar_update)) + assert self.validate_narrative(nar_update) def test_update_narrative_big(self): nar_update = update_narrative(self.test_nar_big) - self.assertTrue(self.validate_narrative(nar_update)) + assert self.validate_narrative(nar_update) def test_update_narrative_poplar(self): nar_update = update_narrative(self.test_nar_poplar) - self.assertTrue(self.validate_narrative(nar_update)) + assert self.validate_narrative(nar_update) def test_find_app(self): info = find_app_info("NarrativeTest/test_input_params") - self.assertTrue(isinstance(info, dict)) + assert isinstance(info, dict) def test_find_bad_app(self): - self.assertIsNone(find_app_info("NotAnAppModule")) + assert find_app_info("NotAnAppModule") is None def test_suggest_apps(self): obsolete_id = "build_a_metabolic_model" suggestions = suggest_apps(obsolete_id) - self.assertTrue(isinstance(suggestions, list)) - self.assertEqual( - suggestions[0]["spec"]["info"]["id"], "fba_tools/build_metabolic_model" - ) + assert isinstance(suggestions, list) + assert suggestions[0]["spec"]["info"]["id"] == "fba_tools/build_metabolic_model" def test_suggest_apps_none(self): suggestions = suggest_apps("NotAnAppModule") - self.assertTrue(isinstance(suggestions, list)) - self.assertEqual(len(suggestions), 0) + assert isinstance(suggestions, list) + assert len(suggestions) == 0 if __name__ == "__main__": diff --git a/src/biokbase/narrative/tests/test_url_config.py b/src/biokbase/narrative/tests/test_url_config.py index dae2cd9061..ec04af1f22 100644 --- a/src/biokbase/narrative/tests/test_url_config.py +++ b/src/biokbase/narrative/tests/test_url_config.py @@ -6,14 +6,14 @@ class UrlConfigTest(unittest.TestCase): def test_getter(self): url = URLS.workspace - self.assertTrue(url.endswith("/services/ws")) + assert url.endswith("/services/ws") def test_get_url(self): url = URLS.get_url("workspace") - self.assertTrue(url.endswith("/services/ws")) + assert url.endswith("/services/ws") def test_missing_url(self): - self.assertIsNone(URLS.nope) + assert URLS.nope is None if __name__ == "__main__": diff --git a/src/biokbase/narrative/tests/test_user_service.py b/src/biokbase/narrative/tests/test_user_service.py index 269d17784e..88b28752c3 100644 --- a/src/biokbase/narrative/tests/test_user_service.py +++ b/src/biokbase/narrative/tests/test_user_service.py @@ -12,4 +12,4 @@ class UserServiceTestCase(unittest.TestCase): def test_user_trust(self): us = UserService() - self.assertTrue(us.is_trusted_user("anybody")) + assert us.is_trusted_user("anybody") diff --git a/src/biokbase/narrative/tests/test_viewers.py b/src/biokbase/narrative/tests/test_viewers.py index b9839129b4..e4ea348e23 100644 --- a/src/biokbase/narrative/tests/test_viewers.py +++ b/src/biokbase/narrative/tests/test_viewers.py @@ -1,6 +1,7 @@ import unittest import biokbase.auth +import pytest from . import util @@ -30,18 +31,17 @@ def setUpClass(cls): def test_bad_view_as_clustergrammer_params(self): from biokbase.narrative import viewers - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): viewers.view_as_clustergrammer(self.generic_ref, col_categories="Time") - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): viewers.view_as_clustergrammer(self.generic_ref, row_categories="Time") - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): viewers.view_as_clustergrammer(self.generic_ref, normalize_on="Time") - with self.assertRaisesRegex(ValueError, "not a compatible data type"): + with pytest.raises(ValueError, match="not a compatible data type"): viewers.view_as_clustergrammer(self.attribute_set_ref) def test__get_categories(self): import pandas as pd - from biokbase.narrative import viewers ids = ["WRI_RS00010_CDS_1", "WRI_RS00015_CDS_1", "WRI_RS00025_CDS_1"] @@ -83,23 +83,21 @@ def test__get_categories(self): labels=[[0, 1, 2], [0, 1, 2]], names=["ID", "test_attribute_1"], ) - self.assertEqual(ids, viewers._get_categories(ids, self.generic_ref)) - with self.assertRaisesRegex(ValueError, "not in the provided mapping"): + assert ids, viewers._get_categories(ids == self.generic_ref) + with pytest.raises(ValueError, match="not in the provided mapping"): viewers._get_categories( ["boo"], self.generic_ref, self.attribute_set_ref, mapping ) - with self.assertRaisesRegex(ValueError, "has no attribute"): + with pytest.raises(ValueError, match="has no attribute"): viewers._get_categories(["boo"], self.generic_ref, self.attribute_set_ref) - self.assertEqual( - index, - viewers._get_categories( - ids, - self.generic_ref, - self.attribute_set_ref, - mapping, - clustergrammer=True, - ), + assert index == viewers._get_categories( + ids, + self.generic_ref, + self.attribute_set_ref, + mapping, + clustergrammer=True, ) + pd.testing.assert_index_equal( multi_index, viewers._get_categories( @@ -110,47 +108,43 @@ def test__get_categories(self): {"test_attribute_1"}, ), ) - self.assertEqual( - filtered_index, - viewers._get_categories( - ids, - self.generic_ref, - self.attribute_set_ref, - mapping, - {"test_attribute_1"}, - clustergrammer=True, - ), + assert filtered_index == viewers._get_categories( + ids, + self.generic_ref, + self.attribute_set_ref, + mapping, + {"test_attribute_1"}, + clustergrammer=True, ) def test_get_df(self): import pandas as pd - from biokbase.narrative import viewers res = viewers.get_df(self.generic_ref) - self.assertIsInstance(res, pd.DataFrame) - self.assertEqual(res.shape, (3, 4)) - self.assertIsInstance(res.index, pd.MultiIndex) + assert isinstance(res, pd.DataFrame) + assert res.shape, 3 == 4 + assert isinstance(res.index, pd.MultiIndex) res = viewers.get_df(self.generic_ref, None, None) - self.assertIsInstance(res, pd.DataFrame) - self.assertEqual(res.shape, (3, 4)) - self.assertIsInstance(res.index, pd.Index) + assert isinstance(res, pd.DataFrame) + assert res.shape, 3 == 4 + assert isinstance(res.index, pd.Index) res = viewers.get_df(self.generic_ref, clustergrammer=True) - self.assertIsInstance(res, pd.DataFrame) - self.assertEqual(res.shape, (3, 4)) - self.assertIsInstance(res.index, pd.Index) + assert isinstance(res, pd.DataFrame) + assert res.shape, 3 == 4 + assert isinstance(res.index, pd.Index) res = viewers.get_df(self.expression_matrix_ref) - self.assertIsInstance(res, pd.DataFrame) - self.assertEqual(res.shape, (4297, 16)) - self.assertIsInstance(res.index, pd.Index) + assert isinstance(res, pd.DataFrame) + assert res.shape, 4297 == 16 + assert isinstance(res.index, pd.Index) def test_view_as_clustergrammer(self): from biokbase.narrative import viewers - self.assertEqual( - str(type(viewers.view_as_clustergrammer(self.generic_ref))), - "", + assert ( + str(type(viewers.view_as_clustergrammer(self.generic_ref))) + == "" ) diff --git a/src/biokbase/narrative/tests/test_widgetmanager.py b/src/biokbase/narrative/tests/test_widgetmanager.py index b075cfc457..dc1022dd96 100644 --- a/src/biokbase/narrative/tests/test_widgetmanager.py +++ b/src/biokbase/narrative/tests/test_widgetmanager.py @@ -3,7 +3,7 @@ from unittest import mock import IPython - +import pytest from biokbase.narrative.widgetmanager import WidgetManager from .narrative_mock.mockclients import get_mock_client @@ -15,6 +15,20 @@ __author__ = "Bill Riehl " +def assert_is_valid_cell_code(js_obj): + code_lines = js_obj.data.strip().split("\n") + assert code_lines[0].strip().startswith("element.html(\"
None: self._path_prefix = os.path.join( os.environ["NARRATIVE_DIR"], "src", "biokbase", "narrative", "tests" ) @@ -66,7 +66,7 @@ def load_json_file(self, filename): location in /src/biokbase/narrative/tests """ json_file_path = self.file_path(filename) - with open(json_file_path, "r") as f: + with open(json_file_path) as f: data = json.loads(f.read()) f.close() return data @@ -121,7 +121,7 @@ def upload_narrative(nar_file, auth_token, user_id, url=ci_ws, set_public=False) """ # read the file - with open(nar_file, "r") as f: + with open(nar_file) as f: nar = json.loads(f.read()) f.close() @@ -163,7 +163,7 @@ def upload_narrative(nar_file, auth_token, user_id, url=ci_ws, set_public=False) return { "ws": ws_info[0], "obj": obj_info[0][0], - "refstr": "{}/{}".format(ws_info[0], obj_info[0][0]), + "refstr": f"{ws_info[0]}/{obj_info[0][0]}", "ref": NarrativeRef({"wsid": ws_info[0], "objid": obj_info[0][0]}), } @@ -185,7 +185,7 @@ def read_token_file(path): if not os.path.isfile(path): return None - with open(path, "r") as f: + with open(path) as f: token = f.read().strip() f.close() return token @@ -196,41 +196,16 @@ def read_json_file(path): Generically reads in any JSON file and returns it as a dict. Especially intended for reading a Narrative file. """ - with open(path, "r") as f: + with open(path) as f: data = json.loads(f.read()) f.close() return data -class MyTestCase(unittest.TestCase): - def test_kvparse(self): - for user_input, text, kvp in ( - ("foo", "foo", {}), - ("name=val", "", {"name": "val"}), - ("a name=val boy", "a boy", {"name": "val"}), - ): - rkvp = {} - rtext = util.parse_kvp(user_input, rkvp) - self.assertEqual( - text, - rtext, - "Text '{}' does not match " - "result '{}' " - "from input '{}'".format(text, rtext, user_input), - ) - self.assertEqual( - text, - rtext, - "Dict '{}' does not match " - "result '{}' " - "from input '{}'".format(kvp, rkvp, user_input), - ) - - class SocketServerBuf(socketserver.TCPServer): allow_reuse_address = True - def __init__(self, addr, handler): + def __init__(self, addr, handler) -> None: socketserver.TCPServer.__init__(self, addr, handler) self.buf = "" @@ -283,7 +258,7 @@ def handle(self): def start_tcp_server(host, port, poll_interval, bufferer=LogProxyMessageBufferer): - _log.info("Starting server on {}:{}".format(host, port)) + _log.info(f"Starting server on {host}:{port}") server = SocketServerBuf((host, port), bufferer) thr = threading.Thread(target=server.serve_forever, args=[poll_interval]) thr.daemon = True @@ -318,7 +293,7 @@ def validate_job_state(job_state: dict) -> None: assert isinstance(job_state["jobState"], dict), "jobState is not a dict" assert "outputWidgetInfo" in job_state, "outputWidgetInfo key missing" assert isinstance( - job_state["outputWidgetInfo"], (dict, NoneType) + job_state["outputWidgetInfo"], dict | NoneType ), "outputWidgetInfo is not a dict or None" state = job_state["jobState"] # list of tuples - first = key name, second = value type diff --git a/src/scripts/fix_ws_metadata/test_fix_workspace_info.py b/src/scripts/fix_ws_metadata/test_fix_workspace_info.py index 809ca0fb89..ad07cdc508 100644 --- a/src/scripts/fix_ws_metadata/test_fix_workspace_info.py +++ b/src/scripts/fix_ws_metadata/test_fix_workspace_info.py @@ -2,9 +2,9 @@ import unittest from unittest import mock -from requests.exceptions import HTTPError - +import pytest from biokbase.workspace.baseclient import ServerError +from requests.exceptions import HTTPError from . import fix_workspace_info @@ -32,8 +32,7 @@ def raise_for_status(self): tok = kwargs["headers"]["Authorization"] if "good" in tok: return MockResponse(json.dumps({"user": FAKE_ADMIN_ID}), 200) - else: - return MockResponse("Bad!", 401) + return MockResponse("Bad!", 401) class MockWorkspace: @@ -154,13 +153,12 @@ def test_parse_args(self): ] for input_args in good_args_set: args = fix_workspace_info.parse_args(input_args) - self.assertEqual(args.token, token) - self.assertEqual(args.auth_url, auth_url) - self.assertEqual(args.ws_url, ws_url) + assert args.token == token + assert args.auth_url == auth_url + assert args.ws_url == ws_url for bad_args in bad_args_set: - with self.assertRaises(ValueError) as e: + with pytest.raises(ValueError, match=bad_args[1]): fix_workspace_info.parse_args(bad_args[0]) - self.assertIn(bad_args[1], str(e.exception)) def test__admin_update_metadata(self): reset_fake_ws_db() @@ -169,28 +167,28 @@ def test__admin_update_metadata(self): ws_id = 1 fix_workspace_info._admin_update_metadata(ws, FAKE_ADMIN_ID, ws_id, new_meta) - self.assertEqual(ws.fake_ws_db[str(ws_id)]["ws_info"][8]["foo"], "bar") - self.assertNotIn( - FAKE_ADMIN_ID, - ws.administer( + assert ws.fake_ws_db[str(ws_id)]["ws_info"][8]["foo"] == "bar" + assert ( + FAKE_ADMIN_ID + not in ws.administer( { "command": "getPermissionsMass", "params": {"workspaces": [{"id": ws_id}]}, } - )["perms"][0], + )["perms"][0] ) ws_id = 3 fix_workspace_info._admin_update_metadata(ws, FAKE_ADMIN_ID, ws_id, new_meta) - self.assertEqual(ws.fake_ws_db[str(ws_id)]["ws_info"][8]["foo"], "bar") - self.assertEqual( + assert ws.fake_ws_db[str(ws_id)]["ws_info"][8]["foo"] == "bar" + assert ( ws.administer( { "command": "getPermissionsMass", "params": {"workspaces": [{"id": ws_id}]}, } - )["perms"][0].get(FAKE_ADMIN_ID), - "r", + )["perms"][0].get(FAKE_ADMIN_ID) + == "r" ) @mock.patch( @@ -199,8 +197,8 @@ def test__admin_update_metadata(self): ) def test__get_user_id(self, request_mock): userid = fix_workspace_info._get_user_id("some_endpoint", "goodtoken") - self.assertEqual(userid, FAKE_ADMIN_ID) - with self.assertRaises(HTTPError): + assert userid == FAKE_ADMIN_ID + with pytest.raises(HTTPError): fix_workspace_info._get_user_id("some_endpoint", "badtoken") @mock.patch( @@ -212,7 +210,7 @@ def test_fix_all_workspace_info(self, ws_mock, request_mock): reset_fake_ws_db() fake_ws = MockWorkspace() fix_workspace_info.Workspace = MockWorkspace - with self.assertRaises(HTTPError): + with pytest.raises(HTTPError): fix_workspace_info.fix_all_workspace_info( "fake_ws", "fake_auth", "bad_token", 20 ) @@ -222,94 +220,74 @@ def test_fix_all_workspace_info(self, ws_mock, request_mock): ) # TODO: add actual tests for results of "database" # ws1 - no change to metadata - self.assertEqual(fake_ws.fake_ws_db["1"]["ws_info"][8], {}) + assert fake_ws.fake_ws_db["1"]["ws_info"][8] == {} # ws2 - add cell_count = 1 - self.assertEqual( - fake_ws.fake_ws_db["2"]["ws_info"][8], - { - "is_temporary": "false", - "narrative": "1", - "narrative_nice_name": "Test", - "cell_count": "1", - "searchtags": "narrative", - }, - ) + assert fake_ws.fake_ws_db["2"]["ws_info"][8] == { + "is_temporary": "false", + "narrative": "1", + "narrative_nice_name": "Test", + "cell_count": "1", + "searchtags": "narrative", + } # ws3 - not temp, fix name, add num-cells - self.assertEqual( - fake_ws.fake_ws_db["3"]["ws_info"][8], - { - "is_temporary": "false", - "narrative": "1", - "narrative_nice_name": "Test3", - "cell_count": "2", - "searchtags": "narrative", - }, - ) + assert fake_ws.fake_ws_db["3"]["ws_info"][8] == { + "is_temporary": "false", + "narrative": "1", + "narrative_nice_name": "Test3", + "cell_count": "2", + "searchtags": "narrative", + } # ws4 - not temp, 1 cell - self.assertEqual( - fake_ws.fake_ws_db["4"]["ws_info"][8], - { - "is_temporary": "false", - "narrative": "1", - "narrative_nice_name": "Test4", - "cell_count": "1", - "searchtags": "narrative", - }, - ) + assert fake_ws.fake_ws_db["4"]["ws_info"][8] == { + "is_temporary": "false", + "narrative": "1", + "narrative_nice_name": "Test4", + "cell_count": "1", + "searchtags": "narrative", + } # ws5 - add num cells. even though there's > 1, it's configured. - self.assertEqual( - fake_ws.fake_ws_db["5"]["ws_info"][8], - {"is_temporary": "false", "narrative": "1", "narrative_nice_name": "Test5"}, - ) + assert fake_ws.fake_ws_db["5"]["ws_info"][8] == { + "is_temporary": "false", + "narrative": "1", + "narrative_nice_name": "Test5", + } # ws6 - fix id, add cell count - self.assertEqual( - fake_ws.fake_ws_db["6"]["ws_info"][8], - { - "is_temporary": "false", - "narrative": "3", - "narrative_nice_name": "Test6", - "cell_count": "1", - "searchtags": "narrative", - }, - ) + assert fake_ws.fake_ws_db["6"]["ws_info"][8] == { + "is_temporary": "false", + "narrative": "3", + "narrative_nice_name": "Test6", + "cell_count": "1", + "searchtags": "narrative", + } # ws7 - fix id, cell count - self.assertEqual( - fake_ws.fake_ws_db["7"]["ws_info"][8], - { - "is_temporary": "false", - "narrative": "3", - "narrative_nice_name": "Test7", - "cell_count": "1", - "searchtags": "narrative", - }, - ) + assert fake_ws.fake_ws_db["7"]["ws_info"][8] == { + "is_temporary": "false", + "narrative": "3", + "narrative_nice_name": "Test7", + "cell_count": "1", + "searchtags": "narrative", + } # ws8 - missing metadata all together, so add it - self.assertEqual( - fake_ws.fake_ws_db["8"]["ws_info"][8], - { - "is_temporary": "false", - "narrative": "3", - "narrative_nice_name": "Test8", - "cell_count": "1", - "searchtags": "narrative", - }, - ) + assert fake_ws.fake_ws_db["8"]["ws_info"][8] == { + "is_temporary": "false", + "narrative": "3", + "narrative_nice_name": "Test8", + "cell_count": "1", + "searchtags": "narrative", + } # ws9 - missing metadata here, too, but it's temporary - self.assertEqual( - fake_ws.fake_ws_db["9"]["ws_info"][8], - { - "is_temporary": "true", - "narrative": "3", - "narrative_nice_name": "Untitled", - "cell_count": "1", - "searchtags": "narrative", - }, - ) + assert fake_ws.fake_ws_db["9"]["ws_info"][8] == { + "is_temporary": "true", + "narrative": "3", + "narrative_nice_name": "Untitled", + "cell_count": "1", + "searchtags": "narrative", + } From 7840bf1ae647a4926fe21ce51b398199bd0b9e02 Mon Sep 17 00:00:00 2001 From: ialarmedalien Date: Fri, 22 Dec 2023 08:19:27 -0800 Subject: [PATCH 3/3] Fixing a couple of code hygiene issues --- src/biokbase/auth.py | 11 ++++------- src/biokbase/narrative/jobs/specmanager.py | 7 ++++++- .../narrative/tests/narrative_mock/mockclients.py | 1 + src/biokbase/narrative/tests/test_app_util.py | 3 --- src/biokbase/narrative/tests/test_specmanager.py | 9 +++++---- src/biokbase/narrative/tests/test_system.py | 3 +-- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/biokbase/auth.py b/src/biokbase/auth.py index 12a229c0ed..3659b2c96c 100644 --- a/src/biokbase/auth.py +++ b/src/biokbase/auth.py @@ -54,26 +54,23 @@ def __init__(self, user_dict: dict): self.user_name = user_dict.get("user") -def validate_token(): +def validate_token() -> bool: """ Validates the currently set auth token. Returns True if valid, False otherwise. """ headers = {"Authorization": get_auth_token()} r = requests.get(token_api_url + endpt_token, headers=headers) - if r.status_code == 200: - return True - else: - return False + return r.status_code == 200 -def set_environ_token(token: str) -> None: +def set_environ_token(token: str | None) -> None: """ Sets a login token in the local environment variable. """ kbase_env.auth_token = token -def get_auth_token() -> Optional[str]: +def get_auth_token() -> str | None: """ Returns the current login token being used, or None if one isn't set. """ diff --git a/src/biokbase/narrative/jobs/specmanager.py b/src/biokbase/narrative/jobs/specmanager.py index 43b868e149..4fad76e2cd 100644 --- a/src/biokbase/narrative/jobs/specmanager.py +++ b/src/biokbase/narrative/jobs/specmanager.py @@ -140,7 +140,12 @@ def app_usage(self, app_id, tag="release"): return AppUsage(usage) - def check_app(self, app_id, tag="release", raise_exception=False): + def check_app( + self: "SpecManager", + app_id: str, + tag: str = "release", + raise_exception: bool = False, + ): """ Checks if a method (and release tag) is available for running and such. If raise_exception==True, and either the tag or app_id are invalid, a ValueError is raised. diff --git a/src/biokbase/narrative/tests/narrative_mock/mockclients.py b/src/biokbase/narrative/tests/narrative_mock/mockclients.py index 73f1f1ee44..4b226afee9 100644 --- a/src/biokbase/narrative/tests/narrative_mock/mockclients.py +++ b/src/biokbase/narrative/tests/narrative_mock/mockclients.py @@ -532,6 +532,7 @@ def get_failing_mock_client(client_name, token=None): class FailingMockClient: def __init__(self, token=None) -> None: + # nothing to do here pass def check_workspace_jobs(self, params): diff --git a/src/biokbase/narrative/tests/test_app_util.py b/src/biokbase/narrative/tests/test_app_util.py index e3486a4f84..d618adbe71 100644 --- a/src/biokbase/narrative/tests/test_app_util.py +++ b/src/biokbase/narrative/tests/test_app_util.py @@ -505,9 +505,6 @@ def get_workspace(_): class RefChainWorkspace: - def __init__(self) -> None: - pass - def get_object_info3(self, params): """ Makes quite a few assumptions about input, as it's used for a specific test. diff --git a/src/biokbase/narrative/tests/test_specmanager.py b/src/biokbase/narrative/tests/test_specmanager.py index b91739453a..9179026166 100644 --- a/src/biokbase/narrative/tests/test_specmanager.py +++ b/src/biokbase/narrative/tests/test_specmanager.py @@ -34,8 +34,11 @@ def test_check_app(self): # bad id and good tag no raise assert self.sm.check_app(self.bad_app_id, self.good_tag) is False + def test_check_app_error(self): # bad id with raise - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match=f'Unknown app id "{self.bad_app_id}" tagged as "release"' + ): self.sm.check_app(self.bad_app_id, raise_exception=True) @mock.patch("biokbase.narrative.jobs.specmanager.clients.get", get_mock_client) @@ -46,9 +49,7 @@ def test_get_type_spec(self): self.sm.get_type_spec("KBaseFBA.NU_FBA").keys() ) with pytest.raises(ValueError, match="Unknown type"): - assert "export_functions" in list( - self.sm.get_type_spec("KBaseExpression.NU_FBA").keys() - ) + self.sm.get_type_spec("KBaseExpression.NU_FBA") if __name__ == "__main__": diff --git a/src/biokbase/narrative/tests/test_system.py b/src/biokbase/narrative/tests/test_system.py index 2c58561b8d..b02e0344e1 100644 --- a/src/biokbase/narrative/tests/test_system.py +++ b/src/biokbase/narrative/tests/test_system.py @@ -101,7 +101,6 @@ def test_strict_sys_var_user_bad(): biokbase.auth.set_environ_token(bad_fake_token) with pytest.raises( ValueError, match='Unable to retrieve system variable: "user_id"' - ) as e: + ): strict_system_variable("user_id") - assert e biokbase.auth.set_environ_token(None)