diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..0b3ba9da --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,14 @@ +# This is the configuration file for Dependabot. You can find configuration information below. +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +# Note: Dependabot has a configurable max open PR limit of 5 + +version: 2 +updates: + + # Maintain dependencies for our GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" + labels: + - "dependencies" diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index f071112e..19d0d428 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -20,11 +20,11 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11"] steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Initialize CodeQL uses: github/codeql-action/init@v2 @@ -35,10 +35,13 @@ jobs: uses: github/codeql-action/analyze@v2 - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} + - name: Pre Commit Checks + uses: pre-commit/action@v3.0.0 + - name: Setup Temp Directory run: mkdir broker_dir @@ -47,7 +50,7 @@ jobs: BROKER_DIRECTORY: "${{ github.workspace }}/broker_dir" run: | pip install -U pip - pip install -U .[test,docker] + pip install -U .[dev,docker] ls -l "$BROKER_DIRECTORY" broker --version pytest -v tests/ --ignore tests/functional diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 127a33e5..e7062f93 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -2,7 +2,7 @@ name: PythonPackage on: push: - tags: + tags: - "*" jobs: @@ -12,7 +12,7 @@ jobs: strategy: matrix: # build/push in lowest support python version - python-version: [ 3.9 ] + python-version: [ 3.10 ] steps: - uses: actions/checkout@v2 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..218e7286 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,12 @@ +# configuration for pre-commit git hooks +repos: +- repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.0.277 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index bb3ec5f0..00000000 --- a/MANIFEST.in +++ /dev/null @@ -1 +0,0 @@ -include README.md diff --git a/README.md b/README.md index fe88f3b5..ae65c72d 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,8 @@ Copy the example settings file to `broker_settings.yaml` and edit it. (optional) If you are using the Container provider, install the extra dependency based on your container runtime of choice with either `pip install broker[podman]` or `pip install broker[docker]`. +(optional) If you are using the Beaker provider, install the extra dependency with `dnf install krb5-devel` and then `pip install broker[beaker]`. + To run Broker outside of its base directory, specify the directory with the `BROKER_DIRECTORY` environment variable. Configure the `broker_settings.yaml` file to set configuration values for broker's interaction with its providers. diff --git a/broker/__init__.py b/broker/__init__.py index 5449b8ad..24b2d562 100644 --- a/broker/__init__.py +++ b/broker/__init__.py @@ -1,3 +1,2 @@ -from broker.broker import Broker - -VMBroker = Broker +"""Shortcuts for the broker module.""" +from broker.broker import Broker # noqa: F401 diff --git a/broker/binds/__init__.py b/broker/binds/__init__.py index e69de29b..be665490 100644 --- a/broker/binds/__init__.py +++ b/broker/binds/__init__.py @@ -0,0 +1 @@ +"""Binds provide interfaces between a provider's interface and the Broker Provider class.""" diff --git a/broker/binds/beaker.py b/broker/binds/beaker.py new file mode 100644 index 00000000..ba616c18 --- /dev/null +++ b/broker/binds/beaker.py @@ -0,0 +1,246 @@ +"""A wrapper around the Beaker CLI.""" +import json +from pathlib import Path +import subprocess +import time +from xml.etree import ElementTree as ET + +from logzero import logger + +from broker import helpers +from broker.exceptions import BeakerBindError + + +def _elementree_to_dict(etree): + """Convert an ElementTree object to a dictionary.""" + data = {} + if etree.attrib: + data.update(etree.attrib) + if etree.text: + data["text"] = etree.text + for child in etree: + child_data = _elementree_to_dict(child) + if (tag := child.tag) in data: + if not isinstance(data[tag], list): + data[tag] = [data[tag]] + data[tag].append(child_data) + else: + data[tag] = child_data + return data + + +def _curate_job_info(job_info_dict): + curated_info = { + "job_id": "id", + # "reservation_id": "current_reservation/recipe_id", + "whiteboard": "whiteboard/text", + "hostname": "recipeSet/recipe/system", + "distro": "recipeSet/recipe/distro", + } + return helpers.dict_from_paths(job_info_dict, curated_info) + + +class BeakerBind: + """A bind class providing a basic interface to the Beaker CLI.""" + + def __init__(self, hub_url, auth="krbv", **kwargs): + self.hub_url = hub_url + self._base_args = ["--insecure", f"--hub={self.hub_url}"] + if auth == "basic": + # If we're not using system kerberos auth, add in explicit basic auth + self.username = kwargs.pop("username", None) + self.password = kwargs.pop("password", None) + self._base_args.extend( + [ + f"--username {self.username}", + f"--password {self.password}", + ] + ) + self.__dict__.update(kwargs) + + def _exec_command(self, *cmd_args, **cmd_kwargs): + """Execute a beaker command and return the result. + + cmd_args: Expanded into feature flags for the beaker command + cmd_kwargs: Expanded into args and values for the beaker command + """ + raise_on_error = cmd_kwargs.pop("raise_on_error", True) + exec_cmd, cmd_args = ["bkr"], list(cmd_args) + # check through kwargs and if any are True add to cmd_args + del_keys = [] + for k, v in cmd_kwargs.items(): + if isinstance(v, bool) or v is None: + del_keys.append(k) + if v is True: + cmd_args.append(f"--{k}" if not k.startswith("--") else k) + for k in del_keys: + del cmd_kwargs[k] + exec_cmd.extend(cmd_args) + exec_cmd.extend(self._base_args) + exec_cmd.extend([f"--{k.replace('_', '-')}={v}" for k, v in cmd_kwargs.items()]) + logger.debug(f"Executing beaker command: {exec_cmd}") + proc = subprocess.Popen( + exec_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + stdout, stderr = proc.communicate() + result = helpers.Result( + stdout=stdout.decode(), + stderr=stderr.decode(), + status=proc.returncode, + ) + if result.status != 0 and raise_on_error: + raise BeakerBindError( + f"Beaker command failed:\nCommand={' '.join(exec_cmd)}\nResult={result}", + ) + logger.debug(f"Beaker command result: {result.stdout}") + return result + + def job_submit(self, job_xml, wait=False): + """Submit a job to Beaker and optionally wait for it to complete.""" + # wait behavior seems buggy to me, so best to avoid it + if not Path(job_xml).exists(): + raise FileNotFoundError(f"Job XML file {job_xml} not found") + result = self._exec_command("job-submit", job_xml, wait=wait) + if not wait: + # get the job id from the output + # format is "Submitted: ['J:7849837'] where the number is the job id + for line in result.stdout.splitlines(): + if line.startswith("Submitted:"): + return line.split("'")[1].replace("J:", "") + + def job_watch(self, job_id): + """Watch a job via the job-watch command. This can be buggy.""" + job_id = f"J:{job_id}" if not job_id.startswith("J:") else job_id + return self._exec_command("job-watch", job_id) + + def job_results(self, job_id, format="beaker-results-xml", pretty=False): + """Get the results of a job in the specified format.""" + job_id = f"J:{job_id}" if not job_id.startswith("J:") else job_id + return self._exec_command("job-results", job_id, format=format, prettyxml=pretty) + + def job_clone(self, job_id, wait=False, **kwargs): + """Clone a job by the specified job id.""" + job_id = f"J:{job_id}" if not job_id.startswith("J:") else job_id + return self._exec_command("job-clone", job_id, wait=wait, **kwargs) + + def job_list(self, *args, **kwargs): + """List jobs matching the criteria specified by args and kwargs.""" + return self._exec_command("job-list", *args, **kwargs) + + def job_cancel(self, job_id): + """Cancel a job by the specified job id.""" + if not job_id.startswith("J:") and not job_id.startswith("RS:"): + job_id = f"J:{job_id}" + return self._exec_command("job-cancel", job_id) + + def job_delete(self, job_id): + """Delete a job by the specified job id.""" + job_id = f"J:{job_id}" if not job_id.startswith("J:") else job_id + return self._exec_command("job-delete", job_id) + + def system_release(self, system_id): + """Release a system by the specified system id.""" + return self._exec_command("system-release", system_id) + + def system_list(self, **kwargs): + """Due to the number of arguments, we will not validate before submitting. + + Accepted arguments are: + available available to be used by this user + free available to this user and not currently being used + removed which have been removed + mine owned by this user + type=TYPE of TYPE + status=STATUS with STATUS + pool=POOL in POOL + arch=ARCH with ARCH + dev-vendor-id=VENDOR-ID with a device that has VENDOR-ID + dev-device-id=DEVICE-ID with a device that has DEVICE-ID + dev-sub-vendor-id=SUBVENDOR-ID with a device that has SUBVENDOR-ID + dev-sub-device-id=SUBDEVICE-ID with a device that has SUBDEVICE-ID + dev-driver=DRIVER with a device that has DRIVER + dev-description=DESCRIPTION with a device that has DESCRIPTION + xml-filter=XML matching the given XML filter + host-filter=NAME matching pre-defined host filter + """ + # convert the flags passed in kwargs to arguments + args = [ + f"--{key}" for key in ("available", "free", "removed", "mine") if kwargs.pop(key, False) + ] + return self._exec_command("system-list", *args, **kwargs) + + def user_systems(self): + """Return a list of system ids owned by the current user. + + This is used for inventory syncing against Beaker. + """ + result = self.system_list(mine=True, raise_on_error=False) + if result.status != 0: + return [] + else: + return result.stdout.splitlines() + + def system_details(self, system_id, format="json"): + """Get details about a system by the specified system id.""" + return self._exec_command("system-details", system_id, format=format) + + def execute_job(self, job, max_wait="24h"): + """Submit a job, periodically checking the status until it completes. + + return: a dictionary of the results. + """ + if Path(job).exists(): # job xml path passed in + job_id = self.job_submit(job, wait=False) + else: # using a job id + job_id = self.job_clone(job) + logger.info(f"Submitted job: {job_id}") + _max_wait = time.time() + helpers.translate_timeout(max_wait or "24h") + while time.time() < _max_wait: + time.sleep(60) + result = self.job_results(job_id, pretty=True) + if 'result="Pass"' in result.stdout: + return _curate_job_info(_elementree_to_dict(ET.fromstring(result.stdout))) + elif 'result="Fail"' in result.stdout or "Exception: " in result.stdout: + raise BeakerBindError(f"Job {job_id} failed:\n{result}") + elif 'result="Warn"' in result.stdout: + res_dict = _elementree_to_dict(ET.fromstring(result.stdout)) + raise BeakerBindError( + f"Job {job_id} was resulted in a warning. Status: {res_dict['status']}" + ) + raise BeakerBindError(f"Job {job_id} did not complete within {max_wait}") + + def system_details_curated(self, system_id): + """Return a curated dictionary of system details.""" + full_details = json.loads(self.system_details(system_id).stdout) + curated_details = { + "hostname": full_details["fqdn"], + "mac_address": full_details["mac_address"], + "owner": "{display_name} <{email_address}>".format( + display_name=full_details["owner"]["display_name"], + email_address=full_details["owner"]["email_address"], + ), + "id": full_details["id"], + } + if current_res := full_details.get("current_reservation"): + curated_details.update( + { + "reservation_id": current_res["recipe_id"], + "reserved_on": current_res.get("start_time"), + "expires_on": current_res.get("finish_time"), + "reserved_for": "{display_name} <{email_address}>".format( + display_name=current_res["user"]["display_name"], + email_address=current_res["user"]["email_address"], + ), + } + ) + return curated_details + + def jobid_from_system(self, system_hostname): + """Return the job id for the current reservation on the system.""" + for job_id in json.loads(self.job_list(mine=True).stdout): + job_result = self.job_results(job_id, pretty=True) + job_detail = _curate_job_info(_elementree_to_dict(ET.fromstring(job_result.stdout))) + if job_detail["hostname"] == system_hostname: + return job_id diff --git a/broker/binds/containers.py b/broker/binds/containers.py index 322285c2..cf291e1c 100644 --- a/broker/binds/containers.py +++ b/broker/binds/containers.py @@ -1,4 +1,9 @@ +"""A collection of classes to ease interaction with Docker and Podman libraries.""" + + class ContainerBind: + """A base class that provides common functionality for Docker and Podman containers.""" + _sensitive_attrs = ["password", "host_password"] def __init__(self, host=None, username=None, password=None, port=22, timeout=None): @@ -12,53 +17,57 @@ def __init__(self, host=None, username=None, password=None, port=22, timeout=Non @property def client(self): + """Return the client instance. Create one if it does not exist.""" if not isinstance(self._client, self._ClientClass): self._client = self._ClientClass(base_url=self.uri, timeout=self.timeout) return self._client @property def images(self): + """Return a list of images on the container host.""" return self.client.images.list() @property def containers(self): + """Return a list of containers on the container host.""" return self.client.containers.list(all=True) def image_info(self, name): + """Return curated information about an image on the container host.""" if image := self.client.images.get(name): return { "id": image.short_id, "tags": image.tags, "size": image.attrs["Size"], - "config": { - k: v for k, v in image.attrs["Config"].items() if k != "Env" - }, + "config": {k: v for k, v in image.attrs["Config"].items() if k != "Env"}, } def create_container(self, image, command=None, **kwargs): - """Create and return running container instance""" + """Create and return running container instance.""" kwargs = self._sanitize_create_args(kwargs) return self.client.containers.create(image, command, **kwargs) def execute(self, image, command=None, remove=True, **kwargs): - """Run a container and return the raw result""" - return self.client.containers.run( - image, command=command, remove=remove, **kwargs - ).decode() + """Run a container and return the raw result.""" + return self.client.containers.run(image, command=command, remove=remove, **kwargs).decode() def remove_container(self, container=None): + """Remove a container from the container host.""" if container: container.remove(v=True, force=True) def pull_image(self, name): + """Pull an image into the container host.""" return self.client.images.pull(name) @staticmethod def get_logs(container): - return "\n".join(map(lambda x: x.decode(), container.logs(stream=False))) + """Return the logs from a container.""" + return "\n".join(x.decode() for x in container.logs(stream=False)) @staticmethod def get_attrs(cont): + """Return curated information about a container.""" return { "id": cont.id, "image": cont.attrs.get("ImageName", cont.attrs["Image"]), @@ -69,6 +78,7 @@ def get_attrs(cont): } def __repr__(self): + """Return a string representation of the object.""" inner = ", ".join( f"{k}={'******' if k in self._sensitive_attrs and v else v}" for k, v in self.__dict__.items() @@ -78,6 +88,8 @@ def __repr__(self): class PodmanBind(ContainerBind): + """Handles Podman-specific connection and implementation differences.""" + def __init__(self, **kwargs): super().__init__(**kwargs) from podman import PodmanClient @@ -86,24 +98,28 @@ def __init__(self, **kwargs): if self.host == "localhost": self.uri = "unix:///run/user/1000/podman/podman.sock" else: - self.uri = ( - "http+ssh://{username}@{host}:{port}/run/podman/podman.sock".format( - **kwargs - ) - ) + self.uri = "http+ssh://{username}@{host}:{port}/run/podman/podman.sock".format(**kwargs) def _sanitize_create_args(self, kwargs): from podman.domain.containers_create import CreateMixin + try: CreateMixin._render_payload(kwargs) except TypeError as err: - sanitized = err.args[0].replace("Unknown keyword argument(s): ", "").replace("'", "").split(" ,") + sanitized = ( + err.args[0] + .replace("Unknown keyword argument(s): ", "") + .replace("'", "") + .split(" ,") + ) kwargs = {k: v for k, v in kwargs.items() if k not in sanitized} kwargs = self._sanitize_create_args(kwargs) return kwargs class DockerBind(ContainerBind): + """Handles Docker-specific connection and implementation differences.""" + def __init__(self, port=2375, **kwargs): kwargs["port"] = port super().__init__(**kwargs) @@ -117,4 +133,5 @@ def __init__(self, port=2375, **kwargs): def _sanitize_create_args(self, kwargs): from docker.models.containers import RUN_CREATE_KWARGS + return {k: v for k, v in kwargs.items() if k in RUN_CREATE_KWARGS} diff --git a/broker/broker.py b/broker/broker.py index 2171f00a..da5374b2 100644 --- a/broker/broker.py +++ b/broker/broker.py @@ -1,9 +1,30 @@ +"""Main interface for the Broker API. + +This module provides the main interface for the Broker API, which allows users to +manage cloud resources across multiple providers. + +It defines the `Host` class, which represents a cloud resource, and the `Broker` class, +which provides methods for managing hosts. + +The `Broker` class is decorated with `mp_decorator`, which enables multiprocessing for +certain methods. The `Host` class is defined in the `broker.hosts` module, +and the provider classes are defined in the `broker.providers` module. + +Exceptions are defined in the `broker.exceptions` module, +and helper functions are defined in the `broker.helpers` module. + +Note: + This module (or parent directory) should be used as the main entry point for the Broker API. + +""" +from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import contextmanager + from logzero import logger -from broker.providers import PROVIDERS, PROVIDER_ACTIONS, _provider_imports -from broker.hosts import Host + from broker import exceptions, helpers -from concurrent.futures import ThreadPoolExecutor, as_completed +from broker.hosts import Host +from broker.providers import PROVIDER_ACTIONS, PROVIDERS, _provider_imports # load all the provider class so they are registered for _import in _provider_imports: @@ -11,15 +32,16 @@ def _try_teardown(host_obj): - """Try a host's teardown method and return an exception message if needed""" + """Try a host's teardown method and return an exception message if needed.""" try: host_obj.teardown() - except Exception as err: + except Exception as err: # noqa: BLE001 + logger.debug(f"Tell Jake the exception was: {err}") return exceptions.HostError(host_obj, f"error during teardown:\n{err}") class mp_decorator: - """This decorator wraps Broker methods to enable multiprocessing + """Decorator wrapping Broker methods to enable multiprocessing. The decorated method is expected to return an itearable. """ @@ -34,6 +56,7 @@ def __init__(self, func=None): self.func = func def __get__(self, instance, owner): + """Support instance methods.""" if not instance: return self.func @@ -46,8 +69,7 @@ def mp_split(*args, **kwargs): max_workers_count = self.MAX_WORKERS or count with self.EXECUTOR(max_workers=max_workers_count) as workers: completed_futures = as_completed( - workers.submit(self.func, instance, *args, **kwargs) - for _ in range(count) + workers.submit(self.func, instance, *args, **kwargs) for _ in range(count) ) for f in completed_futures: results.extend(f.result()) @@ -57,6 +79,8 @@ def mp_split(*args, **kwargs): class Broker: + """Main Broker class to be used as the primary interface for the Broker API.""" + # map exceptions for easier access when used as a library BrokerError = exceptions.BrokerError AuthenticationError = exceptions.AuthenticationError @@ -86,7 +110,7 @@ def __init__(self, **kwargs): self._kwargs = kwargs def _act(self, provider, method, checkout=False): - """Perform a general action against a provider's method""" + """Perform a general action against a provider's method.""" logger.debug(f"Resolving action {method} on provider {provider}.") provider_inst = provider(**self._kwargs) helpers.emit( @@ -97,9 +121,7 @@ def _act(self, provider, method, checkout=False): } ) method_obj = getattr(provider_inst, method) - logger.debug( - f"On {provider_inst=} executing {method_obj=} with params {self._kwargs=}." - ) + logger.debug(f"On {provider_inst=} executing {method_obj=} with params {self._kwargs=}.") result = method_obj(**self._kwargs) logger.debug(f"Action {result=}") if result and checkout: @@ -117,7 +139,7 @@ def _update_provider_actions(self, kwargs): @mp_decorator def _checkout(self): - """checkout one or more VMs + """Checkout one or more VMs. :return: List of Host objects """ @@ -125,7 +147,7 @@ def _checkout(self): logger.debug(f"Doing _checkout(): {self._provider_actions=}") if not self._provider_actions: raise self.BrokerError("Could not determine an appropriate provider") - for action in self._provider_actions.keys(): + for action in self._provider_actions: provider, method = PROVIDER_ACTIONS[action] logger.info(f"Using provider {provider.__name__} to checkout") try: @@ -138,7 +160,7 @@ def _checkout(self): return hosts def checkout(self): - """checkout one or more VMs + """Checkout one or more VMs. :return: Host obj or list of Host objects """ @@ -152,11 +174,11 @@ def checkout(self): self._hosts.extend(hosts) helpers.update_inventory([host.to_dict() for host in hosts]) if err: - raise err - return hosts if not len(hosts) == 1 else hosts[0] + raise self.BrokerError(f"Error during checkout from {self}") from err + return hosts if len(hosts) != 1 else hosts[0] def execute(self, **kwargs): - """execute a provider action + """Execute a provider action. :return: Any results given back by the provider """ @@ -164,17 +186,16 @@ def execute(self, **kwargs): self._kwargs.update(kwargs) if not self._provider_actions: raise self.BrokerError("Could not determine an appropriate provider") - for action, arg in self._provider_actions.items(): + for action, _arg in self._provider_actions.items(): provider, method = PROVIDER_ACTIONS[action] logger.info(f"Using provider {provider.__name__} for execution") return self._act(provider, method) - def nick_help(self): - """Use a provider's nick_help method to get argument information""" - if self._provider_actions: - provider, _ = PROVIDER_ACTIONS[[*self._provider_actions.keys()][0]] - logger.info(f"Querying provider {provider.__name__}") - self._act(provider, "nick_help", checkout=False) + def provider_help(self, provider_name): + """Use a provider's provider_help method to get argument information.""" + provider = PROVIDERS[provider_name] + logger.info(f"Querying provider {provider.__name__}") + self._act(provider, "provider_help", checkout=False) def _checkin(self, host): logger.info(f"Checking in {host.hostname or host.name}") @@ -188,7 +209,7 @@ def _checkin(self, host): return host def checkin(self, sequential=False, host=None, in_context=False): - """checkin one or more VMs + """Checkin one or more VMs. :param host: can be one of: None - Will use the contents of self._hosts @@ -214,9 +235,7 @@ def checkin(self, sequential=False, host=None, in_context=False): hosts = [hosts] if in_context: hosts = [ - host - for host in hosts - if not getattr(host, "_skip_context_checkin", False) + host for host in hosts if not getattr(host, "_skip_context_checkin", False) ] if not hosts: logger.debug("Checkin called with no hosts, taking no action") @@ -230,16 +249,12 @@ def checkin(self, sequential=False, host=None, in_context=False): ) for completed in completed_checkins: _host = completed.result() - self._hosts = [ - h for h in self._hosts if not (h.to_dict() == _host.to_dict()) - ] - logger.debug( - f"Completed checkin process for {_host.hostname or _host.name}" - ) + self._hosts = [h for h in self._hosts if h.to_dict() != _host.to_dict()] + logger.debug(f"Completed checkin process for {_host.hostname or _host.name}") helpers.update_inventory(remove=[h.hostname for h in hosts]) def _extend(self, host): - """extend a single VM""" + """Extend a single VM.""" logger.info(f"Extending host {host.hostname}") provider = PROVIDERS[host._broker_provider] self._kwargs["target_vm"] = host @@ -247,7 +262,7 @@ def _extend(self, host): return host def extend(self, sequential=False, host=None): - """extend one or more VMs + """Extend one or more VMs. :param host: can be one of: None - Will use the contents of self._hosts @@ -275,16 +290,14 @@ def extend(self, sequential=False, host=None): return with ThreadPoolExecutor(max_workers=1 if sequential else len(hosts)) as workers: - completed_extends = as_completed( - workers.submit(self._extend, _host) for _host in hosts - ) + completed_extends = as_completed(workers.submit(self._extend, _host) for _host in hosts) for completed in completed_extends: _host = completed.result() logger.info(f"Completed extend for {_host.hostname or _host.name}") @staticmethod def sync_inventory(provider): - """Acquire a list of hosts from a provider and update our inventory""" + """Acquire a list of hosts from a provider and update our inventory.""" additional_arg, instance = None, {} if "::" in provider: provider, instance = provider.split("::") @@ -298,12 +311,12 @@ def sync_inventory(provider): prov_inventory = PROVIDERS[provider](**instance).get_inventory(additional_arg) curr_inventory = [ hostname if (hostname := host.get("hostname")) else host.get("name") - for host in helpers.load_inventory(filter=f"_broker_provider={provider}") + for host in helpers.load_inventory(filter=f'@inv._broker_provider == "{provider}"') ] helpers.update_inventory(add=prov_inventory, remove=curr_inventory) def reconstruct_host(self, host_export_data): - """reconstruct a host from export data""" + """Reconstruct a host from export data.""" logger.debug(f"reconstructing host with export: {host_export_data}") provider = PROVIDERS.get(host_export_data.get("_broker_provider")) if not provider: @@ -320,7 +333,7 @@ def reconstruct_host(self, host_export_data): return host def from_inventory(self, filter=None): - """Reconstruct one or more hosts from the local inventory + """Reconstruct one or more hosts from the local inventory. :param filter: A broker-spec filter string """ @@ -330,8 +343,10 @@ def from_inventory(self, filter=None): @classmethod @contextmanager def multi_manager(cls, **multi_dict): - """Given a mapping of names to Broker argument dictionaries: - create multiple Broker instances, check them out in parallel, yield, then checkin. + """Allow a user to check out multiple hosts at once. + + Given a mapping of names to Broker argument dictionaries: + create multiple Broker instances, check them out in parallel, yield, then checkin. Example: with Broker.multi_mode( @@ -364,18 +379,14 @@ def multi_manager(cls, **multi_dict): all_hosts.extend(broker_inst._hosts) # run setup on all hosts in parallel with ThreadPoolExecutor(max_workers=len(all_hosts)) as workers: - completed_setups = as_completed( - workers.submit(host.setup) for host in all_hosts - ) + completed_setups = as_completed(workers.submit(host.setup) for host in all_hosts) for completed in completed_setups: completed.result() # yield control to the user yield {name: broker._hosts for name, broker in broker_instances.items()} # teardown all hosts in parallel with ThreadPoolExecutor(max_workers=len(all_hosts)) as workers: - completed_teardowns = as_completed( - workers.submit(host.teardown) for host in all_hosts - ) + completed_teardowns = as_completed(workers.submit(host.teardown) for host in all_hosts) for completed in completed_teardowns: completed.result() # checkin all hosts in parallel @@ -387,6 +398,7 @@ def multi_manager(cls, **multi_dict): completed.result() def __repr__(self): + """Return a string representation of the Broker instance.""" inner = ", ".join( f"{k}={v}" for k, v in self.__dict__.items() @@ -395,6 +407,7 @@ def __repr__(self): return f"{self.__class__.__name__}({inner})" def __enter__(self): + """Checkout hosts and return them to the user.""" try: hosts = self.checkout() if not hosts: @@ -411,6 +424,7 @@ def __enter__(self): raise err def __exit__(self, exc_type, exc_value, exc_traceback): + """Teardown and checkin hosts.""" last_exception = None for host in self._hosts: last_exception = _try_teardown(host) diff --git a/broker/commands.py b/broker/commands.py index e45fda7b..08844dbe 100644 --- a/broker/commands.py +++ b/broker/commands.py @@ -1,19 +1,21 @@ +"""Defines the CLI commands for Broker.""" from functools import wraps import signal import sys + import click from logzero import logger + from broker import exceptions, helpers, settings -from broker.broker import PROVIDERS, PROVIDER_ACTIONS, Broker +from broker.broker import Broker from broker.logger import LOG_LEVEL -from broker import exceptions, helpers, settings - +from broker.providers import PROVIDER_HELP, PROVIDERS signal.signal(signal.SIGINT, helpers.handle_keyboardinterrupt) def loggedcli(group=None, *cli_args, **cli_kwargs): - """Updates the group command wrapper function in order to add logging""" + """Update the group command wrapper function in order to add logging.""" if not group: group = cli # default to the main cli group @@ -35,21 +37,23 @@ def wrapper(*args, **kwargs): class ExceptionHandler(click.Group): - """Wraps click group to catch and handle raised exceptions""" + """Wraps click group to catch and handle raised exceptions.""" def __call__(self, *args, **kwargs): + """Override the __call__ method to catch and handle exceptions.""" try: - return self.main(*args, **kwargs) - except Exception as err: + res = self.main(*args, **kwargs) + helpers.emit(return_code=0) + return res + except Exception as err: # noqa: BLE001 if not isinstance(err, exceptions.BrokerError): err = exceptions.BrokerError(err) helpers.emit(return_code=err.error_code, error_message=str(err.message)) sys.exit(err.error_code) - helpers.emit(return_code=0) def provider_options(command): - """Applies provider-specific decorators to each command this decorates""" + """Apply provider-specific decorators to each command this decorates.""" for prov in PROVIDERS.values(): if prov.hidden: continue @@ -59,8 +63,9 @@ def provider_options(command): def populate_providers(click_group): - """Populates the subcommands for providers subcommand using provider information - Providers become subcommands and their actions become arguments to their subcommand + """Populate the subcommands for providers subcommand using provider information. + + Providers become subcommands and their actions become arguments to their subcommand. Example: Usage: broker providers AnsibleTower [OPTIONS] @@ -78,11 +83,18 @@ def populate_providers(click_group): group=click_group, name=prov, hidden=prov_class.hidden, - context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, + context_settings={ + "allow_extra_args": True, + "ignore_unknown_options": True, + }, ) @click.pass_context def provider_cmd(ctx, *args, **kwargs): # the actual subcommand - """Get information about a provider's actions""" + """Get information about a provider's actions.""" + # add additional args flags to the kwargs + for arg in ctx.args: + if arg.startswith("--"): + kwargs[arg[2:]] = True # if additional arguments were passed, include them in the broker args # strip leading -- characters kwargs.update( @@ -92,24 +104,21 @@ def provider_cmd(ctx, *args, **kwargs): # the actual subcommand } ) broker_inst = Broker(**kwargs) - broker_inst.nick_help() + broker_inst.provider_help(ctx.info_name) # iterate through available actions and populate options from them - for action in ( - action - for action, prov_info in PROVIDER_ACTIONS.items() - if prov_info[0] == prov_class - ): - action = action.replace("_", "-") - plural = ( - action.replace("y", "ies") if action.endswith("y") else f"{action}s" - ) - provider_cmd = click.option( - f"--{plural}", is_flag=True, help=f"Get available {plural}" - )(provider_cmd) - provider_cmd = click.option( - f"--{action}", type=str, help=f"Get information about a {action}" - )(provider_cmd) + for option, (p_cls, is_flag) in PROVIDER_HELP.items(): + if p_cls is not prov_class: + continue + option = option.replace("_", "-") # noqa: PLW2901 + if is_flag: + provider_cmd = click.option( + f"--{option}", is_flag=True, help=f"Get available {option}" + )(provider_cmd) + else: + provider_cmd = click.option( + f"--{option}", type=str, help=f"Get information about a {option}" + )(provider_cmd) provider_cmd = click.option( "--results-limit", type=int, @@ -145,17 +154,18 @@ def provider_cmd(ctx, *args, **kwargs): # the actual subcommand help="Get broker system-level information", ) def cli(version): + """Command-line interface for interacting with providers.""" if version: import pkg_resources broker_version = pkg_resources.get_distribution("broker").version # check the latest version publish to PyPi try: - import requests from packaging.version import Version + import requests latest_version = Version( - requests.get("https://pypi.org/pypi/broker/json").json()["info"][ + requests.get("https://pypi.org/pypi/broker/json", timeout=60).json()["info"][ "version" ] ) @@ -164,8 +174,8 @@ def cli(version): f"A newer version of broker is available: {latest_version}", fg="yellow", ) - except: - pass + except requests.exceptions.RequestException as err: + logger.warning(f"Unable to check for latest version: {err}") click.echo(f"Version: {broker_version}") broker_directory = settings.BROKER_DIRECTORY.absolute() click.echo(f"Broker Directory: {broker_directory}") @@ -177,9 +187,7 @@ def cli(version): @loggedcli(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) @click.option("-b", "--background", is_flag=True, help="Run checkout in the background") @click.option("-n", "--nick", type=str, help="Use a nickname defined in your settings") -@click.option( - "-c", "--count", type=int, help="Number of times broker repeats the checkout" -) +@click.option("-c", "--count", type=int, help="Number of times broker repeats the checkout") @click.option( "--args-file", type=click.Path(exists=True), @@ -188,18 +196,11 @@ def cli(version): @provider_options @click.pass_context def checkout(ctx, background, nick, count, args_file, **kwargs): - """Checkout or "create" a Virtual Machine broker instance - COMMAND: broker checkout --workflow "workflow-name" --workflow-arg1 something - or - COMMAND: broker checkout --nick "nickname" - - :param ctx: clicks context object + """Checkout or "create" a Virtual Machine broker instance. - :param background: run a new broker subprocess to carry out command + COMMAND: broker checkout --workflow "workflow-name" --workflow_arg1 something - :param nick: shortcut for arguments saved in settings.yaml, passed in as a string - - :param args_file: this broker argument will be replaced with the contents of the file passed in + COMMAND: broker checkout --nick "nickname" """ broker_args = helpers.clean_dict(kwargs) if nick: @@ -218,13 +219,12 @@ def checkout(ctx, background, nick, count, args_file, **kwargs): ) if background: helpers.fork_broker() - broker_inst = Broker(**broker_args) - broker_inst.checkout() + Broker(**broker_args).checkout() @cli.group(cls=ExceptionHandler) def providers(): - """Get information about a provider and its actions""" + """Get information about a provider and its actions.""" pass @@ -236,38 +236,20 @@ def providers(): @click.option("-b", "--background", is_flag=True, help="Run checkin in the background") @click.option("--all", "all_", is_flag=True, help="Select all VMs") @click.option("--sequential", is_flag=True, help="Run checkins sequentially") -@click.option( - "--filter", type=str, help="Checkin only what matches the specified filter" -) +@click.option("--filter", type=str, help="Checkin only what matches the specified filter") def checkin(vm, background, all_, sequential, filter): - """Checkin or "remove" a VM or series of VM broker instances - - COMMAND: broker checkin ||all - - :param vm: Hostname or local id of host - - :param background: run a new broker subprocess to carry out command + """Checkin or "remove" a VM or series of VM broker instances. - :param all_: Flag for whether to checkin everything - - :param sequential: Flag for whether to run checkins sequentially - - :param filter: a filter string matching broker's specification + COMMAND: broker checkin ||--all """ if background: helpers.fork_broker() inventory = helpers.load_inventory(filter=filter) to_remove = [] for num, host in enumerate(inventory): - if ( - str(num) in vm - or host.get("hostname") in vm - or host.get("name") in vm - or all_ - ): + if str(num) in vm or host.get("hostname") in vm or host.get("name") in vm or all_: to_remove.append(Broker().reconstruct_host(host)) - broker_inst = Broker(hosts=to_remove) - broker_inst.checkin(sequential=sequential) + Broker(hosts=to_remove).checkin(sequential=sequential) @loggedcli() @@ -277,12 +259,11 @@ def checkin(vm, background, all_, sequential, filter): type=str, help="Class-style name of a supported broker provider. (AnsibleTower)", ) -@click.option( - "--filter", type=str, help="Display only what matches the specified filter" -) +@click.option("--filter", type=str, help="Display only what matches the specified filter") def inventory(details, sync, filter): - """Get a list of all VMs you've checked out showing hostname and local id - hostname pulled from list of dictionaries + """Get a list of all VMs you've checked out showing hostname and local id. + + hostname pulled from list of dictionaries. """ if sync: Broker.sync_inventory(provider=sync) @@ -293,10 +274,13 @@ def inventory(details, sync, filter): emit_data.append(host) if (display_name := host.get("hostname")) is None: display_name = host.get("name") + # if we're filtering, then don't show an index. + # Otherwise, a user might perform an action on the incorrect (unfiltered) index. + index = f"{num}: " if filter is None else "" if details: - logger.info(f"{num}: {display_name}, Details: {helpers.yaml_format(host)}") + logger.info(f"{index}{display_name}:\n{helpers.yaml_format(host)}") else: - logger.info(f"{num}: {display_name}") + logger.info(f"{index}{display_name}") helpers.emit({"inventory": emit_data}) @@ -305,24 +289,12 @@ def inventory(details, sync, filter): @click.option("-b", "--background", is_flag=True, help="Run extend in the background") @click.option("--all", "all_", is_flag=True, help="Select all VMs") @click.option("--sequential", is_flag=True, help="Run extends sequentially") -@click.option( - "--filter", type=str, help="Extend only what matches the specified filter" -) +@click.option("--filter", type=str, help="Extend only what matches the specified filter") @provider_options def extend(vm, background, all_, sequential, filter, **kwargs): - """Extend a host's lease time - - COMMAND: broker extend || - - :param vm: Hostname, VM Name, or local id of host - - :param background: run a new broker subprocess to carry out command + """Extend a host's lease time. - :param all_: Click option all - - :param sequential: Flag for whether to run extends sequentially - - :param filter: a filter string matching broker's specification + COMMAND: broker extend |||--all """ broker_args = helpers.clean_dict(kwargs) if background: @@ -330,42 +302,27 @@ def extend(vm, background, all_, sequential, filter, **kwargs): inventory = helpers.load_inventory(filter=filter) to_extend = [] for num, host in enumerate(inventory): - if str(num) in vm or host["hostname"] in vm or host["name"] in vm or all_: + if str(num) in vm or host["hostname"] in vm or host.get("name") in vm or all_: to_extend.append(Broker().reconstruct_host(host)) - broker_inst = Broker(hosts=to_extend, **broker_args) - broker_inst.extend(sequential=sequential) + Broker(hosts=to_extend, **broker_args).extend(sequential=sequential) @loggedcli() @click.argument("vm", type=str, nargs=-1) -@click.option( - "-b", "--background", is_flag=True, help="Run duplicate in the background" -) -@click.option( - "-c", "--count", type=int, help="Number of times broker repeats the duplicate" -) +@click.option("-b", "--background", is_flag=True, help="Run duplicate in the background") +@click.option("-c", "--count", type=int, help="Number of times broker repeats the duplicate") @click.option("--all", "all_", is_flag=True, help="Select all VMs") -@click.option( - "--filter", type=str, help="Duplicate only what matches the specified filter" -) +@click.option("--filter", type=str, help="Duplicate only what matches the specified filter") def duplicate(vm, background, count, all_, filter): - """Duplicate a broker-procured vm + """Duplicate a broker-procured vm. COMMAND: broker duplicate ||all - - :param vm: Hostname or local id of host - - :param background: run a new broker subprocess to carry out command - - :param all_: Click option all - - :param filter: a filter string matching broker's specification """ if background: helpers.fork_broker() inventory = helpers.load_inventory(filter=filter) for num, host in enumerate(inventory): - if str(num) in vm or host["hostname"] in vm or host["name"] in vm or all_: + if str(num) in vm or host["hostname"] in vm or host.get("name") in vm or all_: broker_args = host.get("_broker_args") if broker_args: if count: @@ -374,17 +331,13 @@ def duplicate(vm, background, count, all_, filter): broker_inst = Broker(**broker_args) broker_inst.checkout() else: - logger.warning( - f"Unable to duplicate {host['hostname']}, no _broker_args found" - ) + logger.warning(f"Unable to duplicate {host['hostname']}, no _broker_args found") @loggedcli(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) @click.option("-b", "--background", is_flag=True, help="Run execute in the background") @click.option("--nick", type=str, help="Use a nickname defined in your settings") -@click.option( - "--output-format", "-o", type=click.Choice(["log", "raw", "yaml"]), default="log" -) +@click.option("--output-format", "-o", type=click.Choice(["log", "raw", "yaml"]), default="log") @click.option( "--artifacts", type=click.Choice(["merge", "last"]), @@ -398,22 +351,11 @@ def duplicate(vm, background, count, all_, filter): @provider_options @click.pass_context def execute(ctx, background, nick, output_format, artifacts, args_file, **kwargs): - """Execute an arbitrary provider action - COMMAND: broker execute --workflow "workflow-name" --workflow-arg1 something - or - COMMAND: broker execute --nick "nickname" - - :param ctx: clicks context object - - :param background: run a new broker subprocess to carry out command + """Execute an arbitrary provider action. - :param nick: shortcut for arguments saved in settings.yaml, passed in as a string + COMMAND: broker execute --workflow "workflow-name" --workflow_arg1 something - :param output_format: change the format of the output to one of the choice options - - :param artifacts: AnsibleTower provider specific option for choosing what to return - - :param args_file: this broker argument will be replaced with the contents of the file passed in + COMMAND: broker execute --nick "nickname" """ broker_args = helpers.clean_dict(kwargs) if nick: @@ -432,12 +374,11 @@ def execute(ctx, background, nick, output_format, artifacts, args_file, **kwargs ) if background: helpers.fork_broker() - broker_inst = Broker(**broker_args) - result = broker_inst.execute() + result = Broker(**broker_args).execute() helpers.emit({"output": result}) if output_format == "raw": - print(result) + click.echo(result) elif output_format == "log": logger.info(result) elif output_format == "yaml": - print(helpers.yaml_format(result)) + click.echo(helpers.yaml_format(result)) diff --git a/broker/exceptions.py b/broker/exceptions.py index 399bbed8..c9a6ef19 100644 --- a/broker/exceptions.py +++ b/broker/exceptions.py @@ -1,26 +1,37 @@ -"""A collection of Broker-specific exceptions""" +"""A collection of Broker-specific exceptions.""" +import logging + from logzero import logger class BrokerError(Exception): + """Base class for Broker exceptions.""" + error_code = 1 def __init__(self, message="An unhandled exception occured!"): - if logger.level == 10 and isinstance(message, Exception): + # Log the exception if the logger is set to DEBUG + if logger.level == logging.DEBUG and isinstance(message, Exception): logger.exception(message) self.message = message logger.error(f"{self.__class__.__name__}: {self.message}") class AuthenticationError(BrokerError): + """Raised when authentication with a provider or Host fails.""" + error_code = 5 class PermissionError(BrokerError): + """Raised when the user does not have permission to perform an action.""" + error_code = 6 class ProviderError(BrokerError): + """Raised when a provider-specific error occurs.""" + error_code = 7 def __init__(self, provider=None, message="Unspecified exception"): @@ -29,17 +40,35 @@ def __init__(self, provider=None, message="Unspecified exception"): class ConfigurationError(BrokerError): + """Raised when a Broker configuration error occurs.""" + error_code = 8 class NotImplementedError(BrokerError): + """Raised when a method or function has not been implemented.""" + error_code = 9 class HostError(BrokerError): + """Raised when a Host-specific error occurs.""" + error_code = 10 def __init__(self, host=None, message="Unspecified exception"): if host: self.message = f"{host.hostname or host.name}: {message}" super().__init__(message=self.message) + + +class ContainerBindError(BrokerError): + """Raised when a problem occurs at the container's bind level.""" + + error_code = 11 + + +class BeakerBindError(BrokerError): + """Raised when a problem occurs at the Beaker bind level.""" + + error_code = 12 diff --git a/broker/helpers.py b/broker/helpers.py index 2abbee10..1b0fe0a6 100644 --- a/broker/helpers.py +++ b/broker/helpers.py @@ -1,35 +1,37 @@ -"""Miscellaneous helpers live here""" +"""Miscellaneous helpers live here.""" import collections +from collections import UserDict, namedtuple +from collections.abc import MutableMapping from contextlib import contextmanager +from copy import deepcopy import getpass import inspect import json import os +from pathlib import Path import sys import tarfile +import threading import time -from collections import UserDict, namedtuple -from collections.abc import MutableMapping -from copy import deepcopy -from pathlib import Path from uuid import uuid4 -import yaml +import click from logzero import logger +import yaml -from broker import exceptions, settings -from broker import logger as b_log +from broker import exceptions, logger as b_log, settings FilterTest = namedtuple("FilterTest", "haystack needle test") +INVENTORY_LOCK = threading.Lock() def clean_dict(in_dict): - """Remove entries from a dict where value is None""" + """Remove entries from a dict where value is None.""" return {k: v for k, v in in_dict.items() if v is not None} def merge_dicts(dict1, dict2): - """Merge two nested dictionaries together + """Merge two nested dictionaries together. :return: merged dictionary """ @@ -49,7 +51,8 @@ def merge_dicts(dict1, dict2): def flatten_dict(nested_dict, parent_key="", separator="_"): - """Flatten a nested dictionary, keeping nested notation in key + """Flatten a nested dictionary, keeping nested notation in key. + { 'key': 'value1', 'another': { @@ -64,11 +67,10 @@ def flatten_dict(nested_dict, parent_key="", separator="_"): "another_nested2": [1, 2], "another_nested2_deep": "value3" } - note that dictionaries nested in lists will be removed from the list + note that dictionaries nested in lists will be removed from the list. :return: dictionary """ - flattened = [] for key, value in nested_dict.items(): new_key = f"{parent_key}{separator}{key}" if parent_key else key @@ -76,7 +78,8 @@ def flatten_dict(nested_dict, parent_key="", separator="_"): flattened.extend(flatten_dict(value, new_key, separator).items()) elif isinstance(value, list): to_remove = [] - value = value.copy() # avoid mutating nested structures + # avoid mutating nested structures + value = value.copy() # noqa: PLW2901 for index, val in enumerate(value): if isinstance(val, dict): flattened.extend(flatten_dict(val, new_key, separator).items()) @@ -89,15 +92,45 @@ def flatten_dict(nested_dict, parent_key="", separator="_"): return dict(flattened) +def dict_from_paths(source_dict, paths): + """Given a dictionary of desired keys and nested paths, return a new dictionary. + + Example: + source_dict = { + "key1": "value1", + "key2": { + "nested1": "value2", + "nested2": { + "deep": "value3" + } + } + } + paths = { + "key1": "key1", + "key2": "key2/nested2/deep" + } + returns { + "key1": "value1", + "key2": "value3" + } + """ + result = {} + for key, path in paths.items(): + if "/" not in path: + result[key] = source_dict.get(path) + else: + top, rem = path.split("/", 1) + result.update(dict_from_paths(source_dict[top], {key: rem})) + return result + + def eval_filter(filter_list, raw_filter, filter_key="inv"): - """Run each filter through an eval to get the results""" - filter_list = [ - MockStub(item) if isinstance(item, dict) else item for item in filter_list - ] + """Run each filter through an eval to get the results.""" + filter_list = [MockStub(item) if isinstance(item, dict) else item for item in filter_list] for raw_f in raw_filter.split("|"): if f"@{filter_key}[" in raw_f: # perform a list filter on the inventory - filter_list = eval( + filter_list = eval( # noqa: S307 raw_f.replace(f"@{filter_key}", filter_key), {filter_key: filter_list} ) filter_list = filter_list if isinstance(filter_list, list) else [filter_list] @@ -105,7 +138,7 @@ def eval_filter(filter_list, raw_filter, filter_key="inv"): # perform an attribute filter on each host filter_list = list( filter( - lambda item: eval( + lambda item: eval( # noqa: S307 raw_f.replace(f"@{filter_key}", filter_key), {filter_key: item} ), filter_list, @@ -115,7 +148,7 @@ def eval_filter(filter_list, raw_filter, filter_key="inv"): def resolve_nick(nick): - """Checks if the nickname exists. Used to define broker arguments + """Check if the nickname exists. Used to define broker arguments. :param nick: String representing the name of a nick @@ -127,7 +160,7 @@ def resolve_nick(nick): def load_file(file, warn=True): - """Verifies existence and loads data from json and yaml files""" + """Verify the existence of and load data from json and yaml files.""" file = Path(file) if not file.exists() or file.suffix not in (".json", ".yaml", ".yml"): if warn: @@ -145,15 +178,14 @@ def load_file(file, warn=True): def resolve_file_args(broker_args): - """Check for files being passed in as values to arguments, - then attempt to resolve them. If not resolved, keep arg/value pair intact. + """Check for files being passed in as values to arguments then attempt to resolve them. + + If not resolved, keep arg/value pair intact. """ final_args = {} # parse the eventual args_file first if val := broker_args.pop("args_file", None): - if isinstance(val, Path) or ( - isinstance(val, str) and val[-4:] in ("json", "yaml", ".yml") - ): + if isinstance(val, Path) or (isinstance(val, str) and val[-4:] in ("json", "yaml", ".yml")): if data := load_file(val): if isinstance(data, dict): final_args.update(data) @@ -164,9 +196,7 @@ def resolve_file_args(broker_args): raise exceptions.BrokerError(f"No data loaded from {val}") for key, val in broker_args.items(): - if isinstance(val, Path) or ( - isinstance(val, str) and val[-4:] in ("json", "yaml", ".yml") - ): + if isinstance(val, Path) or (isinstance(val, str) and val[-4:] in ("json", "yaml", ".yml")): if data := load_file(val): final_args.update({key: data}) else: @@ -177,7 +207,7 @@ def resolve_file_args(broker_args): def load_inventory(filter=None): - """Loads all local hosts in inventory + """Load all local hosts in inventory. :return: list of dictionaries """ @@ -186,7 +216,7 @@ def load_inventory(filter=None): def update_inventory(add=None, remove=None): - """Updates list of local hosts in the checkout interface + """Update list of local hosts in the checkout interface. :param add: list of dictionaries representing new hosts @@ -200,20 +230,19 @@ def update_inventory(add=None, remove=None): add = [] if remove and not isinstance(remove, list): remove = [remove] - with FileLock(settings.inventory_path): + with INVENTORY_LOCK: inv_data = load_inventory() if inv_data: settings.inventory_path.unlink() if remove: for host in inv_data[::-1]: - if host["hostname"] in remove or host["name"] in remove: + if host["hostname"] in remove or host.get("name") in remove: # iterate through new hosts and update with old host data if it would nullify for new_host in add: - if ( - host["hostname"] == new_host["hostname"] - or host["name"] == new_host["name"] - ): + if host["hostname"] == new_host["hostname"] or host.get( + "name" + ) == new_host.get("name"): # update missing data in the new_host with the old_host data new_host.update(merge_dicts(new_host, host)) inv_data.remove(host) @@ -226,7 +255,7 @@ def update_inventory(add=None, remove=None): def yaml_format(in_struct): - """Convert a yaml-compatible structure to a yaml dumped string + """Convert a yaml-compatible structure to a yaml dumped string. :param in_struct: yaml-compatible structure or string containing structure @@ -238,22 +267,26 @@ def yaml_format(in_struct): class Emitter: - """This class provides a simple interface to emit messages to a - json-formatted file. This file also has an instance of this class - called "emit" that should be used instead of this class directly. + """Class that provides a simple interface to emit messages to a json-formatted file. + + This module also has an instance of this class called "emit" that should be used + instead of this class directly. Usage examples: helpers.emit(key=value, another=5) helpers.emit({"key": "value", "another": 5}) """ + EMIT_LOCK = threading.Lock() + def __init__(self, emit_file=None): - """Can empty init and set the file later""" + """Can empty init and set the file later.""" self.file = None if emit_file: self.file = self.set_file(emit_file) def set_file(self, file_path): + """Set the file to emit to.""" if file_path: self.file = Path(file_path) self.file.parent.mkdir(exist_ok=True, parents=True) @@ -262,21 +295,23 @@ def set_file(self, file_path): self.file.touch() def emit_to_file(self, *args, **kwargs): + """Emit data to the file, keeping existing data in-place.""" if not self.file: return for arg in args: if not isinstance(arg, dict): raise exceptions.BrokerError(f"Received an invalid data emission {arg}") kwargs.update(arg) - for key in kwargs.keys(): + for key in kwargs: if getattr(kwargs[key], "json", None): kwargs[key] = kwargs[key].json - with FileLock(self.file): + with self.EMIT_LOCK: curr_data = json.loads(self.file.read_text() or "{}") curr_data.update(kwargs) self.file.write_text(json.dumps(curr_data, indent=4, sort_keys=True)) def __call__(self, *args, **kwargs): + """Allow emit to be used like a function.""" return self.emit_to_file(*args, **kwargs) @@ -284,10 +319,10 @@ def __call__(self, *args, **kwargs): class MockStub(UserDict): - """Test helper class. Allows for both arbitrary mocking and stubbing""" + """Test helper class. Allows for both arbitrary mocking and stubbing.""" def __init__(self, in_dict=None): - """Initialize the class and all nested dictionaries""" + """Initialize the class and all nested dictionaries.""" if in_dict is None: in_dict = {} for key, value in in_dict.items(): @@ -304,9 +339,15 @@ def __init__(self, in_dict=None): super().__init__(in_dict) def __getattr__(self, name): + """Fallback to returning self if attribute doesn't exist.""" return self def __getitem__(self, key): + """Get an item from the dictionary-like object. + + If the key is a string, this method will attempt to get an attribute with that name. + If the key is not found, this method will return the object itself. + """ if isinstance(key, str): item = getattr(self, key, self) try: @@ -316,31 +357,38 @@ def __getitem__(self, key): return item def __call__(self, *args, **kwargs): + """Allow MockStub to be used like a function.""" return self def __hash__(self): + """Return a hash value for the object. + + The hash value is computed using the hash value of all hashable attributes of the object. + """ return hash( - tuple( - ( - kp - for kp in self.__dict__.items() - if isinstance(kp[1], collections.abc.Hashable) - ) - ) + tuple(kp for kp in self.__dict__.items() if isinstance(kp[1], collections.abc.Hashable)) ) def update_log_level(ctx, param, value): + """Update the log level and file logging settings for the Broker. + + Args: + ctx: The Click context object. + param: The Click parameter object. + value: The new log level value. + """ b_log.set_log_level(value) b_log.set_file_logging(value) def set_emit_file(ctx, param, value): - global emit + """Update the file that the Broker emits data to.""" emit.set_file(value) def fork_broker(): + """Fork the Broker process to run in the background.""" pid = os.fork() if pid: logger.info(f"Running broker in the background with pid: {pid}") @@ -349,22 +397,27 @@ def fork_broker(): def handle_keyboardinterrupt(*args): - choice = input( + """Handle keyboard interrupts gracefully. + + Offer the user a choice between keeping Broker alive in the background or killing it. + """ + choice = click.prompt( "\nEnding Broker while running won't end processes being monitored.\n" "Would you like to switch Broker to run in the background?\n" - "[y/n]: " + "[y/n]: ", + type=click.Choice(["y", "n"]), + default="n", ) - if choice.lower()[0] == "y": + if choice == "y": fork_broker() else: raise exceptions.BrokerError("Broker killed by user.") def translate_timeout(timeout): - """Allows for flexible timeout definitions, converts other units to ms + """Allow for flexible timeout definitions, converts other units to ms. acceptable units are (s)econds, (m)inutes, (h)ours, (d)ays - """ if isinstance(timeout, str): timeout, unit = int(timeout[:-1]), timeout[-1] @@ -383,12 +436,12 @@ def translate_timeout(timeout): def simple_retry(cmd, cmd_args=None, cmd_kwargs=None, max_timeout=60, _cur_timeout=1): - """Re(Try) a function given its args and kwargs up until a max timeout""" + """Re(Try) a function given its args and kwargs up until a max timeout.""" cmd_args = cmd_args if cmd_args else [] cmd_kwargs = cmd_kwargs if cmd_kwargs else {} try: return cmd(*cmd_args, **cmd_kwargs) - except Exception as err: + except Exception as err: # noqa: BLE001 - Could be anything new_wait = _cur_timeout * 2 if new_wait > max_timeout: raise err @@ -401,8 +454,9 @@ def simple_retry(cmd, cmd_args=None, cmd_kwargs=None, max_timeout=60, _cur_timeo class FileLock: - """Basic file locking class that acquires and releases locks - recommended usage is the context manager which will handle everything for you + """Basic file locking class that acquires and releases locks. + + Recommended usage is the context manager which will handle everything for you with FileLock("basic_file.txt"): Path("basic_file.txt").write_text("some text") @@ -415,6 +469,7 @@ def __init__(self, file_name, timeout=10): self.timeout = timeout def wait_file(self): + """Wait for the lock file to be released, then acquire it.""" timeout_after = time.time() + self.timeout while self.lock.exists(): if time.time() <= timeout_after: @@ -426,26 +481,29 @@ def wait_file(self): self.lock.touch() def return_file(self): + """Release the lock file.""" self.lock.unlink() - def __enter__(self): + def __enter__(self): # noqa: D105 self.wait_file() - def __exit__(self, *tb_info): + def __exit__(self, *tb_info): # noqa: D105 self.return_file() class Result: - """Dummy result class for presenting results in dot access""" + """Dummy result class for presenting results in dot access.""" def __init__(self, **kwargs): self.__dict__.update(kwargs) def __repr__(self): + """Return a string representation of the object.""" return f"stdout:\n{self.stdout}\nstderr:\n{self.stderr}\nstatus: {self.status}" @classmethod def from_ssh(cls, stdout, channel): + """Create a Result object from an SSH channel.""" return cls( stdout=stdout, status=channel.get_exit_status(), @@ -454,6 +512,7 @@ def from_ssh(cls, stdout, channel): @classmethod def from_duplexed_exec(cls, duplex_exec): + """Create a Result object from a duplexed exec object from the docker library.""" if duplex_exec.output[0]: stdout = duplex_exec.output[0].decode("utf-8") else: @@ -470,6 +529,7 @@ def from_duplexed_exec(cls, duplex_exec): @classmethod def from_nonduplexed_exec(cls, nonduplex_exec): + """Create a Result object from a nonduplexed exec object from the docker library.""" return cls( status=nonduplex_exec.exit_code, stdout=nonduplex_exec.output.decode("utf-8"), @@ -479,19 +539,18 @@ def from_nonduplexed_exec(cls, nonduplex_exec): def find_origin(): """Move up the call stack to find tests, fixtures, or cli invocations. + Additionally, return the jenkins url, if it exists. """ prev, jenkins_url = None, os.environ.get("BUILD_URL") for frame in inspect.stack(): - if frame.function == "checkout" and frame.filename.endswith( - "broker/commands.py" - ): + if frame.function == "checkout" and frame.filename.endswith("broker/commands.py"): return f"broker_cli:{getpass.getuser()}", jenkins_url if frame.function.startswith("test_"): return f"{frame.function}:{frame.filename}", jenkins_url if frame.function == "call_fixture_func": # attempt to find the test name from the fixture's request object - if request := _frame.frame.f_locals.get("request"): + if request := _frame.frame.f_locals.get("request"): # noqa: F821 return f"{prev} for {request.node._nodeid}", jenkins_url # otherwise, return the fixture name and filename return prev or "Uknown fixture", jenkins_url @@ -501,7 +560,7 @@ def find_origin(): @contextmanager def data_to_tempfile(data, path=None, as_tar=False): - """Write data to a temporary file and return the path""" + """Write data to a temporary file and return the path.""" path = Path(path or uuid4().hex[-10]) logger.debug(f"Creating temporary file {path.absolute()}") if isinstance(data, bytes): @@ -521,7 +580,7 @@ def data_to_tempfile(data, path=None, as_tar=False): @contextmanager def temporary_tar(paths): - """Create a temporary tar file and return the path""" + """Create a temporary tar file and return the path.""" temp_tar = Path(f"{uuid4().hex[-10]}.tar") with tarfile.open(temp_tar, mode="w") as tar: for path in paths: diff --git a/broker/hosts.py b/broker/hosts.py index 2145085e..d45fcb94 100644 --- a/broker/hosts.py +++ b/broker/hosts.py @@ -1,14 +1,38 @@ +"""Module for managing hosts. + +This module defines the `Host` class, which represents a host that can be accessed via SSH or Bind. +The `Host` class provides methods for connecting to the host, executing commands, and transferring files. +It additionally exposes a common interface for Broker to manage host creation, checkin, and deletion. +It is recommended to subclass the `Host` class for custom behavior. + +Usage: + To use the `Host` class, create a new `Host` object with the required parameters: + + ``` + from broker.hosts import Host + + host = Host(hostname="example.com", username="user", password="password") + ``` +""" from logzero import logger -from broker.exceptions import NotImplementedError, HostError + +from broker.exceptions import HostError, NotImplementedError from broker.session import ContainerSession, Session from broker.settings import settings class Host: + """Class representing a host that can be accessed via SSH or Bind. + + This class provides methods for connecting to the host, executing commands, and transferring files. + It additionally exposes a common interface for Broker to manage host creation, checkin, and deletion. + It is recommended to subclass the `Host` class for custom behavior. + """ + default_timeout = 0 # timeout in ms, 0 is infinite def __init__(self, **kwargs): - """Create a Host instance + """Create a Host instance. Expected kwargs: hostname: str - Hostname or IP address of the host, required @@ -24,6 +48,7 @@ def __init__(self, **kwargs): if not self.hostname: # check to see if we're being reconstructued, likely for checkin import inspect + if any(f.function == "reconstruct_host" for f in inspect.stack()): logger.debug("Ignoring missing hostname and ip for checkin reconstruction.") else: @@ -31,22 +56,25 @@ def __init__(self, **kwargs): self.name = kwargs.pop("name", None) self.username = kwargs.pop("username", settings.HOST_USERNAME) self.password = kwargs.pop("password", settings.HOST_PASSWORD) - self.timeout = kwargs.pop( - "connection_timeout", settings.HOST_CONNECTION_TIMEOUT - ) + self.timeout = kwargs.pop("connection_timeout", settings.HOST_CONNECTION_TIMEOUT) self.port = kwargs.pop("port", settings.HOST_SSH_PORT) self.key_filename = kwargs.pop("key_filename", settings.HOST_SSH_KEY_FILENAME) - self.__dict__.update(kwargs) # Make every other kwarg an attribute + self.__dict__.update(kwargs) # Make every other kwarg an attribute self._session = None def __del__(self): - """Try to close the connection on garbage collection of the host instance""" + """Try to close the connection on garbage collection of the host instance.""" self.close() # object.__del__ DNE, so I don't have to call it here. # If host inherits from a different class with __del__, it should get called through super @property def session(self): + """Return the session object for the host. + + If the session object does not exist, it will be created by calling the `connect` method. + If the host is a non-SSH-enabled container host, a `ContainerSession` object will be created instead. + """ # This attribute may be missing after pickling if not isinstance(getattr(self, "_session", None), Session): # Check to see if we're a non-ssh-enabled Container Host @@ -56,9 +84,16 @@ def session(self): self.connect() return self._session - def connect( - self, username=None, password=None, timeout=None, port=22, key_filename=None - ): + def connect(self, username=None, password=None, timeout=None, port=22, key_filename=None): + """Connect to the host using SSH. + + Args: + username (str): The username to use for the SSH connection. + password (str): The password to use for the SSH connection. + timeout (int): The timeout for the SSH connection in seconds. + port (int): The port to use for the SSH connection. Defaults to 22. + key_filename (str): The path to the private key file to use for the SSH connection. + """ username = username or self.username password = password or self.password timeout = timeout or self.timeout @@ -75,16 +110,18 @@ def connect( password=password, port=_port, key_filename=key_filename, - timeout=timeout + timeout=timeout, ) def close(self): + """Close the SSH connection to the host.""" # This attribute may be missing after pickling if isinstance(getattr(self, "_session", None), Session): self._session.session.disconnect() self._session = None def release(self): + """Release the host using the appropriate method for the provider.""" raise NotImplementedError("release has not been implemented for this provider") # @cached_property @@ -95,6 +132,15 @@ def _pkg_mgr(self): return None def execute(self, command, timeout=None): + """Execute a command on the host using SSH. + + Args: + command (str): The command to execute on the host. + timeout (int): The timeout for the SSH connection in seconds. Defaults to `None`. + + Returns: + str: The output of the command executed on the host. + """ timeout = timeout or self.default_timeout logger.debug(f"{self.hostname} executing command: {command}") res = self.session.run(command, timeout=timeout) @@ -102,27 +148,33 @@ def execute(self, command, timeout=None): return res def to_dict(self): + """Return a dict representation of the host.""" + keep_keys = ( + "hostname", + "_broker_provider", + "_broker_args", + "tower_inventory", + "job_id", + "_attrs", + ) ret_dict = { - "hostname": self.hostname, "name": getattr(self, "name", None), - "_broker_provider": self._broker_provider, "_broker_provider_instance": self._prov_inst.instance, "type": "host", - "_broker_args": self._broker_args, } - if hasattr(self, "tower_inventory"): - ret_dict["tower_inventory"] = self.tower_inventory + ret_dict.update({k: v for k, v in self.__dict__.items() if k in keep_keys}) return ret_dict def setup(self): - """Automatically ran when entering a Broker context manager""" + """Automatically ran when entering a Broker context manager.""" pass def teardown(self): - """Automatically ran when exiting a Broker context manager""" + """Automatically ran when exiting a Broker context manager.""" pass def __repr__(self): + """Return a string representation of the host.""" inner = ", ".join( f"{k}={v}" for k, v in self.__dict__.items() @@ -132,4 +184,5 @@ def __repr__(self): @classmethod def from_dict(cls, arg_dict): + """Create a Host instance from a dict.""" return cls(**arg_dict, from_dict=True) diff --git a/broker/logger.py b/broker/logger.py index c66fe553..fedf2abc 100644 --- a/broker/logger.py +++ b/broker/logger.py @@ -1,18 +1,18 @@ -# -*- encoding: utf-8 -*- """Module handling internal and dependency logging.""" import copy from enum import IntEnum import logging + +import awxkit import logzero import urllib3 -from broker.settings import BROKER_DIRECTORY, settings -from dynaconf.vendor.box.box_list import BoxList -from dynaconf.vendor.box.box import Box -import awxkit +from broker.settings import BROKER_DIRECTORY, settings class LOG_LEVEL(IntEnum): + """Bare class for log levels. Trace is added for custom logging.""" + TRACE = 5 DEBUG = logging.DEBUG INFO = logging.INFO @@ -21,12 +21,14 @@ class LOG_LEVEL(IntEnum): class RedactingFilter(logging.Filter): - """Custom logging.Filter to redact secrets from the Dynaconf config""" + """Custom logging.Filter to redact secrets from the Dynaconf config.""" + def __init__(self, sensitive): - super(RedactingFilter, self).__init__() + super().__init__() self._sensitive = sensitive def filter(self, record): + """Filter the record and redact the sensitive keys.""" if isinstance(record.args, dict): record.args = self.redact_dynaconf(record.args) else: @@ -34,16 +36,13 @@ def filter(self, record): return True def redact_dynaconf(self, data): - """ - This method goes over the data and redacts all values of keys - that match the sensitive ones - """ - if isinstance(data, (list, tuple)): + """Go over the data and redact all values of keys that match the sensitive ones.""" + if isinstance(data, list | tuple): data_copy = [self.redact_dynaconf(item) for item in data] elif isinstance(data, dict): data_copy = copy.deepcopy(data) for k, v in data_copy.items(): - if isinstance(v, (dict, list)): + if isinstance(v, dict | list): data_copy[k] = self.redact_dynaconf(v) elif k in self._sensitive and v: data_copy[k] = "******" @@ -56,7 +55,9 @@ def redact_dynaconf(self, data): logging.addLevelName("TRACE", LOG_LEVEL.TRACE) logzero.DEFAULT_COLORS[LOG_LEVEL.TRACE.value] = logzero.colors.Fore.MAGENTA + def patch_awx_for_verbosity(api): + """Patch the awxkit API to log when we're at trace level.""" client = api.client awx_log = client.log @@ -66,9 +67,7 @@ def patch(cls, name): func = getattr(cls, name) def the_patch(self, *args, **kwargs): - awx_log.log( - LOG_LEVEL.TRACE.value, f"Calling {self=} {func=}(*{args=}, **{kwargs=}" - ) + awx_log.log(LOG_LEVEL.TRACE.value, f"Calling {self=} {func=}(*{args=}, **{kwargs=}") retval = func(self, *args, **kwargs) awx_log.log( LOG_LEVEL.TRACE.value, @@ -83,6 +82,7 @@ def the_patch(self, *args, **kwargs): def resolve_log_level(level): + """Resolve the log level from a string.""" try: log_level = LOG_LEVEL[level.upper()] except KeyError: @@ -91,10 +91,10 @@ def resolve_log_level(level): def formatter_factory(log_level, color=True): + """Create a logzero formatter based on the log level.""" log_fmt = "%(color)s[%(levelname)s %(asctime)s]%(end_color)s %(message)s" debug_fmt = ( - "%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]" - "%(end_color)s %(message)s" + "%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s" ) formatter = logzero.LogFormatter( fmt=debug_fmt if log_level <= LOG_LEVEL.DEBUG else log_fmt, color=color @@ -103,15 +103,14 @@ def formatter_factory(log_level, color=True): def set_log_level(level=settings.logging.console_level): - if level == "silent": - log_level = LOG_LEVEL.INFO - else: - log_level = resolve_log_level(level) + """Set the log level for logzero.""" + log_level = LOG_LEVEL.INFO if level == "silent" else resolve_log_level(level) logzero.formatter(formatter=formatter_factory(log_level)) logzero.loglevel(level=log_level) def set_file_logging(level=settings.logging.file_level, path="logs/broker.log"): + """Set the file logging for logzero.""" silent = False if level == "silent": silent = True @@ -137,8 +136,10 @@ def setup_logzero( name=None, path="logs/broker.log", ): + """Call logzero setup with the given settings.""" urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - patch_awx_for_verbosity(awxkit.api) + if isinstance(level, str) and level.lower() == "trace": + patch_awx_for_verbosity(awxkit.api) set_log_level(level) set_file_logging(file_level, path) if formatter: diff --git a/broker/providers/__init__.py b/broker/providers/__init__.py index fe91e616..a84bbcdb 100644 --- a/broker/providers/__init__.py +++ b/broker/providers/__init__.py @@ -1,46 +1,103 @@ +"""Module for Broker providers. + +This module defines the `Provider` class, which is the base class for all Broker providers. +It provides useful methods for registering provider actions and must be inherited by all +Broker providers. + +Attributes: + PROVIDERS (dict): Dictionary of provider names and classes. + PROVIDER_ACTIONS (dict): Dictionary of provider actions and their corresponding methods. + PROVIDER_HELP (dict): Dictionary providing information to construct `broker providers --help` + +Classes: + Provider: Base class for all Broker providers. + +Usage: + To create a new Broker provider, create a new class that inherits from the `Provider` class + and implements the required methods. For example: + + ``` + from broker.providers import Provider + + class MyProvider(Provider): + def provider_help(self): + # implementation here + + def get_inventory(self, **inventory_opts): + # implementation here + ``` + +Note: The `Provider` class should not be used directly. + +""" from abc import ABCMeta, abstractmethod -import dynaconf +import inspect from pathlib import Path -from broker import exceptions -from broker.settings import settings +import dynaconf from logzero import logger +from broker import exceptions +from broker.settings import settings # populate a list of all provider module names _provider_imports = [ - f.stem - for f in Path(__file__).parent.glob("*.py") - if f.is_file() and f.stem != "__init__" + f.stem for f in Path(__file__).parent.glob("*.py") if f.is_file() and f.stem != "__init__" ] # ProviderName: ProviderClassObject PROVIDERS = {} # action: (InterfaceClass, "method_name") PROVIDER_ACTIONS = {} +# action: (InterfaceClass, "method_name") +PROVIDER_HELP = {} class ProviderMeta(ABCMeta): - """Metaclass that registers provider classes and actions""" + """Metaclass that registers provider classes and actions.""" def __new__(cls, name, bases, attrs): - """Register provider classes and actions""" + """Register provider classes and actions.""" new_cls = super().__new__(cls, name, bases, attrs) if name != "Provider": PROVIDERS[name] = new_cls logger.debug(f"Registered provider {name}") - for attr in attrs.values(): - if hasattr(attr, "_as_action"): - for action in attr._as_action: - PROVIDER_ACTIONS[action] = (new_cls, attr.__name__) + for attr, obj in attrs.items(): + if attr == "provider_help": + # register the help options based on the function arguments + for name, param in inspect.signature(obj).parameters.items(): + if name not in ("self", "kwargs"): + # {name: (cls, is_flag)} + PROVIDER_HELP[name] = ( + new_cls, + isinstance(param.default, bool), + ) + logger.debug(f"Registered help option {name} for provider {name}") + elif hasattr(obj, "_as_action"): + for action in obj._as_action: + PROVIDER_ACTIONS[action] = (new_cls, attr) logger.debug(f"Registered action {action} for provider {name}") return new_cls class Provider(metaclass=ProviderMeta): + """Abstract base class for all providers. + + This class should be subclassed by all provider implementations. It provides a + metaclass that registers provider classes and actions. + + Attributes: + _validators (list): A list of Dynaconf Validators specific to the provider. + hidden (bool): A flag to hide the provider from the CLI. + _checkout_options (list): A list of checkout options to add to each command. + _execute_options (list): A list of execute options to add to each command. + _fresh_settings (dynaconf.Dynaconf): A clone of the global settings object. + _sensitive_attrs (list): A list of sensitive attributes that should not be logged. + """ + # Populate with a list of Dynaconf Validators specific to your provider _validators = [] - # Set to true if you don't want your provider shown in the CLI + # Used to hide the provider from the CLI hidden = False # Populate these to add your checkout and execute options to each command # _checkout_options = [click.option("--workflow", type=str, help="Help text")] @@ -57,7 +114,7 @@ def __init__(self, **kwargs): self._validate_settings(self.instance) def _validate_settings(self, instance_name=None): - """Load and validate provider settings + """Load and validate provider settings. Each provider's settings can include an instances list with specific instance details. @@ -78,13 +135,10 @@ def _validate_settings(self, instance_name=None): if instance_name in candidate: instance = candidate break - elif ( - candidate.values()[0].get("default") - or len(fresh_settings.instances) == 1 - ): + elif candidate.values()[0].get("default") or len(fresh_settings.instances) == 1: instance = candidate self.instance, *_ = instance # store the instance name on the provider - fresh_settings.update((inst_vals := instance.values()[0])) + fresh_settings.update(inst_vals := instance.values()[0]) settings[section_name] = fresh_settings if not inst_vals.get("override_envars"): # if a provider instance doesn't want to override envars, load them @@ -96,15 +150,16 @@ def _validate_settings(self, instance_name=None): try: settings.validators.validate(only=section_name) except dynaconf.ValidationError as err: - raise exceptions.ConfigurationError(err) + raise exceptions.ConfigurationError(err) from err def _set_attributes(self, obj, attrs): obj.__dict__.update(attrs) - def _get_params(arg_list, kwargs): + def _get_params(self, arg_list, kwargs): return {k: v for k, v in kwargs.items() if k in arg_list} def construct_host(self, host_cls, provider_params, **kwargs): + """Construct a host object from a host class and include relevent provider params.""" host_inst = host_cls(**provider_params, **kwargs) host_attrs = self._get_params(self._construct_params) host_attrs["release"] = self._host_release @@ -112,22 +167,28 @@ def construct_host(self, host_cls, provider_params, **kwargs): return host_inst @abstractmethod - def nick_help(self): - pass + def provider_help(self): + """Help options that will be added to the CLI. + + Anything other than 'self' and 'kwargs' will be added as a help option + To specify a flag, set the default value to a boolean + Everything else should default to None + """ @abstractmethod def get_inventory(self, **kwargs): - pass + """Pull inventory information from the provider.""" @abstractmethod def extend(self): - pass + """Extend the reservation of a host. Not all providers support this.""" @abstractmethod def release(self, host_obj): - pass + """Release/return a host to the provider. Often this is a deletion or removal.""" def __repr__(self): + """Return a string representation of the provider.""" inner = ", ".join( f"{k}={'******' if k in self._sensitive_attrs and v else v}" for k, v in self.__dict__.items() @@ -135,9 +196,17 @@ def __repr__(self): ) return f"{self.__class__.__name__}({inner})" + @staticmethod + def auto_hide(cls): + """Decorate a provider class to hide it from the CLI.""" + if not settings.get(cls.__name__.upper(), False): + # import IPython; IPython.embed() + cls.hidden = True + return cls + @staticmethod def register_action(*as_names): - """Decorator to register a provider action + """Decorate a provider method to register it as a provider action. :param as_names: One or more action names to register the decorated function as """ diff --git a/broker/providers/ansible_tower.py b/broker/providers/ansible_tower.py index 40d24f0b..1b17f01d 100644 --- a/broker/providers/ansible_tower.py +++ b/broker/providers/ansible_tower.py @@ -1,32 +1,33 @@ -import click +"""Ansible Tower provider implementation.""" +from datetime import datetime +from functools import cache, cached_property import inspect import json -import yaml from urllib import parse as url_parser -from functools import cache, cached_property + +import click from dynaconf import Validator +from logzero import logger +import yaml + from broker import exceptions -from broker.helpers import find_origin, eval_filter +from broker.helpers import eval_filter, find_origin from broker.settings import settings -from logzero import logger -from datetime import datetime try: import awxkit -except ImportError: +except ImportError as err: raise exceptions.ProviderError( provider="AnsibleTower", message="Unable to import awxkit. Is it installed?" - ) + ) from err -from broker.providers import Provider from broker import helpers +from broker.providers import Provider @cache -def get_awxkit_and_uname( - config=None, root=None, url=None, token=None, uname=None, pword=None -): - """Return an awxkit api object and resolved username""" +def get_awxkit_and_uname(config=None, root=None, url=None, token=None, uname=None, pword=None): + """Return an awxkit api object and resolved username.""" # Prefer token if its set, otherwise use username/password # auth paths for the API taken from: # https://github.com/ansible/awx/blob/ddb6c5d0cce60779be279b702a15a2fddfcd0724/awxkit/awxkit/cli/client.py#L85-L94 @@ -40,22 +41,20 @@ def get_awxkit_and_uname( logger.info("Using token authentication") config.token = token try: - root.connection.login( - username=None, password=None, token=token, auth_type="Bearer" - ) + root.connection.login(username=None, password=None, token=token, auth_type="Bearer") except awxkit.exceptions.Unauthorized as err: - raise exceptions.AuthenticationError(err.args[0]) + raise exceptions.AuthenticationError(err.args[0]) from err versions = root.get().available_versions try: # lookup the user that authenticated with the token # If a username was specified in config, use that instead my_username = uname or versions.v2.get().me.get().results[0].username - except (IndexError, AttributeError): + except (IndexError, AttributeError) as err: # lookup failed for whatever reason raise exceptions.ProviderError( provider="AnsibleTower", message="Failed to lookup a username for the given token, please check credentials", - ) + ) from err else: # dynaconf validators should have checked that either token or password was provided helpers.emit(auth_type="password") if datetime.now() < datetime(2023, 2, 6): @@ -77,7 +76,9 @@ def get_awxkit_and_uname( return versions.v2.get(), my_username +@Provider.auto_hide class AnsibleTower(Provider): + """Ansible Tower provider provides a Broker-specific wrapper around awxkit.""" _validators = [ Validator("ANSIBLETOWER.release_workflow", default="remove-vm"), @@ -104,9 +105,7 @@ class AnsibleTower(Provider): type=str, help="AnsibleTower inventory to checkout a host on", ), - click.option( - "--workflow", type=str, help="Name of a workflow used to checkout a host" - ), + click.option("--workflow", type=str, help="Name of a workflow used to checkout a host"), ] _execute_options = [ click.option( @@ -115,9 +114,7 @@ class AnsibleTower(Provider): help="AnsibleTower inventory to execute against", ), click.option("--workflow", type=str, help="Name of a workflow to execute"), - click.option( - "--job-template", type=str, help="Name of a job template to execute" - ), + click.option("--job-template", type=str, help="Name of a job template to execute"), ] _extend_options = [ click.option( @@ -130,15 +127,20 @@ class AnsibleTower(Provider): _sensitive_attrs = ["pword", "password", "token"] def __init__(self, **kwargs): + """Almost all values are taken from Broker's config with the following exceptions. + + kwargs: + tower_inventory: AnsibleTower inventory to use for this instance + config: awxkit config object + root: awxkit api root object + """ super().__init__(**kwargs) # get our instance settings self.url = settings.ANSIBLETOWER.base_url self.uname = settings.ANSIBLETOWER.get("username") self.pword = settings.ANSIBLETOWER.get("password") self.token = settings.ANSIBLETOWER.get("token") - self._inventory = ( - kwargs.get("tower_inventory") or settings.ANSIBLETOWER.inventory - ) + self._inventory = kwargs.get("tower_inventory") or settings.ANSIBLETOWER.inventory # Init the class itself config = kwargs.get("config") root = kwargs.get("root") @@ -151,13 +153,14 @@ def __init__(self, **kwargs): pword=self.pword, ) # Check to see if we're running AAP (ver 4.0+) - self._is_aap = False if self.v2.ping.get().version[0] == "3" else True + self._is_aap = self.v2.ping.get().version[0] != "3" @staticmethod def _pull_params(kwargs): - """Given a kwarg dict, separate AT-specific parameters from other kwargs + """Given a kwarg dict, separate AT-specific parameters from other kwargs. + AT-specific params must stat with double underscores. - Example: __page_size + Example: __page_size. """ params, new_kwargs = {}, {} for key, value in kwargs.items(): @@ -231,7 +234,7 @@ def _translate_inventory(self, inventory): ) def _merge_artifacts(self, at_object, strategy="last", artifacts=None): - """Gather and merge all artifacts associated with an object and its children + """Gather and merge all artifacts associated with an object and its children. :param at_object: object you want to merge @@ -258,9 +261,7 @@ def _merge_artifacts(self, at_object, strategy="last", artifacts=None): children = at_object.get_related("workflow_nodes").results # filter out children with no associated job children = list( - filter( - lambda child: getattr(child.summary_fields, "job", None), children - ) + filter(lambda child: getattr(child.summary_fields, "job", None), children) ) children.sort(key=lambda child: child.summary_fields.job.id) if strategy == "last": @@ -273,8 +274,7 @@ def _merge_artifacts(self, at_object, strategy="last", artifacts=None): if child_obj: child_obj = child_obj.pop() artifacts = ( - self._merge_artifacts(child_obj, strategy, artifacts) - or artifacts + self._merge_artifacts(child_obj, strategy, artifacts) or artifacts ) else: logger.warning( @@ -283,21 +283,17 @@ def _merge_artifacts(self, at_object, strategy="last", artifacts=None): return artifacts def _get_failure_messages(self, workflow): - """Find all failure nodes and aggregate failure messages""" + """Find all failure nodes and aggregate failure messages.""" failure_messages = [] # get all failed job nodes (iterate) if "workflow_nodes" in workflow.related: children = workflow.get_related("workflow_nodes").results # filter out children with no associated job children = list( - filter( - lambda child: getattr(child.summary_fields, "job", None), children - ) + filter(lambda child: getattr(child.summary_fields, "job", None), children) ) # filter out children that didn't fail - children = list( - filter(lambda child: child.summary_fields.job.failed, children) - ) + children = list(filter(lambda child: child.summary_fields.job.failed, children)) children.sort(key=lambda child: child.summary_fields.job.id) for child in children[::-1]: if child.type == "workflow_job_node": @@ -321,9 +317,7 @@ def _get_failure_messages(self, workflow): # get all failed job_events for each job (filter failed=true) failed_events = [ ev - for ev in child_obj.get_related( - "job_events", page_size=200 - ).results + for ev in child_obj.get_related("job_events", page_size=200).results if ev.failed ] # find the one(s) with event_data['res']['msg'] @@ -350,14 +344,11 @@ def _get_failure_messages(self, workflow): def _get_expire_date(self, host_id): try: time_stamp = ( - self.v2.hosts.get(id=host_id) - .results[0] - .related.ansible_facts.get() - .expire_date + self.v2.hosts.get(id=host_id).results[0].related.ansible_facts.get().expire_date ) return str(datetime.fromtimestamp(int(time_stamp))) - except Exception: - return None + except Exception as err: # noqa: BLE001 + logger.debug(f"Tell Jake that the exception here is: {err}!") def _compile_host_info(self, host): # attempt to get the hostname from the host variables and then facts @@ -378,9 +369,7 @@ def _compile_host_info(self, host): if expire_time: host_info["expire_time"] = expire_time try: - create_job = self.v2.jobs.get( - id=host.get_related("job_events").results[0].job - ) + create_job = self.v2.jobs.get(id=host.get_related("job_events").results[0].job) create_job = create_job.results[0].get_related("source_workflow_job") host_info["_broker_args"]["workflow"] = create_job.name except IndexError: @@ -391,19 +380,14 @@ def _compile_host_info(self, host): host_info["_broker_args"]["workflow"] = host.get_related( "last_job" ).summary_fields.source_workflow_job.name - except Exception: - logger.warning( - f"Unable to determine workflow for {host_info['hostname']}" - ) + except Exception as err: # noqa: BLE001 + logger.debug(f"Tell Jake that the exception here is: {err}!") + logger.warning(f"Unable to determine workflow for {host_info['hostname']}") else: return host_info create_vars = json.loads(create_job.extra_vars) host_info["_broker_args"].update( - { - arg: val - for arg, val in create_vars.items() - if val and isinstance(val, str) - } + {arg: val for arg, val in create_vars.items() if val and isinstance(val, str)} ) # temporary workaround for OSP hosts that have lost their hostname if not host_info["hostname"] and host.variables.get("openstack"): @@ -412,6 +396,7 @@ def _compile_host_info(self, host): @cached_property def inventory(self): + """Return the current tower inventory.""" if not self._inventory: return elif isinstance(self._inventory, int): @@ -421,7 +406,7 @@ def inventory(self): return self._inventory def construct_host(self, provider_params, host_classes, **kwargs): - """Constructs host to be read by Ansible Tower + """Construct a host to be read by Ansible Tower. :param provider_params: dictionary of what the provider returns when initially creating the vm @@ -433,13 +418,11 @@ def construct_host(self, provider_params, host_classes, **kwargs): misc_attrs = {} # used later to add misc attributes to host object if provider_params: job = provider_params - job_attrs = self._merge_artifacts( - job, strategy=kwargs.get("strategy", "last") - ) + job_attrs = self._merge_artifacts(job, strategy=kwargs.get("strategy", "last")) # pull information about the job arguments job_extra_vars = json.loads(job.extra_vars) # and update them if they have resolved values - for key in job_extra_vars.keys(): + for key in job_extra_vars: job_extra_vars[key] = job_attrs.get(key) kwargs.update({key: val for key, val in job_extra_vars.items() if val}) kwargs.update({key: val for key, val in job_attrs.items() if val}) @@ -461,17 +444,15 @@ def construct_host(self, provider_params, host_classes, **kwargs): if not hostname: logger.warning(f"No hostname found in job attributes:\n{job_attrs}") logger.debug(f"hostname: {hostname}, name: {name}, host type: {host_type}") - host_inst = host_classes[host_type]( - **{**kwargs, "hostname": hostname, "name": name} - ) + host_inst = host_classes[host_type](**{**kwargs, "hostname": hostname, "name": name}) else: host_inst = host_classes[kwargs.get("type")](**kwargs) self._set_attributes(host_inst, broker_args=kwargs, misc_attrs=misc_attrs) return host_inst @Provider.register_action("workflow", "job_template") - def execute(self, **kwargs): - """Execute workflow or job template in Ansible Tower + def execute(self, **kwargs): # noqa: PLR0912 - Possible TODO refactor + """Execute workflow or job template in Ansible Tower. :param kwargs: workflow or job template name passed in a string @@ -494,7 +475,7 @@ def execute(self, **kwargs): try: candidates = get_path.get(name=name).results except awxkit.exceptions.Unauthorized as err: - raise exceptions.AuthenticationError(err.args[0]) + raise exceptions.AuthenticationError(err.args[0]) from err if candidates: target = candidates.pop() else: @@ -505,34 +486,27 @@ def execute(self, **kwargs): payload = {} if inventory := kwargs.pop("inventory", None): payload["inventory"] = inventory - logger.info( - f"Using tower inventory: {self._translate_inventory(inventory)}" - ) + logger.info(f"Using tower inventory: {self._translate_inventory(inventory)}") elif self.inventory: payload["inventory"] = self.inventory - logger.info( - f"Using tower inventory: {self._translate_inventory(self.inventory)}" - ) + logger.info(f"Using tower inventory: {self._translate_inventory(self.inventory)}") else: logger.info("No inventory specified, Ansible Tower will use a default.") payload["extra_vars"] = str(kwargs) logger.debug( - f"Launching {subject}: {url_parser.urljoin(self.url, str(target.url))}\n" - f"{payload=}" + f"Launching {subject}: {url_parser.urljoin(self.url, str(target.url))}\n{payload=}" ) job = target.launch(payload=payload) job_number = job.url.rstrip("/").split("/")[-1] job_api_url = url_parser.urljoin(self.url, str(job.url)) if self._is_aap: - job_ui_url = url_parser.urljoin( - self.url, f"/#/jobs/{subject}/{job_number}/output" - ) + job_ui_url = url_parser.urljoin(self.url, f"/#/jobs/{subject}/{job_number}/output") else: job_ui_url = url_parser.urljoin(self.url, f"/#/{subject}s/{job_number}") helpers.emit(api_url=job_api_url, ui_url=job_ui_url) - logger.info("Waiting for job: \n" f"API: {job_api_url}\n" f"UI: {job_ui_url}") + logger.info(f"Waiting for job: \nAPI: {job_api_url}\nUI: {job_ui_url}") job.wait_until_completed(timeout=settings.ANSIBLETOWER.workflow_timeout) - if not job.status == "successful": + if job.status != "successful": message_data = { f"{subject.capitalize()} Status": job.status, "Reason(s)": self._get_failure_messages(job), @@ -548,7 +522,7 @@ def execute(self, **kwargs): return job def get_inventory(self, user=None): - """Compile a list of hosts based on any inventory a user's name is mentioned""" + """Compile a list of hosts based on any inventory a user's name is mentioned.""" user = user or self.username invs = [ inv @@ -559,12 +533,12 @@ def get_inventory(self, user=None): for inv in invs: inv_hosts = inv.get_related("hosts", page_size=200).results hosts.extend(inv_hosts) - with click.progressbar(hosts, label='Compiling host information') as hosts_bar: + with click.progressbar(hosts, label="Compiling host information") as hosts_bar: compiled_host_info = [self._compile_host_info(host) for host in hosts_bar] return compiled_host_info def extend(self, target_vm, new_expire_time=None): - """Run the extend workflow with defaults args + """Run the extend workflow with defaults args. :param target_vm: This should be a host object """ @@ -572,20 +546,28 @@ def extend(self, target_vm, new_expire_time=None): if new_inv := target_vm._broker_args.get("tower_inventory"): if new_inv != self._inventory: self._inventory = new_inv - if hasattr(self.__dict__, 'inventory'): + if hasattr(self.__dict__, "inventory"): del self.inventory # clear the cached value return self.execute( workflow=settings.ANSIBLETOWER.extend_workflow, target_vm=target_vm.name, - new_expire_time=new_expire_time - or settings.ANSIBLETOWER.get("new_expire_time"), + new_expire_time=new_expire_time or settings.ANSIBLETOWER.get("new_expire_time"), ) - @Provider.register_action("template", "inventory") - def nick_help(self, **kwargs): - """Get a list of extra vars and their defaults from a workflow""" + def provider_help( + self, + workflows=False, + workflow=None, + job_templates=False, + job_template=None, + templates=False, + inventories=False, + inventory=None, + **kwargs, + ): + """Get a list of extra vars and their defaults from a workflow.""" results_limit = kwargs.get("results_limit", settings.ANSIBLETOWER.results_limit) - if workflow := kwargs.get("workflow"): + if workflow: wfjt = self.v2.workflow_job_templates.get(name=workflow).results.pop() default_inv = self.v2.inventory.get(id=wfjt.inventory).results.pop() logger.info( @@ -593,12 +575,10 @@ def nick_help(self, **kwargs): f"Accepted additional nick fields:\n{helpers.yaml_format(wfjt.extra_vars)}" f"tower_inventory: {default_inv['name']}" ) - elif kwargs.get("workflows"): + elif workflows: workflows = [ workflow.name - for workflow in self.v2.workflow_job_templates.get( - page_size=1000 - ).results + for workflow in self.v2.workflow_job_templates.get(page_size=1000).results if workflow.summary_fields.user_capabilities.get("start") ] if res_filter := kwargs.get("results_filter"): @@ -606,21 +586,18 @@ def nick_help(self, **kwargs): workflows = workflows if isinstance(workflows, list) else [workflows] workflows = "\n".join(workflows[:results_limit]) logger.info(f"Available workflows:\n{workflows}") - elif inventory := kwargs.get("inventory"): + elif inventory: inv = self.v2.inventory.get(name=inventory, kind="").results.pop() inv = {"Name": inv.name, "ID": inv.id, "Description": inv.description} logger.info(f"Accepted additional nick fields:\n{helpers.yaml_format(inv)}") - elif kwargs.get("inventories"): - inv = [ - inv.name - for inv in self.v2.inventory.get(kind="", page_size=1000).results - ] + elif inventories: + inv = [inv.name for inv in self.v2.inventory.get(kind="", page_size=1000).results] if res_filter := kwargs.get("results_filter"): inv = eval_filter(inv, res_filter, "res") inv = inv if isinstance(inv, list) else [inv] inv = "\n".join(inv[:results_limit]) logger.info(f"Available Inventories:\n{inv}") - elif job_template := kwargs.get("job_template"): + elif job_template: jt = self.v2.job_templates.get(name=job_template).results.pop() default_inv = self.v2.inventory.get(id=jt.inventory).results.pop() logger.info( @@ -628,7 +605,7 @@ def nick_help(self, **kwargs): f"Accepted additional nick fields:\n{helpers.yaml_format(jt.extra_vars)}" f"tower_inventory: {default_inv['name']}" ) - elif kwargs.get("job_templates"): + elif job_templates: job_templates = [ job_template.name for job_template in self.v2.job_templates.get(page_size=1000).results @@ -636,17 +613,18 @@ def nick_help(self, **kwargs): ] if res_filter := kwargs.get("results_filter"): job_templates = eval_filter(job_templates, res_filter, "res") - job_templates = job_templates if isinstance(job_templates, list) else [job_templates] + job_templates = ( + job_templates if isinstance(job_templates, list) else [job_templates] + ) job_templates = "\n".join(job_templates[:results_limit]) logger.info(f"Available job templates:\n{job_templates}") - elif kwargs.get("templates"): + elif templates: templates = list( - { - tmpl - for tmpl in self.execute( - workflow="list-templates", artifacts="last" - )["data_out"]["list_templates"] - } + set( + self.execute(workflow="list-templates", artifacts="last")["data_out"][ + "list_templates" + ] + ) ) templates.sort(reverse=True) if res_filter := kwargs.get("results_filter"): @@ -654,10 +632,9 @@ def nick_help(self, **kwargs): templates = templates if isinstance(templates, list) else [templates] templates = "\n".join(templates[:results_limit]) logger.info(f"Available templates:\n{templates}") - else: - logger.warning("That action is not yet implemented.") def release(self, name, broker_args=None): + """Release the host back to the tower instance via the release workflow.""" if broker_args is None: broker_args = {} return self.execute( @@ -668,7 +645,7 @@ def release(self, name, broker_args=None): def awxkit_representer(dumper, data): - """In order to resolve awxkit objects, a custom representer is needed""" + """In order to resolve awxkit objects, a custom representer is needed.""" return dumper.represent_dict(dict(data)) diff --git a/broker/providers/beaker.py b/broker/providers/beaker.py new file mode 100644 index 00000000..955c3d22 --- /dev/null +++ b/broker/providers/beaker.py @@ -0,0 +1,160 @@ +"""Beaker provider implementation.""" +import inspect + +import click +from dynaconf import Validator +from logzero import logger + +from broker import helpers +from broker.binds.beaker import BeakerBind +from broker.exceptions import BrokerError, ProviderError +from broker.hosts import Host +from broker.providers import Provider +from broker.settings import settings + + +@Provider.auto_hide +class Beaker(Provider): + """Beaker provider class providing a Broker interface around the Beaker bind.""" + + _validators = [ + Validator("beaker.hub_url", must_exist=True), + Validator("beaker.max_job_wait", default="24h"), + ] + _checkout_options = [ + click.option( + "--job-xml", + type=click.Path(exists=True, dir_okay=False), + help="Path to the job XML file to submit", + ), + click.option( + "--job-id", + type=str, + help="Beaker job ID to clone", + ), + ] + _execute_options = [ + click.option( + "--job-xml", + type=str, + help="Path to the job XML file to submit", + ), + ] + _extend_options = [ + click.option( + "--extend-duration", + type=click.IntRange(1, 99), + help="Number of hours to extend the job. Must be between 1 and 99", + ) + ] + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.hub_url = settings.beaker.hub_url + self.runtime = kwargs.pop("bind", BeakerBind)(self.hub_url, **kwargs) + + def _host_release(self): + caller_host = inspect.stack()[1][0].f_locals["host"] + if not (job_id := getattr(caller_host, "job_id", None)): + job_id = self.runtime.jobid_from_system(caller_host.hostname) + return self.release(caller_host.hostname, job_id) + + def _set_attributes(self, host_inst, broker_args=None, misc_attrs=None): + host_inst.__dict__.update( + { + "_prov_inst": self, + "_broker_provider": "Beaker", + "_broker_provider_instance": self.instance, + "_broker_args": broker_args, + "release": self._host_release, + } + ) + if isinstance(misc_attrs, dict): + host_inst._attrs = misc_attrs + + def _compile_host_info(self, host, broker_info=True): + """Compiles host information into a dictionary suitable for use in the inventory. + + :param host (beaker.host.Host): The host to compile information for. + + :return: A dictionary containing the compiled host information. + """ + curated_host_info = self.runtime.system_details_curated(host) + if broker_info: + curated_host_info.update( + { + "_broker_provider": "Beaker", + "_broker_provider_instance": self.instance, + "_broker_args": getattr(host, "_broker_args", {}), + } + ) + if not curated_host_info.get("job_id"): + curated_host_info["job_id"] = self.runtime.jobid_from_system( + curated_host_info["hostname"] + ) + return curated_host_info + + def construct_host(self, provider_params, host_classes, **kwargs): + """Construct a broker host from a beaker system information. + + :param provider_params: a beaker system information dictionary + + :param host_classes: host object + + :return: constructed broker host object + """ + logger.debug(f"constructing with {provider_params=}\n{host_classes=}\n{kwargs=}") + if not provider_params: + host_inst = host_classes[kwargs.get("type", "host")](**kwargs) + # cont_inst = self._cont_inst_by_name(host_inst.name) + self._set_attributes(host_inst, broker_args=kwargs) + else: + host_info = self._compile_host_info(provider_params["hostname"], broker_info=False) + host_inst = host_classes[kwargs.get("type", "host")](**provider_params) + self._set_attributes(host_inst, broker_args=kwargs, misc_attrs=host_info) + return host_inst + + @Provider.register_action("job_xml", "job_id") + def submit_job(self, max_wait=None, **kwargs): + """Submit a job to Beaker and wait for it to complete.""" + job = kwargs.get("job_xml") or kwargs.get("job_id") + max_wait = max_wait or settings.beaker.get("max_job_wait") + result = self.runtime.execute_job(job, max_wait) + logger.debug(f"Job completed with results: {result}") + return result + + def provider_help(self, jobs=False, job=None, **kwargs): + """Print useful information from the Beaker provider.""" + results_limit = kwargs.get("results_limit", settings.container.results_limit) + if job: + if not job.startswith("J:"): + job = f"J:{job}" + logger.info(self.runtime.job_clone(job, prettyxml=True, dryrun=True).stdout) + elif jobs: + result = self.runtime.job_list(**kwargs).stdout.splitlines() + if res_filter := kwargs.get("results_filter"): + result = helpers.eval_filter(result, res_filter, "res") + result = "\n".join(result[:results_limit]) + logger.info(f"Available jobs:\n{result}") + + def release(self, host_name, job_id): + """Release a hosts reserved from Beaker by cancelling the job.""" + return self.runtime.job_cancel(job_id) + # return self.runtime.system_release(host_name) + + def extend(self, host_name, extend_duration=99): + """Extend the duration of a Beaker reservation.""" + try: + Host(hostname=host_name).execute(f"/usr/bin/extendtesttime.sh {extend_duration}") + except BrokerError as err: + raise ProviderError( + f"Failed to extend host {host_name}: {err}\n" + f"Try running: root@{host_name} /usr/bin/extendtesttime.sh {extend_duration}" + ) from err + + def get_inventory(self, *args): + """Get a list of hosts and their information from Beaker.""" + hosts = self.runtime.user_systems() + with click.progressbar(hosts, label="Compiling host information") as hosts_bar: + compiled_host_info = [self._compile_host_info(host) for host in hosts_bar] + return compiled_host_info diff --git a/broker/providers/container.py b/broker/providers/container.py index 0471390f..d0df7598 100644 --- a/broker/providers/container.py +++ b/broker/providers/container.py @@ -1,18 +1,21 @@ +"""Container provider implementation.""" from functools import cache import getpass import inspect from uuid import uuid4 + import click -from logzero import logger from dynaconf import Validator -from broker import exceptions -from broker import helpers -from broker.settings import settings -from broker.providers import Provider +from logzero import logger + +from broker import exceptions, helpers from broker.binds import containers +from broker.providers import Provider +from broker.settings import settings def container_info(container_inst): + """Return a dict of container information.""" return { "_broker_provider": "Container", "name": container_inst.name, @@ -26,17 +29,14 @@ def container_info(container_inst): def _host_release(): caller_host = inspect.stack()[1][0].f_locals["host"] if not caller_host._cont_inst: - caller_host._cont_inst = caller_host._prov_inst._cont_inst_by_name( - caller_host.name - ) + caller_host._cont_inst = caller_host._prov_inst._cont_inst_by_name(caller_host.name) caller_host._cont_inst.remove(v=True, force=True) caller_host._checked_in = True @cache -def get_runtime( - runtime_cls=None, host=None, username=None, password=None, port=None, timeout=None -): +def get_runtime(runtime_cls=None, host=None, username=None, password=None, port=None, timeout=None): + """Return a runtime instance.""" return runtime_cls( host=host, username=username, @@ -46,7 +46,10 @@ def get_runtime( ) +@Provider.auto_hide class Container(Provider): + """Container provider class providing a Broker interface around the container binds.""" + _validators = [ Validator("CONTAINER.runtime", default="podman"), Validator("CONTAINER.host", default="localhost"), @@ -100,7 +103,7 @@ def __init__(self, **kwargs): self._name_prefix = settings.container.get("name_prefix", getpass.getuser()) def _ensure_image(self, name): - """Check if an image exists on the provider, attempt a pull if not""" + """Check if an image exists on the provider, attempt a pull if not.""" for image in self.runtime.images: if name in image.tags: return @@ -108,18 +111,18 @@ def _ensure_image(self, name): return try: self.runtime.pull_image(name) - except Exception as err: + except Exception as err: # noqa: BLE001 - This could be a few things raise exceptions.ProviderError( "Container", f"Unable to find image: {name}\n{err}" - ) + ) from err @staticmethod - def _find_ssh_port(port_map): - """Go through container port map and find the mapping that corresponds to port 22""" + def _find_ssh_port(port_map, ssh_port=22): + """Go through container port map and find the mapping that corresponds to port 22.""" if isinstance(port_map, list): # [{'hostPort': 1337, 'containerPort': 22, 'protocol': 'tcp', 'hostIP': ''}, for pm in port_map: - if pm["containerPort"] == 22: + if pm["containerPort"] == ssh_port: return pm["hostPort"] elif isinstance(port_map, dict): # {'22/tcp': [{'HostIp': '', 'HostPort': '1337'}], @@ -140,13 +143,15 @@ def _set_attributes(self, host_inst, broker_args=None, cont_inst=None): ) def _port_mapping(self, image, **kwargs): - """ - 22 - 22:1337 - 22/tcp - 22/tcp:1337 - 22,23 - 22:1337 23:1335 + """Create a mapping of ports to expose on the container. + + Accepted `ports` formats: + 22 + 22:1337 + 22/tcp + 22/tcp:1337 + 22,23 + 22:1337 23:1335. """ mapping = {} if ports := kwargs.pop("ports", None): @@ -164,21 +169,19 @@ def _port_mapping(self, image, **kwargs): elif settings.container.auto_map_ports: mapping = { k: v or None - for k, v in self.runtime.image_info(image)["config"][ - "ExposedPorts" - ].items() + for k, v in self.runtime.image_info(image)["config"]["ExposedPorts"].items() } return mapping def _cont_inst_by_name(self, cont_name): - """Attempt to find and return a container by its name""" + """Attempt to find and return a container by its name.""" for cont_inst in self.runtime.containers: if cont_inst.name == cont_name: return cont_inst logger.error(f"Unable to find container by name {cont_name}") def construct_host(self, provider_params, host_classes, **kwargs): - """Constructs broker host from a container instance + """Construct a broker host from a container instance. :param provider_params: a container instance object @@ -186,9 +189,7 @@ def construct_host(self, provider_params, host_classes, **kwargs): :return: broker object of constructed host instance """ - logger.debug( - f"constructing with {provider_params=}\n{host_classes=}\n{kwargs=}" - ) + logger.debug(f"constructing with {provider_params=}\n{host_classes=}\n{kwargs=}") if not provider_params: host_inst = host_classes[kwargs.get("type", "host")](**kwargs) cont_inst = self._cont_inst_by_name(host_inst.name) @@ -204,21 +205,21 @@ def construct_host(self, provider_params, host_classes, **kwargs): raise Exception(f"Could not determine container hostname:\n{cont_attrs}") name = cont_attrs["name"] logger.debug(f"hostname: {hostname}, name: {name}, host type: host") - host_inst = host_classes["host"]( - **{**kwargs, "hostname": hostname, "name": name} - ) + host_inst = host_classes["host"](**{**kwargs, "hostname": hostname, "name": name}) self._set_attributes(host_inst, broker_args=kwargs, cont_inst=cont_inst) return host_inst - def nick_help(self, **kwargs): - """Useful information about container images""" + def provider_help( + self, container_hosts=False, container_host=None, container_apps=False, **kwargs + ): + """Return useful information about container images.""" results_limit = kwargs.get("results_limit", settings.container.results_limit) - if image := kwargs.get("container_host"): + if container_host: logger.info( - f"Information for {image} container-host:\n" - f"{helpers.yaml_format(self.runtime.image_info(image))}" + f"Information for {container_host} container-host:\n" + f"{helpers.yaml_format(self.runtime.image_info(container_host))}" ) - elif kwargs.get("container_hosts"): + elif container_hosts: images = [ img.tags[0] for img in self.runtime.images @@ -229,7 +230,7 @@ def nick_help(self, **kwargs): images = images if isinstance(images, list) else [images] images = "\n".join(images[:results_limit]) logger.info(f"Available host images:\n{images}") - elif kwargs.get("container_apps"): + elif container_apps: images = [img.tags[0] for img in self.runtime.images if img.tags] if res_filter := kwargs.get("results_filter"): images = helpers.eval_filter(images, res_filter, "res") @@ -238,7 +239,7 @@ def nick_help(self, **kwargs): logger.info(f"Available app images:\n{images}") def get_inventory(self, name_prefix): - """Get all containers that have a matching name prefix""" + """Get all containers that have a matching name prefix.""" name_prefix = name_prefix or self._name_prefix return [ container_info(cont) @@ -247,22 +248,23 @@ def get_inventory(self, name_prefix): ] def extend(self): - pass + """There is no need to extend a continer-ased host.""" def release(self, host_obj): + """Remove a container-based host from the container host.""" host_obj._cont_inst.remove(force=True) @Provider.register_action("container_host") def run_container(self, container_host, **kwargs): - """Start a container based on an image name (container_host)""" + """Start a container based on an image name (container_host).""" self._ensure_image(container_host) if not kwargs.get("name"): kwargs["name"] = self._gen_name() kwargs["ports"] = self._port_mapping(container_host, **kwargs) - envars = kwargs.get('environment', {}) + envars = kwargs.get("environment", {}) if isinstance(envars, str): - envars = {var.split('=')[0]: var.split('=')[1] for var in envars.split(',')} + envars = {var.split("=")[0]: var.split("=")[1] for var in envars.split(",")} # add some context information about the container's requester origin = helpers.find_origin() @@ -279,10 +281,11 @@ def run_container(self, container_host, **kwargs): @Provider.register_action("container_app") def execute(self, container_app, **kwargs): - """Run a container and return the raw results""" + """Run a container and return the raw results.""" return self.runtime.execute(container_app, **kwargs) def run_wait_container(self, image_name, **kwargs): + """Run a container and wait for it to exit.""" cont_inst = self.run_container(image_name, **kwargs) cont_inst.wait(condition="excited") return self.runtime.get_logs(cont_inst) diff --git a/broker/providers/test_provider.py b/broker/providers/test_provider.py index 33344f0b..38df4b52 100644 --- a/broker/providers/test_provider.py +++ b/broker/providers/test_provider.py @@ -1,9 +1,11 @@ +"""A test provider for use in unit tests.""" import inspect -from broker import helpers + from dynaconf import Validator -from broker.settings import settings -from broker.providers import Provider +from broker import helpers +from broker.providers import Provider +from broker.settings import settings HOST_PROPERTIES = { "basic": { @@ -16,6 +18,8 @@ class TestProvider(Provider): + """Basic TestProvider class to test the Provider interface.""" + __test__ = False # don't use for testing hidden = True # hide from click command generation _validators = [Validator("TESTPROVIDER.foo", must_exist=True)] @@ -41,6 +45,7 @@ def _set_attributes(self, host_inst, broker_args=None): ) def construct_host(self, provider_params, host_classes, **kwargs): + """Construct a host object from the provider_params and kwargs.""" if provider_params: host_params = provider_params.copy() host_params.update(kwargs) @@ -52,6 +57,7 @@ def construct_host(self, provider_params, host_classes, **kwargs): @Provider.register_action() def test_action(self, **kwargs): + """A dummy action for testing.""" action = kwargs.get("test_action") if action == "release": return "released", kwargs @@ -60,15 +66,17 @@ def test_action(self, **kwargs): return HOST_PROPERTIES["basic"] def release(self, host_obj): + """Release a host ;) .""" return self.test_action(test_action="release", **host_obj.to_dict()) def extend(self): - pass + """No current implementation for this provider.""" def get_inventory(self, *args, **kwargs): + """Load a filtered local inventory.""" return helpers.load_inventory( - filter=f"_broker_provider={self.__class__.__name__}" + filter=f'@inv._broker_provider == "{self.__class__.__name__}"' ) - def nick_help(self): - pass + def provider_help(self): + """No current implementation for this provider.""" diff --git a/broker/session.py b/broker/session.py index 7a966136..61c14f8b 100644 --- a/broker/session.py +++ b/broker/session.py @@ -1,12 +1,23 @@ +"""Module providing classes to establish ssh or ssh-like connections to hosts. + +Classes: + Session - Wrapper around ssh2-python's auth/connection system. + InteractiveShell - Wrapper around ssh2-python's non-blocking channel system. + ContainerSession - Wrapper around docker-py's exec system. + +Note: You typically want to use a Host object instance to create sessions, + not these classes directly. +""" from contextlib import contextmanager -import os +from pathlib import Path import socket import tempfile -from pathlib import Path + from logzero import logger -from ssh2.session import Session as ssh2_Session from ssh2 import sftp as ssh2_sftp -from broker import helpers +from ssh2.session import Session as ssh2_Session + +from broker import exceptions, helpers SESSIONS = {} @@ -19,13 +30,24 @@ FILE_FLAGS = ssh2_sftp.LIBSSH2_FXF_CREAT | ssh2_sftp.LIBSSH2_FXF_WRITE -class AuthException(Exception): - pass - - class Session: + """Wrapper around ssh2-python's auth/connection system.""" + def __init__(self, **kwargs): - """Wrapper around ssh2-python's auth/connection system""" + """Initialize a Session object. + + kwargs: + hostname (str): The hostname or IP address of the remote host. Defaults to 'localhost'. + username (str): The username to authenticate with. Defaults to 'root'. + timeout (float): The timeout for the connection in seconds. Defaults to 60. + port (int): The port number to connect to. Defaults to 22. + key_filename (str): The path to the private key file to use for authentication. + password (str): The password to use for authentication. + + Raises: + AuthException: If no password or key file is provided. + FileNotFoundError: If the key file is not found. + """ host = kwargs.get("hostname", "localhost") user = kwargs.get("username", "root") sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -43,11 +65,11 @@ def __init__(self, **kwargs): elif kwargs.get("password"): self.session.userauth_password(user, kwargs["password"]) else: - raise AuthException("No password or key file provided.") + raise exceptions.AuthenticationError("No password or key file provided.") @staticmethod def _read(channel): - """read the contents of a channel""" + """Read the contents of a channel.""" size, data = channel.read() results = "" while size > 0: @@ -62,7 +84,7 @@ def _read(channel): ) def run(self, command, timeout=0): - """run a command on the host and return the results""" + """Run a command on the host and return the results.""" self.session.set_timeout(helpers.translate_timeout(timeout)) channel = self.session.open_session() channel.execute( @@ -75,15 +97,15 @@ def run(self, command, timeout=0): return results def shell(self, pty=False): - """Create and return an interactive shell instance""" + """Create and return an interactive shell instance.""" channel = self.session.open_session() return InteractiveShell(channel, pty) @contextmanager def tail_file(self, filename): - """Simulate tailing a file on the remote host + """Simulate tailing a file on the remote host. - example: + Example: with my_host.session.tail_file("/var/log/messages") as res: # do something that creates new messages print(res.stdout) @@ -97,7 +119,7 @@ def tail_file(self, filename): res.__dict__.update(result.__dict__) def sftp_read(self, source, destination=None, return_data=False): - """read a remote file into a local destination or return a bytes object if return_data is True""" + """Read a remote file into a local destination or return a bytes object if return_data is True.""" if not return_data: if not destination: destination = source @@ -111,15 +133,15 @@ def sftp_read(self, source, destination=None, return_data=False): with sftp.open( source, ssh2_sftp.LIBSSH2_FXF_READ, ssh2_sftp.LIBSSH2_SFTP_S_IRUSR ) as remote: - captured_data = bytes() - for rc, data in remote: + captured_data = b"" + for _rc, data in remote: captured_data += data if return_data: return captured_data destination.write_bytes(data) def sftp_write(self, source, destination=None, ensure_dir=True): - """sftp write a local file to a remote destination""" + """Sftp write a local file to a remote destination.""" if not destination: destination = source elif destination.endswith("/"): @@ -131,26 +153,26 @@ def sftp_write(self, source, destination=None, ensure_dir=True): with sftp.open(destination, FILE_FLAGS, SFTP_MODE) as remote: remote.write(data) - def remote_copy(self, source, dest_host, ensure_dir=True): - """Copy a file from this host to another""" + def remote_copy(self, source, dest_host, dest_path=None, ensure_dir=True): + """Copy a file from this host to another.""" + dest_path = dest_path or source sftp_down = self.session.sftp_init() sftp_up = dest_host.session.session.sftp_init() if ensure_dir: - dest_host.run(f"mkdir -p {Path(source).absolute().parent}") + dest_host.session.run(f"mkdir -p {Path(dest_path).absolute().parent}") with sftp_down.open( source, ssh2_sftp.LIBSSH2_FXF_READ, ssh2_sftp.LIBSSH2_SFTP_S_IRUSR - ) as download: - with sftp_up.open(source, FILE_FLAGS, SFTP_MODE) as upload: - for size, data in download: - upload.write(data) + ) as download, sftp_up.open(dest_path, FILE_FLAGS, SFTP_MODE) as upload: + for _size, data in download: + upload.write(data) def scp_write(self, source, destination=None, ensure_dir=True): - """scp write a local file to a remote destination""" + """SCP write a local file to a remote destination.""" if not destination: destination = source elif destination.endswith("/"): destination = destination + Path(source).name - fileinfo = os.stat(source) + fileinfo = (source := Path(source).stat()) chan = self.session.scp_send64( destination, fileinfo.st_mode & 0o777, @@ -160,19 +182,21 @@ def scp_write(self, source, destination=None, ensure_dir=True): ) if ensure_dir: self.run(f"mkdir -p {Path(destination).absolute().parent}") - with open(source, "rb") as local: + with source.open("rb") as local: for data in local: chan.write(data) def __enter__(self): + """Return the session object.""" return self def __exit__(self, *args): + """Close the session.""" self.session.disconnect() class InteractiveShell: - """A helper class that provides an interactive shell interface + """A helper class that provides an interactive shell interface. Preferred use of this class is via its context manager @@ -191,28 +215,29 @@ def __init__(self, channel, pty=False): self._chan.shell() def __enter__(self): + """Return the shell object.""" return self def __exit__(self, *exc_args): - """Close the channel and read stdout/stderr and status""" + """Close the channel and read stdout/stderr and status.""" self._chan.close() self.result = Session._read(self._chan) def __getattribute__(self, name): - """Expose non-duplicate attributes from the channel""" + """Expose non-duplicate attributes from the channel.""" try: return object.__getattribute__(self, name) except AttributeError: return getattr(self._chan, name) def send(self, cmd): - """Send a command to the channel, ensuring a newline character""" + """Send a command to the channel, ensuring a newline character.""" if not cmd.endswith("\n"): cmd += "\n" self._chan.write(cmd) def stdout(self): - """read the contents of a channel's stdout""" + """Read the contents of a channel's stdout.""" if not self._chan.eof(): _, data = self._chan.read(65535) results = data.decode("utf-8") @@ -226,13 +251,13 @@ def stdout(self): class ContainerSession: - """An approximation of ssh-based functionality from the Session class""" + """An approximation of ssh-based functionality from the Session class.""" def __init__(self, cont_inst): self._cont_inst = cont_inst def run(self, command, demux=True, **kwargs): - """This is the container approximation of Session.run""" + """Container approximation of Session.run.""" kwargs.pop("timeout", None) # Timeouts are set at the client level kwargs["demux"] = demux if "'" in command: @@ -241,7 +266,7 @@ def run(self, command, demux=True, **kwargs): tmp.seek(0) command = f"/bin/bash {tmp.name}" self.sftp_write(tmp.name) - if any([s in command for s in "|&><"]): + if any(s in command for s in "|&><"): # Containers don't handle pipes, redirects, etc well in a bare exec_run command = f"/bin/bash -c '{command}'" result = self._cont_inst._cont_inst.exec_run(command, **kwargs) @@ -252,12 +277,11 @@ def run(self, command, demux=True, **kwargs): return result def disconnect(self): - """Needed for simple compatability with Session""" - pass + """Needed for simple compatability with Session.""" @contextmanager def tail_file(self, filename): - """Simulate tailing a file on the remote host""" + """Simulate tailing a file on the remote host.""" initial_size = int(self.run(f"stat -c %s {filename}").stdout.strip()) yield (res := helpers.Result()) # get the contents of the file from the initial size to the end @@ -265,7 +289,7 @@ def tail_file(self, filename): res.__dict__.update(result.__dict__) def sftp_write(self, source, destination=None, ensure_dir=True): - """Add one of more files to the container""" + """Add one of more files to the container.""" # ensure source is a list of Path objects if not isinstance(source, list): source = [Path(source)] @@ -278,9 +302,7 @@ def sftp_write(self, source, destination=None, ensure_dir=True): destination = destination or f"{source[0].parent}/" # Files need to be added to a tarfile with helpers.temporary_tar(source) as tar: - logger.debug( - f"{self._cont_inst.hostname} adding file(s) {source} to {destination}" - ) + logger.debug(f"{self._cont_inst.hostname} adding file(s) {source} to {destination}") if ensure_dir: if destination.endswith("/"): self.run(f"mkdir -m 666 -p {destination}") @@ -289,7 +311,7 @@ def sftp_write(self, source, destination=None, ensure_dir=True): self._cont_inst._cont_inst.put_archive(str(destination), tar.read_bytes()) def sftp_read(self, source, destination=None, return_data=False): - """Get a file or directory from the container""" + """Get a file or directory from the container.""" destination = Path(destination or source) logger.debug(f"{self._cont_inst.hostname} getting file {source}") data, status = self._cont_inst._cont_inst.get_archive(source) @@ -310,16 +332,15 @@ def sftp_read(self, source, destination=None, return_data=False): destination.write_bytes(f.read()) else: logger.warning("More than one member was found in the tar file.") - tar.extractall( - destination.parent if destination.is_file() else destination - ) + tar.extractall(destination.parent if destination.is_file() else destination) def shell(self, pty=False): - """Create and return an interactive shell instance""" + """Create and return an interactive shell instance.""" raise NotImplementedError("ContainerSession.shell has not been implemented") def __enter__(self): + """Return the session object.""" return self def __exit__(self, *args): - pass + """Do nothing on exit.""" diff --git a/broker/settings.py b/broker/settings.py index 4bea0854..4396f8fa 100644 --- a/broker/settings.py +++ b/broker/settings.py @@ -1,15 +1,30 @@ -import os -import click +"""Broker settings module. + +Useful items: + settings: The settings object. + init_settings: Function to initialize the settings file. + validate_settings: Function to validate the settings file. + interactive_mode: Whether or not Broker is running in interactive mode. + BROKER_DIRECTORY: The directory where Broker looks for its files. + settings_path: The path to the settings file. + inventory_path: The path to the inventory file. +""" import inspect +import os from pathlib import Path + +import click from dynaconf import Dynaconf, Validator from dynaconf.validator import ValidationError + from broker.exceptions import ConfigurationError def init_settings(settings_path, interactive=False): """Initialize the broker settings file.""" - raw_url = "https://raw.githubusercontent.com/SatelliteQE/broker/master/broker_settings.yaml.example" + raw_url = ( + "https://raw.githubusercontent.com/SatelliteQE/broker/master/broker_settings.yaml.example" + ) if ( not interactive or click.prompt( @@ -22,7 +37,7 @@ def init_settings(settings_path, interactive=False): import requests click.echo(f"Downloading example file from: {raw_url}") - raw_file = requests.get(raw_url) + raw_file = requests.get(raw_url, timeout=60) settings_path.write_text(raw_file.text) if interative_mode: try: @@ -33,14 +48,12 @@ def init_settings(settings_path, interactive=False): fg="yellow", ) else: - raise ConfigurationError( - f"Broker settings file not found at {settings_path.absolute()}." - ) + raise ConfigurationError(f"Broker settings file not found at {settings_path.absolute()}.") interative_mode = False # GitHub action context -if not "GITHUB_WORKFLOW" in os.environ: +if "GITHUB_WORKFLOW" not in os.environ: # determine if we're being ran from a CLI for frame in inspect.stack()[::-1]: if "/bin/broker" in frame.filename: @@ -62,9 +75,7 @@ def init_settings(settings_path, interactive=False): inventory_path = BROKER_DIRECTORY.joinpath("inventory.yaml") if not settings_path.exists(): - click.secho( - f"Broker settings file not found at {settings_path.absolute()}.", fg="red" - ) + click.secho(f"Broker settings file not found at {settings_path.absolute()}.", fg="red") init_settings(settings_path, interactive=interative_mode) validators = [ @@ -104,6 +115,6 @@ def init_settings(settings_path, interactive=False): except ValidationError as err: raise ConfigurationError( f"Configuration error in {settings_path.absolute()}: {err.args[0]}" - ) + ) from err os.environ.update(vault_vars) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..c2e57d9a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,179 @@ +[build-system] +requires = ["setuptools", "setuptools-scm", "wheel", "twine"] +build-backend = "setuptools.build_meta" + +[project] +name = "broker" +description = "The infrastructure middleman." +readme = "README.md" +requires-python = ">=3.10" +license = {file = "LICENSE", name = "GNU General Public License v3"} +keywords = ["broker", "AnsibleTower", "docker", "podman", "beaker"] +authors = [ + {name = "Jacob J Callahan", email = "jacob.callahan05@gmail.com"} +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", + "Natural Language :: English", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] +dependencies = [ + "awxkit", + "click", + "dynaconf<4.0.0", + "logzero", + "pyyaml", + "setuptools", + "ssh2-python", +] +dynamic = ["version"] # dynamic fields to update on build - version via setuptools_scm + +[project.urls] +Repository = "https://github.com/SatelliteQE/broker" + +[project.optional-dependencies] +dev = [ + "pre-commit", + "pytest", + "ruff" +] +docker = [ + "docker", + "paramiko" +] +podman = ["podman-py"] +beaker = ["beaker-client"] + +[project.scripts] +broker = "broker.commands:cli" + +[tool.setuptools] +platforms = ["any"] +zip-safe = false +include-package-data = true + +[tool.setuptools.packages.find] +include = ["broker"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = ["-v", "-l", "--color=yes", "--code-highlight=yes"] + +[tool.black] +line-length = 100 +target-version = ["py310", "py311"] +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.venv + | build + | dist + | tests/data +)/ +''' + +[tool.ruff] +exclude = ["tests/"] +target-version = "py311" +fixable = ["ALL"] + +select = [ + "B002", # Python does not support the unary prefix increment + "B007", # Loop control variable {name} not used within loop body + "B009", # Do not call getattr with a constant attribute value + "B010", # Do not call setattr with a constant attribute value + "B011", # Do not `assert False`, raise `AssertionError` instead + "B013", # Redundant tuple in exception handler + "B014", # Exception handler with duplicate exception + "B023", # Function definition does not bind loop variable {name} + "B026", # Star-arg unpacking after a keyword argument is strongly discouraged + "BLE001", # Using bare except clauses is prohibited + "C", # complexity + "C4", # flake8-comprehensions + "COM818", # Trailing comma on bare tuple prohibited + "D", # docstrings + "E", # pycodestyle + "F", # pyflakes/autoflake + "G", # flake8-logging-format + "I", # isort + "ISC001", # Implicitly concatenated string literals on one line + "N804", # First argument of a class method should be named cls + "N805", # First argument of a method should be named self + "N815", # Variable {name} in class scope should not be mixedCase + "N999", # Invalid module name: '{name}' + "PERF", # Perflint rules + "PGH004", # Use specific rule codes when using noqa + "PLC0414", # Useless import alias. Import alias does not rename original package. + "PLC", # pylint + "PLE", # pylint + "PLR", # pylint + "PLW", # pylint + "PTH", # Use pathlib + "RUF", # Ruff-specific rules + "S103", # bad-file-permissions + "S108", # hardcoded-temp-file + "S110", # try-except-pass + "S112", # try-except-continue + "S113", # Probable use of requests call without timeout + "S306", # suspicious-mktemp-usage + "S307", # suspicious-eval-usage + "S601", # paramiko-call + "S602", # subprocess-popen-with-shell-equals-true + "S604", # call-with-shell-equals-true + "S609", # unix-command-wildcard-injection + "SIM105", # Use contextlib.suppress({exception}) instead of try-except-pass + "SIM117", # Merge with-statements that use the same scope + "SIM118", # Use {key} in {dict} instead of {key} in {dict}.keys() + "SIM201", # Use {left} != {right} instead of not {left} == {right} + "SIM208", # Use {expr} instead of not (not {expr}) + "SIM212", # Use {a} if {a} else {b} instead of {b} if not {a} else {a} + "SIM300", # Yoda conditions. Use 'age == 42' instead of '42 == age'. + "SIM401", # Use get from dict with default instead of an if block + "T100", # Trace found: {name} used + "T20", # flake8-print + "TRY004", # Prefer TypeError exception for invalid type + "TRY200", # Use raise from to specify exception cause + "TRY302", # Remove exception handler; error is immediately re-raised + "PLR0911", # Too many return statements ({returns} > {max_returns}) + "PLR0912", # Too many branches ({branches} > {max_branches}) + "PLR0915", # Too many statements ({statements} > {max_statements}) + "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable + "PLW2901", # Outer {outer_kind} variable {name} overwritten by inner {inner_kind} target + "UP", # pyupgrade + "W", # pycodestyle +] + +ignore = [ + "ANN", # flake8-annotations + "PGH001", # No builtin eval() allowed + "D203", # 1 blank line required before class docstring + "D213", # Multi-line docstring summary should start at the second line + "D406", # Section name should end with a newline + "D407", # Section name underlining + "E501", # line too long + "E731", # do not assign a lambda expression, use a def + "PLR0913", # Too many arguments to function call ({c_args} > {max_args}) + "RUF012", # Mutable class attributes should be annotated with typing.ClassVar + "D107", # Missing docstring in __init__ +] + +[tool.ruff.flake8-pytest-style] +fixture-parentheses = false + +[tool.ruff.isort] +force-sort-within-sections = true +known-first-party = [ + "broker", +] +combine-as-imports = true + +[tool.ruff.per-file-ignores] +# None at this time + +[tool.ruff.mccabe] +max-complexity = 25 diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index e2c5a324..00000000 --- a/pytest.ini +++ /dev/null @@ -1,3 +0,0 @@ -[pytest] -testpaths = tests -addopts = -v diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index a060a02b..00000000 --- a/setup.cfg +++ /dev/null @@ -1,50 +0,0 @@ -[metadata] -name = broker -description = The infrastructure middleman. -long_description = file: README.md -long_description_content_type = text/markdown -author = Jacob J Callahan -author_email = jacob.callahan05@gmail.com -url = https://github.com/SatelliteQE/broker -license = GNU General Public License v3 -keywords = broker, AnsibleTower -classifiers = - Development Status :: 4 - Beta - Intended Audience :: Developers - License :: OSI Approved :: GNU General Public License v3 (GPLv3) - Natural Language :: English - Programming Language :: Python :: 3 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Programming Language :: Python :: 3.11 - -[options] -install_requires = - awxkit - click - dynaconf<3.2.1 - logzero - pyyaml - setuptools - ssh2-python -packages = find: -zip_safe = False - -[options.extras_require] -test = pytest -setup = - setuptools - setuptools-scm - wheel - twine -docker = - docker - paramiko -podman = podman-py - -[options.entry_points] -console_scripts = - broker = broker.commands:cli - -[flake8] -max-line-length = 110 diff --git a/setup.py b/setup.py deleted file mode 100644 index 76755a44..00000000 --- a/setup.py +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/env python -from setuptools import setup - -setup(use_scm_version=True) diff --git a/tests/data/beaker/job_result.json b/tests/data/beaker/job_result.json new file mode 100644 index 00000000..104c0317 --- /dev/null +++ b/tests/data/beaker/job_result.json @@ -0,0 +1,68 @@ +{ + "id": "1234567", + "owner": "test_user@testdom.com", + "result": "Pass", + "status": "Running", + "retention_tag": "scratch", + "whiteboard": { + "text": "Test Reserve Workflow" + }, + "recipeSet": { + "priority": "Normal", + "response": "ack", + "id": "7654321", + "recipe": { + "id": "1029384", + "job_id": "1234567", + "recipe_set_id": "7654321", + "result": "Pass", + "status": "Running", + "distro": "RHEL-8.7.0", + "arch": "x86_64", + "family": "RedHatEnterpriseLinux8", + "variant": "BaseOS", + "system": "fake.host.testdom.com", + "distroRequires": { + "and": { + "distro_family": { + "op": "=", + "value": "RedHatEnterpriseLinux8" + }, + "distro_variant": { + "op": "=", + "value": "BaseOS" + }, + "distro_name": { + "op": "=", + "value": "RHEL-8.7.0" + }, + "distro_arch": { + "op": "=", + "value": "x86_64" + } + } + }, + "hostRequires": { + "system_type": { + "value": "Machine" + } + }, + "logs": { + "log": [ + { + "href": "https: //beakerhost.testdom.com/recipes/1029384/logs/console.log", + "name": "console.log" + }, + { + "href": "https: //beakerhost.testdom.com/recipes/1029384/logs/anaconda.log", + "name": "anaconda.log" + }, + { + "href": "https: //beakerhost.testdom.com/recipes/1029384/logs/sys.log", + "name": "sys.log" + } + ] + } + } + } +} diff --git a/tests/data/beaker/test_job.xml b/tests/data/beaker/test_job.xml new file mode 100644 index 00000000..8a3212cb --- /dev/null +++ b/tests/data/beaker/test_job.xml @@ -0,0 +1,26 @@ + + Test Reserve Workflow + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/cli_scenarios/beaker/checkout_test_job-2.yaml b/tests/data/cli_scenarios/beaker/checkout_test_job-2.yaml new file mode 100644 index 00000000..3b392214 --- /dev/null +++ b/tests/data/cli_scenarios/beaker/checkout_test_job-2.yaml @@ -0,0 +1,2 @@ +job_xml: "tests/data/beaker/test_job.xml" +count: 2 diff --git a/tests/data/cli_scenarios/beaker/checkout_test_job.yaml b/tests/data/cli_scenarios/beaker/checkout_test_job.yaml new file mode 100644 index 00000000..104525f2 --- /dev/null +++ b/tests/data/cli_scenarios/beaker/checkout_test_job.yaml @@ -0,0 +1 @@ +job_xml: "tests/data/beaker/test_job.xml" diff --git a/tests/data/cli_scenarios/satlab/checkout_latest_sat.yaml b/tests/data/cli_scenarios/satlab/checkout_latest_sat.yaml index fae87fc0..60ad63bc 100644 --- a/tests/data/cli_scenarios/satlab/checkout_latest_sat.yaml +++ b/tests/data/cli_scenarios/satlab/checkout_latest_sat.yaml @@ -1 +1 @@ -workflow: deploy-sat-jenkins +workflow: deploy-satellite diff --git a/tests/data/cli_scenarios/satlab/checkout_rhel78.yaml b/tests/data/cli_scenarios/satlab/checkout_rhel91.yaml similarity index 50% rename from tests/data/cli_scenarios/satlab/checkout_rhel78.yaml rename to tests/data/cli_scenarios/satlab/checkout_rhel91.yaml index b0579213..1754d032 100644 --- a/tests/data/cli_scenarios/satlab/checkout_rhel78.yaml +++ b/tests/data/cli_scenarios/satlab/checkout_rhel91.yaml @@ -1,2 +1,2 @@ workflow: deploy-base-rhel -deploy_rhel_version: "7.8" +deploy_rhel_version: "9.1" diff --git a/tests/data/cli_scenarios/satlab/checkout_sat_611.yaml b/tests/data/cli_scenarios/satlab/checkout_sat_611.yaml new file mode 100644 index 00000000..98bdeb60 --- /dev/null +++ b/tests/data/cli_scenarios/satlab/checkout_sat_611.yaml @@ -0,0 +1,2 @@ +workflow: deploy-satellite +deploy_sat_version: "6.11" diff --git a/tests/data/cli_scenarios/satlab/checkout_sat_69.yaml b/tests/data/cli_scenarios/satlab/checkout_sat_69.yaml deleted file mode 100644 index 754928a2..00000000 --- a/tests/data/cli_scenarios/satlab/checkout_sat_69.yaml +++ /dev/null @@ -1,2 +0,0 @@ -workflow: deploy-sat-jenkins -deploy_sat_version: "6.9" diff --git a/tests/data/cli_scenarios/satlab/execute_test_workflow.yaml b/tests/data/cli_scenarios/satlab/execute_test_workflow.yaml deleted file mode 100644 index 2995cde5..00000000 --- a/tests/data/cli_scenarios/satlab/execute_test_workflow.yaml +++ /dev/null @@ -1,3 +0,0 @@ -workflow: test-workflow -artifacts: last -output-format: yaml diff --git a/tests/functional/README.md b/tests/functional/README.md index be141af4..4e48a9c2 100644 --- a/tests/functional/README.md +++ b/tests/functional/README.md @@ -17,3 +17,10 @@ Setup: - Make sure you have room for at least 4 hosts in your current SLA limit. Note: These tests take a while to run, up to around 45m. + +**Beaker Tests** + +Setup: +- Ensure you have setup both your Beaker and Kerberos config +- Ensure that your host_username and host_password match what's expected from the Beaker host. +- Tests are currently limited, but still take a while to run. Run times are dependent on Beaker availability. diff --git a/tests/functional/test_containers.py b/tests/functional/test_containers.py index 4d6e49c9..61cd5d91 100644 --- a/tests/functional/test_containers.py +++ b/tests/functional/test_containers.py @@ -23,9 +23,7 @@ def temp_inventory(): """Temporarily move the local inventory, then move it back when done""" backup_path = inventory_path.rename(f"{inventory_path.absolute()}.bak") yield - CliRunner().invoke( - cli, ["checkin", "--all", "--filter", "_broker_provider {tailed_file}") @@ -116,9 +108,7 @@ def test_container_e2e_mp(): res = c_host.execute(f"ls {remote_dir}") assert str(loc_settings_path) in res.stdout with NamedTemporaryFile() as tmp: - c_host.session.sftp_read( - f"{remote_dir}/{loc_settings_path.name}", tmp.file.name - ) + c_host.session.sftp_read(f"{remote_dir}/{loc_settings_path.name}", tmp.file.name) data = c_host.session.sftp_read( f"{remote_dir}/{loc_settings_path.name}", return_data=True ) @@ -128,9 +118,7 @@ def test_container_e2e_mp(): assert ( loc_settings_path.read_bytes() == data ), "Local file is different from the received one (return_data=True)" - assert ( - data == Path(tmp.file.name).read_bytes() - ), "Received files do not match" + assert data == Path(tmp.file.name).read_bytes(), "Received files do not match" def test_broker_multi_manager(): @@ -139,7 +127,9 @@ def test_broker_multi_manager(): ubi8={"container_host": "ubi8:latest", "_count": 2}, ubi9={"container_host": "ubi9:latest"}, ) as multi_hosts: - assert "ubi7" in multi_hosts and "ubi8" in multi_hosts and "ubi9" in multi_hosts + assert "ubi7" in multi_hosts + assert "ubi8" in multi_hosts + assert "ubi9" in multi_hosts assert len(multi_hosts["ubi8"]) == 2 assert multi_hosts["ubi7"][0]._cont_inst.top()["Processes"] assert ( diff --git a/tests/functional/test_rh_beaker.py b/tests/functional/test_rh_beaker.py new file mode 100644 index 00000000..5aa39daf --- /dev/null +++ b/tests/functional/test_rh_beaker.py @@ -0,0 +1,92 @@ +from pathlib import Path +from tempfile import NamedTemporaryFile +import pytest +from click.testing import CliRunner +from broker import Broker +from broker.commands import cli +from broker.providers.beaker import Beaker +from broker.settings import inventory_path, settings_path + +SCENARIO_DIR = Path("tests/data/cli_scenarios/beaker") + + +@pytest.fixture(scope="module", autouse=True) +def skip_if_not_configured(): + try: + Beaker() + except Exception as err: + pytest.skip(f"Beaker is not configured correctly: {err}") + + +@pytest.fixture(scope="module") +def temp_inventory(): + """Temporarily move the local inventory, then move it back when done""" + backup_path = inventory_path.rename(f"{inventory_path.absolute()}.bak") + yield + CliRunner().invoke(cli, ["checkin", "--all", "--filter", '@inv._broker_provider == "Beaker"']) + inventory_path.unlink() + backup_path.rename(inventory_path) + + +@pytest.mark.parametrize( + "args_file", [f for f in SCENARIO_DIR.iterdir() if f.name.startswith("checkout_")] +) +def test_checkout_scenarios(args_file, temp_inventory): + result = CliRunner().invoke(cli, ["checkout", "--args-file", args_file]) + assert result.exit_code == 0 + + +# @pytest.mark.parametrize( +# "args_file", [f for f in SCENARIO_DIR.iterdir() if f.name.startswith("execute_")] +# ) +# def test_execute_scenarios(args_file): +# result = CliRunner().invoke(cli, ["execute", "--args-file", args_file]) +# assert result.exit_code == 0 + + +def test_inventory_sync(): + result = CliRunner().invoke(cli, ["inventory", "--sync", "Beaker"]) + assert result.exit_code == 0 + + +def test_jobs_list(): + result = CliRunner(mix_stderr=False).invoke(cli, ["providers", "Beaker", "--jobs", "--mine"]) + assert result.exit_code == 0 + + +# def test_job_query(): +# """This isn't possible until we can figure out how to capture logged output""" +# result = CliRunner().invoke( +# cli, ["providers", "Beaker", "--job", ""] +# ) +# assert result.exit_code == 0 + + +# ----- Broker API Tests ----- + + +def test_beaker_host(): + with Broker(job_xml="tests/data/beaker/test_job.xml") as r_host: + res = r_host.execute("hostname") + assert res.stdout.strip() == r_host.hostname + remote_dir = "/tmp/fake" + r_host.session.sftp_write(str(settings_path.absolute()), f"{remote_dir}/") + res = r_host.execute(f"ls {remote_dir}") + assert str(settings_path.name) in res.stdout + with NamedTemporaryFile() as tmp: + r_host.session.sftp_read(f"{remote_dir}/{settings_path.name}", tmp.file.name) + data = r_host.session.sftp_read(f"{remote_dir}/{settings_path.name}", return_data=True) + assert ( + settings_path.read_bytes() == Path(tmp.file.name).read_bytes() + ), "Local file is different from the received one" + assert ( + settings_path.read_bytes() == data + ), "Local file is different from the received one (return_data=True)" + assert data == Path(tmp.file.name).read_bytes(), "Received files do not match" + # test the tail_file context manager + tailed_file = f"{remote_dir}/tail_me.txt" + r_host.execute(f"echo 'hello world' > {tailed_file}") + with r_host.session.tail_file(tailed_file) as tf: + r_host.execute(f"echo 'this is a new line' >> {tailed_file}") + assert "this is a new line" in tf.stdout + assert "hello world" not in tf.stdout diff --git a/tests/functional/test_satlab.py b/tests/functional/test_satlab.py index a8be4e3a..dc48736c 100644 --- a/tests/functional/test_satlab.py +++ b/tests/functional/test_satlab.py @@ -23,9 +23,7 @@ def temp_inventory(): """Temporarily move the local inventory, then move it back when done""" backup_path = inventory_path.rename(f"{inventory_path.absolute()}.bak") yield - CliRunner().invoke( - cli, ["checkin", "--all", "--filter", "_broker_provider {tailed_file}") with r_host.session.tail_file(tailed_file) as tf: r_host.execute(f"echo 'this is a new line' >> {tailed_file}") - assert 'this is a new line' in tf.stdout - assert 'hello world' not in tf.stdout + assert "this is a new line" in tf.stdout + assert "hello world" not in tf.stdout def test_tower_host_mp(): @@ -117,3 +114,11 @@ def test_tower_host_mp(): loc_settings_path.read_bytes() == data ), "Local file is different from the received one (return_data=True)" assert data == Path(tmp.file.name).read_bytes(), "Received files do not match" + # test remote copy from one host to another + r_hosts[0].session.remote_copy( + source=f"{remote_dir}/{loc_settings_path.name}", + dest_host=r_hosts[1], + dest_path=f"/root/{loc_settings_path.name}", + ) + res = r_hosts[1].execute(f"ls /root") + assert loc_settings_path.name in res.stdout diff --git a/tests/providers/test_ansible_tower.py b/tests/providers/test_ansible_tower.py index b50f6890..957015c4 100644 --- a/tests/providers/test_ansible_tower.py +++ b/tests/providers/test_ansible_tower.py @@ -72,19 +72,19 @@ def pop(self, item=None): return super().pop(item) -@pytest.fixture(scope="function") +@pytest.fixture def api_stub(): - yield AwxkitApiStub() + return AwxkitApiStub() -@pytest.fixture(scope="function") +@pytest.fixture def config_stub(): - yield MockStub() + return MockStub() -@pytest.fixture(scope="function") +@pytest.fixture def tower_stub(api_stub, config_stub): - yield AnsibleTower(root=api_stub, config=config_stub) + return AnsibleTower(root=api_stub, config=config_stub) def test_execute(tower_stub): diff --git a/tests/providers/test_beaker.py b/tests/providers/test_beaker.py new file mode 100644 index 00000000..7f2d1908 --- /dev/null +++ b/tests/providers/test_beaker.py @@ -0,0 +1,71 @@ +import json +import pytest +from pathlib import Path +from broker.providers.beaker import Beaker +from broker.binds.beaker import _curate_job_info +from broker.helpers import MockStub +from broker.hosts import Host + + +class BeakerBindStub(MockStub): + """This class stubs out the methods of the Beaker bind + + stubbing for: + - self.runtime.jobid_from_system(caller_host.hostname) + - self.runtime.release(caller_host.hostname, job_id) # no-op + - self.runtime.system_details_curated(host) + - self.runtime.execute_job(job_xml, max_wait) + - self.runtime.job_clone(job, prettyxml=True, dryrun=True).stdout + - self.runtime.system_release(host_name) # no-op + - self.runtime.job_cancel(job_id) # no-op + - self.runtime.job_list(**kwargs).stdout.splitlines() + - self.runtime.user_systems() + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.job_id = "1234567" + self.stdout = "1234567\n7654321\n" + + def jobid_from_system(self, hostname): + return self.job_id + + def system_details_curated(self, host): + return { + "hostname": "test.example.com", + "job_id": self.job_id, + "mac_address": "00:00:00:00:00:00", + "owner": "testuser ", + "id": "7654321", + "reservation_id": "1267", + "reserved_on": "2023-01-01 00:00:00", + "expires_on": "2025-01-01 00:00:00", + "reserved_for": "anotheruser ", + } + + def execute_job(self, job_xml, max_wait): + return _curate_job_info(json.loads(Path("tests/data/beaker/job_result.json").read_text())) + + def user_systems(self): + return ["test.example.com", "test2.example.com"] + + +@pytest.fixture +def bind_stub(): + return BeakerBindStub() + + +@pytest.fixture +def beaker_stub(bind_stub): + return Beaker(bind=bind_stub) + + +def test_empty_init(): + assert Beaker() + + +def test_host_creation(beaker_stub): + job_res = beaker_stub.submit_job("tests/data/beaker/test_job.xml") + host = beaker_stub.construct_host(job_res, {"host": Host}) + assert isinstance(host, Host) + assert host.hostname == "fake.host.testdom.com" diff --git a/tests/providers/test_container.py b/tests/providers/test_container.py index e406adf3..9ab3aae6 100644 --- a/tests/providers/test_container.py +++ b/tests/providers/test_container.py @@ -22,9 +22,7 @@ class ContainerApiStub(MockStub): def __init__(self, **kwargs): in_dict = { "images": [MockStub({"tags": "ch-d:ubi8"})], # self.runtime.images - "containers": [ - MockStub({"tags": "f37d3058317f"}) - ], # self.runtime.containers + "containers": [MockStub({"tags": "f37d3058317f"})], # self.runtime.containers "name": "f37d3058317f", # self.runtime.get_attrs(cont_inst)["name"] } if "job_id" in kwargs: @@ -55,14 +53,14 @@ def create_container(self, container_host, **kwargs): return MockStub(container) -@pytest.fixture(scope="function") +@pytest.fixture def api_stub(): - yield ContainerApiStub() + return ContainerApiStub() -@pytest.fixture(scope="function") +@pytest.fixture def container_stub(api_stub): - yield Container(bind=api_stub) + return Container(bind=api_stub) def test_empty_init(): diff --git a/tests/test_broker.py b/tests/test_broker.py index 899f814e..a15427c5 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -1,8 +1,17 @@ -from broker import broker, Broker, helpers +from broker import broker, Broker, helpers, settings from broker.providers import test_provider import pytest +@pytest.fixture(scope="module") +def temp_inventory(): + """Temporarily move the local inventory, then move it back when done""" + backup_path = settings.inventory_path.rename(f"{settings.inventory_path.absolute()}.bak") + yield + settings.inventory_path.unlink() + backup_path.rename(settings.inventory_path) + + def test_empty_init(): """Broker should be able to init without any arguments""" broker_inst = Broker() @@ -50,25 +59,23 @@ def test_broker_empty_checkin(): broker_inst.checkin() -def test_broker_checkin_n_sync_empty_hostname(): +def test_broker_checkin_n_sync_empty_hostname(temp_inventory): """Test that broker can checkin and sync inventory with a host that has empty hostname""" broker_inst = broker.Broker(nick="test_nick") broker_inst.checkout() - inventory = helpers.load_inventory() + inventory = helpers.load_inventory(filter='@inv._broker_provider == "TestProvider"') assert len(inventory) == 1 inventory[0]["hostname"] = None # remove the host from the inventory helpers.update_inventory(remove="test.host.example.com") # add the host back with no hostname helpers.update_inventory(add=inventory) - hosts = broker_inst.from_inventory() + hosts = broker_inst.from_inventory(filter='@inv._broker_provider == "TestProvider"') assert len(hosts) == 1 assert hosts[0].hostname is None broker_inst = broker.Broker(hosts=hosts) broker_inst.checkin() - assert ( - not broker_inst.from_inventory() - ), "Host was not removed from inventory after checkin" + assert not broker_inst.from_inventory(), "Host was not removed from inventory after checkin" def test_mp_checkout(): @@ -106,7 +113,8 @@ def test_multi_manager(): with Broker.multi_manager( test_1={"nick": "test_nick"}, test_2={"nick": "test_nick", "_count": 2} ) as host_dict: - assert "test_1" in host_dict and "test_2" in host_dict + assert "test_1" in host_dict + assert "test_2" in host_dict assert len(host_dict["test_1"]) == 1 assert len(host_dict["test_2"]) == 2 assert host_dict["test_1"][0].hostname == "test.host.example.com" diff --git a/tests/test_helpers.py b/tests/test_helpers.py index b6355a2c..85c6a040 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -150,3 +150,16 @@ def test_eval_filter_chain(fake_inventory): """Test that a user can chain multiple filters together""" filtered = helpers.eval_filter(fake_inventory, "@inv[:3] | 'sat-jenkins' in @inv.name") assert len(filtered) == 1 + + +def test_dict_from_paths_nested(): + source_dict = { + "person": { + "name": "John", + "age": 30, + "address": {"street": "123 Main St", "city": "Anytown", "state": "CA", "zip": "12345"}, + } + } + paths = {"person_name": "person/name", "person_zip": "person/address/zip"} + result = helpers.dict_from_paths(source_dict, paths) + assert result == {"person_name": "John", "person_zip": "12345"} diff --git a/tests/test_settings.py b/tests/test_settings.py index 1926babf..771c4c6b 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -36,9 +36,7 @@ def test_nested_envar(set_envars): assert test_provider.foo == "baz" -@pytest.mark.parametrize( - "set_envars", [("BROKER_TESTPROVIDER__foo", "envar")], indirect=True -) +@pytest.mark.parametrize("set_envars", [("BROKER_TESTPROVIDER__foo", "envar")], indirect=True) def test_default_envar(set_envars): """Set a top-level instance value via environment variable then verify that the value is not overriden when the provider is selected by default. @@ -48,9 +46,7 @@ def test_default_envar(set_envars): assert test_provider.foo == "envar" -@pytest.mark.parametrize( - "set_envars", [("BROKER_TESTPROVIDER__foo", "override me")], indirect=True -) +@pytest.mark.parametrize("set_envars", [("BROKER_TESTPROVIDER__foo", "override me")], indirect=True) def test_nondefault_envar(set_envars): """Set a top-level instance value via environment variable then verify that the value has been overriden when the provider is specified. @@ -60,9 +56,7 @@ def test_nondefault_envar(set_envars): assert test_provider.foo == "baz" -@pytest.mark.parametrize( - "set_envars", [("VAULT_ENABLED_FOR_DYNACONF", "1")], indirect=True -) +@pytest.mark.parametrize("set_envars", [("VAULT_ENABLED_FOR_DYNACONF", "1")], indirect=True) def test_purge_vault_envars(set_envars): """Set dynaconf vault envars and verify that they have no effect""" sys.modules.pop("broker.settings")