diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml deleted file mode 100644 index 80bf3eb2..00000000 --- a/.github/workflows/pre-commit.yml +++ /dev/null @@ -1,54 +0,0 @@ ---- -# Copyright 2024 "Google LLC" -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: 'Use pre-commit to validate Pull Request' - -# yamllint disable-line rule:truthy -on: - pull_request: - types: - - edited - - opened - - labeled - - synchronize - branches: - - master - - v5 - -jobs: - pre-commit: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: '3.10' - check-latest: true - cache: 'pip' - - uses: terraform-linters/setup-tflint@v4 - with: - tflint_version: v0.49.0 - - run: tflint --init - env: - # https://github.com/terraform-linters/tflint/blob/master/docs/user-guide/plugins.md#avoiding-rate-limiting - GITHUB_TOKEN: ${{ github.token }} - - uses: actions/setup-go@v5 - with: - go-version: '1.23' - check-latest: true - - run: go install github.com/terraform-docs/terraform-docs@latest - - uses: pre-commit/action@v3.0.1 - with: - extra_args: --show-diff-on-failure --all-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 10aa9600..2174bfc2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,28 +46,6 @@ repos: language: system types: [file, text] pass_filenames: false -- repo: https://github.com/antonbabenko/pre-commit-terraform - rev: v1.83.4 - hooks: - - id: terraform_fmt - - id: terraform_validate - - id: terraform_tflint - - id: terraform_docs - args: - - --args=--config=.terraform-docs.yaml - - --hook-config=--create-file-if-not-exist=true - - --hook-config=--path-to-file=README_TF.md -- repo: https://github.com/psf/black - rev: 23.9.1 - hooks: - - id: black - exclude: ^dm/ - language_version: python3 -- repo: https://github.com/pycqa/flake8 - rev: 6.1.0 - hooks: - - id: flake8 - exclude: ^dm/ - repo: https://github.com/codespell-project/codespell rev: v2.2.5 hooks: diff --git a/ansible/roles/slurm/tasks/main.yml b/ansible/roles/slurm/tasks/main.yml index 9a3e8989..b62c1bf0 100644 --- a/ansible/roles/slurm/tasks/main.yml +++ b/ansible/roles/slurm/tasks/main.yml @@ -57,33 +57,6 @@ src: tmpfiles.d/slurm.conf.j2 dest: /etc/tmpfiles.d/slurm.conf -- name: Copy Scripts - copy: - src: scripts/{{item}} - dest: '{{slurm_paths.scripts}}/{{item}}' - owner: '{{slurm_user.user}}' - group: '{{slurm_user.group}}' - mode: 0o755 - with_items: - - conf.py - - resume.py - - setup.py - - setup_network_storage.py - - startup.sh - - slurmsync.py - - suspend.py - - util.py - - load_bq.py - -- name: Copy slurm_gcp_plugins - copy: - src: scripts/slurm_gcp_plugins - dest: '{{ slurm_paths.scripts }}/' - owner: '{{slurm_user.user}}' - group: '{{slurm_user.group}}' - mode: '0644' - directory_mode: '0755' - - name: Copy Jobs copy: src: jobs/ diff --git a/etc/cgroup.conf.tpl b/etc/cgroup.conf.tpl deleted file mode 100644 index ffeb167c..00000000 --- a/etc/cgroup.conf.tpl +++ /dev/null @@ -1,7 +0,0 @@ -# cgroup.conf -# https://slurm.schedmd.com/cgroup.conf.html - -ConstrainCores=yes -ConstrainRamSpace=yes -ConstrainSwapSpace=no -ConstrainDevices=yes diff --git a/etc/job_submit.lua.tpl b/etc/job_submit.lua.tpl deleted file mode 100644 index 5cf8ddb7..00000000 --- a/etc/job_submit.lua.tpl +++ /dev/null @@ -1,102 +0,0 @@ -SCRIPTS_DIR = "{scripts_dir}" -NO_VAL = 4294967294 ---util.py exit code -PART_INVALID = -1 --partition does not exists in config.yaml, thus do not exist in slurm -DIFF_VMCOUNTS_SAME_PART = -2 --in the same partition there are nodesets with different vmcounts -DIFF_PART_DIFFERENT_VMCOUNTS = -3 --partition is a list of partitions in which at least two of them have different vmcount -UNKWOWN_ERROR = -4 --util.py did not return a valid response - -function get_part(job_desc,part_list) - if job_desc.partition then - return job_desc.partition - end - for name,val in pairs(part_list) do - if val.flag_default == 1 then - return name - end - end - return nil -end - -function os.capture(cmd, raw) - local handle = assert(io.popen(cmd, 'r')) - local output = assert(handle:read('*a')) - handle:close() - return output -end - -function get_vmcount(part) - local cmd = SCRIPTS_DIR .. "/util.py -p " .. part - local out = os.capture(cmd,true) - for line in out:gmatch("(.-)\r?\n") do - local tag, val = line:match("([^:]+):([^:]+)") - if tag == "VMCOUNT" then - return tonumber(val) - end - end - return UNKWOWN_ERROR -end - - -function slurm_job_submit(job_desc, part_list, submit_uid) - local part = get_part(job_desc,part_list) - local vmcount = get_vmcount(part) - --Only do something if the job is in a TPU partition, if vmcount is 0, it implies that the partition(s) specified are not TPU ones - if vmcount == 0 then - return slurm.SUCCESS - end - --This is a TPU job, but as the vmcount is 1 it can he handled the same way - if vmcount == 1 then - return slurm.SUCCESS - end - --Check for errors - if vmcount == PART_INVALID then - slurm.log_user("Invalid partition specified " .. part) - return slurm.FAILURE - end - if vmcount == DIFF_VMCOUNTS_SAME_PART then - slurm.log_user("In partition(s) " .. part .. " there are more than one tpu nodeset vmcount, this should not happen.") - return slurm.ERROR - end - if vmcount == DIFF_PART_DIFFERENT_VMCOUNTS then - slurm.log_user("In partition list " .. part .. " there are more than one TPU types, cannot determine which is the correct vmcount to use, please retry with only one partition.") - return slurm.FAILURE - end - if vmcount == UNKWOWN_ERROR then - slurm.log_user("Something went wrong while executing util.py to get the vmcount.") - return slurm.ERROR - end - --This is surely a TPU node - if vmcount > 1 then - local min_nodes = job_desc.min_nodes - local max_nodes = job_desc.max_nodes - --if not specified assume it is one, this should be improved taking into account the cpus, mem, and other factors - if min_nodes == NO_VAL then - min_nodes = 1 - max_nodes = 1 - end - --as max_nodes can be higher than the nodes in the partition, we are not able to calculate with certainty the nodes that this job will have if this value is set to something - --different than min_nodes - if min_nodes ~= max_nodes then - slurm.log_user("Max nodes cannot be set different than min nodes for the TPU partitions.") - return slurm.ERROR - end - --Set the number of switches to the number of nodes originally requested by the job, as the job requests "TPU groups" - job_desc.req_switch = min_nodes - - --Apply the node increase into the job description. - job_desc.min_nodes = min_nodes * vmcount - job_desc.max_nodes = max_nodes * vmcount - --if job_desc.features then - --slurm.log_user("Features: %s",job_desc.features) - --end - end - - return slurm.SUCCESS -end - -function slurm_job_modify(job_desc, job_rec, part_list, modify_uid) - return slurm.SUCCESS -end - -return slurm.SUCCESS diff --git a/etc/slurm.conf.tpl b/etc/slurm.conf.tpl deleted file mode 100644 index 7d32bed8..00000000 --- a/etc/slurm.conf.tpl +++ /dev/null @@ -1,67 +0,0 @@ -# slurm.conf -# https://slurm.schedmd.com/slurm.conf.html -# https://slurm.schedmd.com/configurator.html - -ProctrackType=proctrack/cgroup -SlurmctldPidFile=/var/run/slurm/slurmctld.pid -SlurmdPidFile=/var/run/slurm/slurmd.pid -TaskPlugin=task/affinity,task/cgroup -MaxNodeCount=64000 - -# -# -# SCHEDULING -SchedulerType=sched/backfill -SelectType=select/cons_tres -SelectTypeParameters=CR_Core_Memory - -# -# -# LOGGING AND ACCOUNTING -AccountingStoreFlags=job_comment -JobAcctGatherFrequency=30 -JobAcctGatherType=jobacct_gather/cgroup -SlurmctldDebug=info -SlurmdDebug=info -DebugFlags=Power - -# -# -# TIMERS -MessageTimeout=60 - -################################################################################ -# vvvvv WARNING: DO NOT MODIFY SECTION BELOW vvvvv # -################################################################################ - -SlurmctldHost={control_host}({control_addr}) - -AuthType=auth/munge -AuthInfo=cred_expire=120 -AuthAltTypes=auth/jwt -CredType=cred/munge -MpiDefault={mpi_default} -ReturnToService=2 -SlurmctldPort={control_host_port} -SlurmdPort=6818 -SlurmdSpoolDir=/var/spool/slurmd -SlurmUser=slurm -StateSaveLocation={state_save} - -# -# -# LOGGING AND ACCOUNTING -AccountingStorageType=accounting_storage/slurmdbd -AccountingStorageHost={control_host} -ClusterName={name} -SlurmctldLogFile={slurmlog}/slurmctld.log -SlurmdLogFile={slurmlog}/slurmd-%n.log - -# -# -# GENERATED CLOUD CONFIGURATIONS -include cloud.conf - -################################################################################ -# ^^^^^ WARNING: DO NOT MODIFY SECTION ABOVE ^^^^^ # -################################################################################ diff --git a/etc/slurmdbd.conf.tpl b/etc/slurmdbd.conf.tpl deleted file mode 100644 index ba06f28b..00000000 --- a/etc/slurmdbd.conf.tpl +++ /dev/null @@ -1,31 +0,0 @@ -# slurmdbd.conf -# https://slurm.schedmd.com/slurmdbd.conf.html - -DebugLevel=info -PidFile=/var/run/slurm/slurmdbd.pid - -################################################################################ -# vvvvv WARNING: DO NOT MODIFY SECTION BELOW vvvvv # -################################################################################ - -AuthType=auth/munge -AuthAltTypes=auth/jwt -AuthAltParameters=jwt_key={state_save}/jwt_hs256.key - -DbdHost={control_host} - -LogFile={slurmlog}/slurmdbd.log - -SlurmUser=slurm - -StorageLoc={db_name} - -StorageType=accounting_storage/mysql -StorageHost={db_host} -StoragePort={db_port} -StorageUser={db_user} -StoragePass={db_pass} - -################################################################################ -# ^^^^^ WARNING: DO NOT MODIFY SECTION ABOVE ^^^^^ # -################################################################################ diff --git a/scripts/.gitignore b/scripts/.gitignore deleted file mode 100644 index 5b6b0720..00000000 --- a/scripts/.gitignore +++ /dev/null @@ -1 +0,0 @@ -config.yaml diff --git a/scripts/Pipfile b/scripts/Pipfile deleted file mode 100644 index 236264fb..00000000 --- a/scripts/Pipfile +++ /dev/null @@ -1,22 +0,0 @@ -[[source]] -url = "https://pypi.org/simple" -verify_ssl = true -name = "pypi" - -[packages] -requests = "*" -pyyaml = "*" -addict = "*" -google-api-python-client = "*" -google-cloud-bigquery = "*" -google-cloud-storage = "*" -ipython = "<8.11" -more-executors = "*" -prometheus-client = "*" - -[dev-packages] - -[requires] -# Supports: python >= 3.6 -# REF: https://github.com/pypa/pipenv/issues/1050 -python_version = "3" diff --git a/scripts/conf.py b/scripts/conf.py deleted file mode 100755 index 0dd81d29..00000000 --- a/scripts/conf.py +++ /dev/null @@ -1,499 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (C) SchedMD LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional, Iterable, Dict -from itertools import chain -from collections import defaultdict -import json -from pathlib import Path -import util -from util import dirs, slurmdirs - -FILE_PREAMBLE = """ -# Warning: -# This file is managed by a script. Manual modifications will be overwritten. -""" - -login_nodeset = "x-login" - - -def dict_to_conf(conf, delim=" ") -> str: - """convert dict to delimited slurm-style key-value pairs""" - - def filter_conf(pair): - k, v = pair - if isinstance(v, list): - v = ",".join(el for el in v if el is not None) - return k, (v if bool(v) or v == 0 else None) - - return delim.join( - f"{k}={v}" for k, v in map(filter_conf, conf.items()) if v is not None - ) - - -def conflines(cloud_parameters, lkp: util.Lookup) -> str: - scripts_dir = lkp.cfg.install_dir or dirs.scripts - no_comma_params = cloud_parameters.no_comma_params or False - - any_gpus = any( - lkp.template_info(nodeset.instance_template).gpu_count > 0 - for nodeset in lkp.cfg.nodeset.values() - ) - - any_tpu = any( - tpu_nodeset is not None - for part in lkp.cfg.partitions.values() - for tpu_nodeset in part.partition_nodeset_tpu - ) - - any_dynamic = any(bool(p.partition_feature) for p in lkp.cfg.partitions.values()) - comma_params = { - "PrivateData": [ - "cloud", - ], - "LaunchParameters": [ - "enable_nss_slurm", - "use_interactive_step", - ], - "SlurmctldParameters": [ - "cloud_reg_addrs" if any_dynamic or any_tpu else "cloud_dns", - "enable_configless", - "idle_on_node_suspend", - ], - "SchedulerParameters": [ - "bf_continue", - "salloc_wait_nodes", - "ignore_prefer_validation", - ], - "GresTypes": [ - "gpu" if any_gpus else None, - ], - } - prolog_path = Path(dirs.custom_scripts / "prolog.d") - epilog_path = Path(dirs.custom_scripts / "epilog.d") - default_tree_width = 65533 if any_dynamic else None - conf_options = { - **(comma_params if not no_comma_params else {}), - "Prolog": f"{prolog_path}/*" if lkp.cfg.prolog_scripts else None, - "Epilog": f"{epilog_path}/*" if lkp.cfg.epilog_scripts else None, - "SuspendProgram": f"{scripts_dir}/suspend.py", - "ResumeProgram": f"{scripts_dir}/resume.py", - "ResumeFailProgram": f"{scripts_dir}/suspend.py", - "ResumeRate": cloud_parameters.get("resume_rate", 0), - "ResumeTimeout": cloud_parameters.get("resume_timeout", 300), - "SuspendRate": cloud_parameters.get("suspend_rate", 0), - "SuspendTimeout": cloud_parameters.get("suspend_timeout", 300), - "TreeWidth": cloud_parameters.get("tree_width", default_tree_width), - "JobSubmitPlugins": "lua" if any_tpu else None, - "TopologyPlugin": cloud_parameters.get("topology_plugin", "topology/tree"), - } - return dict_to_conf(conf_options, delim="\n") - - -def loginlines() -> str: - nodeset = { - "NodeSet": login_nodeset, - "Feature": login_nodeset, - } - partition = { - "PartitionName": login_nodeset, - "Nodes": login_nodeset, - "State": "UP", - "DefMemPerCPU": 1, - "Hidden": "YES", - "RootOnly": "YES", - } - lines = [ - dict_to_conf(nodeset), - dict_to_conf(partition), - ] - return "\n".join(lines) - - -def nodeset_lines(nodeset, lkp: util.Lookup) -> str: - template_info = lkp.template_info(nodeset.instance_template) - machine_conf = lkp.template_machine_conf(nodeset.instance_template) - - # follow https://slurm.schedmd.com/slurm.conf.html#OPT_Boards - # by setting Boards, SocketsPerBoard, CoresPerSocket, and ThreadsPerCore - node_def = { - "NodeName": "DEFAULT", - "State": "UNKNOWN", - "RealMemory": machine_conf.memory, - "Boards": machine_conf.boards, - "SocketsPerBoard": machine_conf.sockets_per_board, - "CoresPerSocket": machine_conf.cores_per_socket, - "ThreadsPerCore": machine_conf.threads_per_core, - "CPUs": machine_conf.cpus, - **nodeset.node_conf, - } - - gres = f"gpu:{template_info.gpu_count}" if template_info.gpu_count else None - nodelist = lkp.nodelist(nodeset) - - return "\n".join( - map( - dict_to_conf, - [ - node_def, - {"NodeName": nodelist, "State": "CLOUD", "Gres": gres}, - {"NodeSet": nodeset.nodeset_name, "Nodes": nodelist}, - ], - ) - ) - - -def nodeset_tpu_lines(nodeset, lkp: util.Lookup) -> str: - node_def = { - "NodeName": "DEFAULT", - "State": "UNKNOWN", - **nodeset.node_conf, - } - nodelist = lkp.nodelist(nodeset) - - return "\n".join( - map( - dict_to_conf, - [ - node_def, - {"NodeName": nodelist, "State": "CLOUD"}, - {"NodeSet": nodeset.nodeset_name, "Nodes": nodelist}, - ], - ) - ) - - -def nodeset_dyn_lines(nodeset): - """generate slurm NodeSet definition for dynamic nodeset""" - return dict_to_conf( - {"NodeSet": nodeset.nodeset_name, "Feature": nodeset.nodeset_feature} - ) - - -def partitionlines(partition, lkp: util.Lookup) -> str: - """Make a partition line for the slurm.conf""" - MIN_MEM_PER_CPU = 100 - - def defmempercpu(nodeset: str) -> int: - template = lkp.cfg.nodeset.get(nodeset).instance_template - machine = lkp.template_machine_conf(template) - return max(MIN_MEM_PER_CPU, machine.memory // machine.cpus) - - defmem = min( - map(defmempercpu, partition.partition_nodeset), default=MIN_MEM_PER_CPU - ) - - nodesets = list( - chain( - partition.partition_nodeset, - partition.partition_nodeset_dyn, - partition.partition_nodeset_tpu, - ) - ) - - is_tpu = len(partition.partition_nodeset_tpu) > 0 - is_dyn = len(partition.partition_nodeset_dyn) > 0 - - oversub_exlusive = partition.enable_job_exclusive or is_tpu - power_down_on_idle = partition.enable_job_exclusive and not is_dyn - - line_elements = { - "PartitionName": partition.partition_name, - "Nodes": ",".join(nodesets), - "State": "UP", - "DefMemPerCPU": defmem, - "SuspendTime": 300, - "Oversubscribe": "Exclusive" if oversub_exlusive else None, - "PowerDownOnIdle": "YES" if power_down_on_idle else None, - **partition.partition_conf, - } - - return dict_to_conf(line_elements) - - -def suspend_exc_lines(lkp: util.Lookup) -> Iterable[str]: - static_nodelists = [] - for ns in lkp.power_managed_nodesets(): - if ns.node_count_static: - nodelist = lkp.nodelist_range(ns.nodeset_name, 0, ns.node_count_static) - static_nodelists.append(nodelist) - suspend_exc_nodes = {"SuspendExcNodes": static_nodelists} - - dyn_parts = [ - p.partition_name - for p in lkp.cfg.partitions.values() - if len(p.partition_nodeset_dyn) > 0 - ] - suspend_exc_parts = {"SuspendExcParts": [login_nodeset, *dyn_parts]} - - return filter( - None, - [ - dict_to_conf(suspend_exc_nodes) if static_nodelists else None, - dict_to_conf(suspend_exc_parts), - ], - ) - - -def make_cloud_conf(lkp: util.Lookup) -> str: - """generate cloud.conf snippet""" - lines = [ - FILE_PREAMBLE, - conflines(lkp.cfg.cloud_parameters, lkp), - loginlines(), - *(nodeset_lines(n, lkp) for n in lkp.cfg.nodeset.values()), - *(nodeset_dyn_lines(n) for n in lkp.cfg.nodeset_dyn.values()), - *(nodeset_tpu_lines(n, lkp) for n in lkp.cfg.nodeset_tpu.values()), - *(partitionlines(p, lkp) for p in lkp.cfg.partitions.values()), - *(suspend_exc_lines(lkp)), - ] - return "\n\n".join(filter(None, lines)) - - -def gen_cloud_conf(lkp: util.Lookup) -> None: - content = make_cloud_conf(lkp) - - conf_file = Path(lkp.cfg.output_dir or slurmdirs.etc) / "cloud.conf" - conf_file.write_text(content) - util.chown_slurm(conf_file, mode=0o644) - - -def install_slurm_conf(lkp: util.Lookup) -> None: - """install slurm.conf""" - if lkp.cfg.ompi_version: - mpi_default = "pmi2" - else: - mpi_default = "none" - - conf_options = { - "name": lkp.cfg.slurm_cluster_name, - "control_addr": lkp.control_addr if lkp.control_addr else lkp.hostname_fqdn, - "control_host": lkp.control_host, - "control_host_port": lkp.control_host_port, - "scripts": dirs.scripts, - "slurmlog": dirs.log, - "state_save": slurmdirs.state, - "mpi_default": mpi_default, - } - - conf = lkp.cfg.slurm_conf_tpl.format(**conf_options) - - conf_file = Path(lkp.cfg.output_dir or slurmdirs.etc) / "slurm.conf" - conf_file.write_text(conf) - util.chown_slurm(conf_file, mode=0o644) - - -def install_slurmdbd_conf(lkp: util.Lookup) -> None: - """install slurmdbd.conf""" - conf_options = { - "control_host": lkp.control_host, - "slurmlog": dirs.log, - "state_save": slurmdirs.state, - "db_name": "slurm_acct_db", - "db_user": "slurm", - "db_pass": '""', - "db_host": "localhost", - "db_port": "3306", - } - - if lkp.cfg.cloudsql_secret: - secret_name = f"{lkp.cfg.slurm_cluster_name}-slurm-secret-cloudsql" - payload = json.loads(util.access_secret_version(lkp.project, secret_name)) - - if payload["db_name"] and payload["db_name"] != "": - conf_options["db_name"] = payload["db_name"] - if payload["user"] and payload["user"] != "": - conf_options["db_user"] = payload["user"] - if payload["password"] and payload["password"] != "": - conf_options["db_pass"] = payload["password"] - - db_host_str = payload["server_ip"].split(":") - if db_host_str[0]: - conf_options["db_host"] = db_host_str[0] - conf_options["db_port"] = ( - db_host_str[1] if len(db_host_str) >= 2 else "3306" - ) - - conf = lkp.cfg.slurmdbd_conf_tpl.format(**conf_options) - - conf_file = Path(lkp.cfg.output_dir or slurmdirs.etc) / "slurmdbd.conf" - conf_file.write_text(conf) - util.chown_slurm(conf_file, 0o600) - - -def install_cgroup_conf(lkp: util.Lookup) -> None: - """install cgroup.conf""" - conf_file = Path(lkp.cfg.output_dir or slurmdirs.etc) / "cgroup.conf" - conf_file.write_text(lkp.cfg.cgroup_conf_tpl) - util.chown_slurm(conf_file, mode=0o600) - - -def install_jobsubmit_lua(lkp: util.Lookup) -> None: - """install job_submit.lua if there are tpu nodes in the cluster""" - if any( - tpu_nodeset is not None - for part in lkp.cfg.partitions.values() - for tpu_nodeset in part.partition_nodeset_tpu - ): - conf_options = { - "scripts_dir": lkp.cfg.slurm_scripts_dir or dirs.scripts, - } - conf = lkp.cfg.jobsubmit_lua_tpl.format(**conf_options) - - conf_file = Path(lkp.cfg.output_dir or slurmdirs.etc) / "job_submit.lua" - conf_file.write_text(conf) - util.chown_slurm(conf_file, 0o600) - - -def gen_cloud_gres_conf(lkp: util.Lookup) -> None: - """generate cloud_gres.conf""" - - gpu_nodes = defaultdict(list) - for nodeset in lkp.cfg.nodeset.values(): - template_info = lkp.template_info(nodeset.instance_template) - gpu_count = template_info.gpu_count - if gpu_count == 0: - continue - gpu_nodes[gpu_count].append(lkp.nodelist(nodeset)) - - lines = [ - dict_to_conf( - { - "NodeName": names, - "Name": "gpu", - "File": "/dev/nvidia{}".format(f"[0-{i-1}]" if i > 1 else "0"), - } - ) - for i, names in gpu_nodes.items() - ] - lines.append("\n") - content = FILE_PREAMBLE + "\n".join(lines) - - conf_file = Path(lkp.cfg.output_dir or slurmdirs.etc) / "cloud_gres.conf" - conf_file.write_text(content) - util.chown_slurm(conf_file, mode=0o600) - - -def install_gres_conf(lkp: util.Lookup) -> None: - conf_file = Path(lkp.cfg.output_dir or slurmdirs.etc) / "cloud_gres.conf" - gres_conf = Path(lkp.cfg.output_dir or slurmdirs.etc) / "gres.conf" - if not gres_conf.exists(): - gres_conf.symlink_to(conf_file) - util.chown_slurm(gres_conf, mode=0o600) - - -class Switch: - """ - Represents a switch in the topology.conf file. - NOTE: It's class user job to make sure that there is no leaf-less Switches in the tree - """ - - def __init__( - self, - name: str, - nodes: Optional[Iterable[str]] = None, - switches: Optional[Dict[str, "Switch"]] = None, - ): - self.name = name - self.nodes = nodes or [] - self.switches = switches or {} - - def conf_line(self) -> str: - d = {"SwitchName": self.name} - if self.nodes: - d["Nodes"] = util.to_hostlist_fast(self.nodes) - if self.switches: - d["Switches"] = util.to_hostlist_fast(self.switches.keys()) - return dict_to_conf(d) - - def render_conf_lines(self) -> Iterable[str]: - yield self.conf_line() - for s in sorted(self.switches.values(), key=lambda s: s.name): - yield from s.render_conf_lines() - - -class TopologyBuilder: - def __init__(self) -> None: - self._r = Switch("root") - - def add(self, path: List[str], nodes: Iterable[str]) -> None: - n = self._r - assert path - for p in path: - n = n.switches.setdefault(p, Switch(p)) - n.nodes = chain(n.nodes, nodes) - - def render_conf_lines(self) -> Iterable[str]: - if not self._r.switches: - return [] - for s in sorted(self._r.switches.values(), key=lambda s: s.name): - yield from s.render_conf_lines() - - -def add_tpu_nodeset_topology(nodeset: object, bldr: TopologyBuilder, lkp: util.Lookup): - tpuobj = util.TPU(nodeset) - static, dynamic = lkp.nodenames(nodeset) - - pref = ["nodeset_tpu-root", nodeset.nodeset_name] - if tpuobj.vmcount == 1: # Put all nodes in one switch - bldr.add(pref, list(chain(static, dynamic))) - return - - # Chunk nodes into sub-switches of size `vmcount` - chunk_num = 0 - for nodenames in (static, dynamic): - for nodeschunk in util.chunked(nodenames, n=tpuobj.vmcount): - chunk_name = f"{nodeset.nodeset_name}-{chunk_num}" - chunk_num += 1 - bldr.add([*pref, chunk_name], list(nodeschunk)) - - -def add_nodeset_topology( - nodeset: object, bldr: TopologyBuilder, lkp: util.Lookup -) -> None: - path = ["nodeset-root", nodeset.nodeset_name] - nodes = list(chain(*lkp.nodenames(nodeset))) - bldr.add(path, nodes) - - -def gen_topology(lkp: util.Lookup) -> TopologyBuilder: - bldr = TopologyBuilder() - for ns in lkp.cfg.nodeset_tpu.values(): - add_tpu_nodeset_topology(ns, bldr, lkp) - for ns in lkp.cfg.nodeset.values(): - add_nodeset_topology(ns, bldr, lkp) - return bldr - - -def gen_topology_conf(lkp: util.Lookup) -> None: - """generate slurm topology.conf from config.yaml""" - bldr = gen_topology(lkp) - conf_file = Path(lkp.cfg.output_dir or slurmdirs.etc) / "cloud_topology.conf" - with open(conf_file, "w") as f: - f.writelines(FILE_PREAMBLE + "\n") - for line in bldr.render_conf_lines(): - f.write(line) - f.write("\n") - f.write("\n") - util.chown_slurm(conf_file, mode=0o600) - - -def install_topology_conf(lkp: util.Lookup) -> None: - conf_file = Path(lkp.cfg.output_dir or slurmdirs.etc) / "cloud_topology.conf" - topo_conf = Path(lkp.cfg.output_dir or slurmdirs.etc) / "topology.conf" - if not topo_conf.exists(): - topo_conf.symlink_to(conf_file) - util.chown_slurm(conf_file, mode=0o600) diff --git a/scripts/destroy_nodes.py b/scripts/destroy_nodes.py deleted file mode 100755 index 4e5c05db..00000000 --- a/scripts/destroy_nodes.py +++ /dev/null @@ -1,131 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (C) SchedMD LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from pathlib import Path -from time import sleep -from suspend import ( - batch_execute, - delete_instance_request, - truncate_iter, - wait_for_operations, -) -from util import lkp, compute, config_root_logger, parse_self_link - -logger_name = Path(__file__).name -log = logging.getLogger(logger_name) - - -def delete_instances(compute_list): - log.info( - "Deleting {0} compute instances:\n{1}".format( - len(compute_list), "\n".join(compute_list) - ) - ) - - ops = {} - for self_link in compute_list: - link_info = parse_self_link(self_link) - ops[self_link] = delete_instance_request( - instance=link_info.instance, project=link_info.project, zone=link_info.zone - ) - done, failed = batch_execute(ops) - if failed: - failed_nodes = [f"{n}: {e}" for n, (_, e) in failed.items()] - node_str = "\n".join(str(el) for el in truncate_iter(failed_nodes, 5)) - log.error(f"some nodes failed to delete: {node_str}") - wait_for_operations(done.values()) - - -def main(args): - required_map = { - "labels.slurm_cluster_name": args.slurm_cluster_name, - "labels.slurm_instance_role": "compute", - } - required_list = [f"{k}={v}" for k, v in required_map.items()] - required_logic = " AND ".join(required_list) - - target_list = ( - " OR ".join([f"name={x}" for x in args.target.split(",")]) - if args.target - else "" - ) - target_logic = f"AND ({target_list})" if args.target else "" - - exclude_list = ( - " AND ".join([f"name!={x}" for x in args.exclude.split(",")]) - if args.exclude - else "" - ) - exclude_logic = f"AND ({exclude_list})" if args.exclude else "" - - filter = f"{required_logic} {target_logic} {exclude_logic}" - log.debug(f'filter = "{filter}"') - - # NOTE: It is not technically possible to filter by metadata or other - # complex nested items - p_id = args.project_id if args.project_id else lkp.project - if not p_id: - print("Error: Project id cannot be determined") - exit(1) - result = compute.instances().aggregatedList(project=p_id, filter=filter).execute() - - compute_list = [] - for item in result["items"].values(): - instances = item.get("instances") - if instances is not None: - for instance in instances: - compute_list.append(instance["selfLink"]) - - delete_instances(compute_list) - - if len(compute_list) > 0: - sleep_dur = 30 - log.info(f"Done. Sleeping for {sleep_dur} seconds.") - sleep(sleep_dur) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter - ) - parser.add_argument("slurm_cluster_name", help="Slurm cluster name label filter") - parser.add_argument( - "--target", help="NodeNames targeted for destruction", type=str, default=None - ) - parser.add_argument( - "--project_id", help="Google cloud project ID", type=str, default=None - ) - parser.add_argument( - "--exclude", help="NodeNames excluded from destruction", type=str, default=None - ) - parser.add_argument( - "--debug", - "-d", - dest="debug", - action="store_true", - help="Enable debugging output", - ) - - args = parser.parse_args() - - logfile = (Path(__file__).parent / logger_name).with_suffix(".log") - if args.debug: - config_root_logger(logger_name, level="DEBUG", logfile=logfile) - else: - config_root_logger(logger_name, level="INFO", logfile=logfile) - - main(args) diff --git a/scripts/destroy_resource_policies.py b/scripts/destroy_resource_policies.py deleted file mode 100755 index 3bf22d66..00000000 --- a/scripts/destroy_resource_policies.py +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (C) SchedMD LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from pathlib import Path -from suspend import batch_execute, truncate_iter, wait_for_operations -from util import lkp, compute, config_root_logger, parse_self_link - -logger_name = Path(__file__).name -log = logging.getLogger(logger_name) - - -def delete_placement_groups(project, region, resourcePolicy): - request = compute.resourcePolicies().delete( - project=project, region=region, resourcePolicy=resourcePolicy - ) - return request - - -def delete_policies(policy_list): - log.info( - "Deleting {0} resource policies:\n{1}".format( - len(policy_list), "\n".join(policy_list) - ) - ) - - ops = {} - for self_link in policy_list: - link_info = parse_self_link(self_link) - ops[self_link] = delete_placement_groups( - project=link_info.project, - region=link_info.region, - resourcePolicy=link_info.resourcePolicie, - ) - done, failed = batch_execute(ops) - if failed: - failed_items = [f"{n}: {e}" for n, (_, e) in failed.items()] - items_str = "\n".join(str(el) for el in truncate_iter(failed_items, 5)) - log.error(f"some policies failed to delete: {items_str}") - wait_for_operations(done.values()) - - -def main(args): - # NOTE: Resource policies cannot be labeled - if args.partition_name: - filter = f"name={args.slurm_cluster_name}-{args.partition_name}-*" - else: - filter = f"name={args.slurm_cluster_name}-*" - log.debug(f'filter = "{filter}"') - p_id = args.project_id if args.project_id else lkp.project - if not p_id: - print("Error: Project id cannot be determined") - exit(1) - result = ( - compute.resourcePolicies().aggregatedList(project=p_id, filter=filter).execute() - ) - - policy_list = [] - for item in result["items"].values(): - policies = item.get("resourcePolicies") - if policies is not None: - for policy in policies: - policy_list.append(policy["selfLink"]) - - delete_policies(policy_list) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter - ) - parser.add_argument("slurm_cluster_name", help="Slurm cluster name filter") - parser.add_argument( - "--partition", "-p", dest="partition_name", help="Slurm partition name filter" - ) - parser.add_argument( - "--project_id", help="Google cloud project ID", type=str, default=None - ) - parser.add_argument( - "--debug", - "-d", - dest="debug", - action="store_true", - help="Enable debugging output", - ) - - args = parser.parse_args() - - logfile = (Path(__file__).parent / logger_name).with_suffix(".log") - if args.debug: - config_root_logger(logger_name, level="DEBUG", logfile=logfile) - else: - config_root_logger(logger_name, level="INFO", logfile=logfile) - - main(args) diff --git a/scripts/load_bq.py b/scripts/load_bq.py deleted file mode 100755 index 70dfa04d..00000000 --- a/scripts/load_bq.py +++ /dev/null @@ -1,329 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import os -import shelve -import uuid -from collections import namedtuple -from datetime import datetime, timezone, timedelta -from pathlib import Path -from pprint import pprint - -from google.cloud.bigquery import SchemaField -from google.cloud import bigquery as bq -from google.api_core import retry, exceptions - -import util -from util import run -from util import cfg - - -SACCT = "sacct" -script = Path(__file__).resolve() - -DEFAULT_TIMESTAMP_FILE = script.parent / "bq_timestamp" -timestamp_file = Path(os.environ.get("TIMESTAMP_FILE", DEFAULT_TIMESTAMP_FILE)) - -# cluster_id_file = script.parent / 'cluster_uuid' -# try: -# cluster_id = cluster_id_file.read_text().rstrip() -# except FileNotFoundError: -# cluster_id = uuid.uuid4().hex -# cluster_id_file.write_text(cluster_id) - -job_idx_cache_path = script.parent / "bq_job_idx_cache" - -SLURM_TIME_FORMAT = r"%Y-%m-%dT%H:%M:%S" - - -def make_datetime(time_string): - return datetime.strptime(time_string, SLURM_TIME_FORMAT).replace( - tzinfo=timezone.utc - ) - - -def make_time_interval(seconds): - sign = 1 - if seconds < 0: - sign = -1 - seconds = abs(seconds) - d, r = divmod(seconds, 60 * 60 * 24) - h, r = divmod(r, 60 * 60) - m, s = divmod(r, 60) - d *= sign - h *= sign - return f"{d}D {h:02}:{m:02}:{s}" - - -converters = { - "DATETIME": make_datetime, - "INTERVAL": make_time_interval, - "STRING": str, - "INT64": lambda n: int(n or 0), -} - - -def schema_field(field_name, data_type, description, required=False): - return SchemaField( - field_name, - data_type, - description=description, - mode="REQUIRED" if required else "NULLABLE", - ) - - -schema_fields = [ - schema_field("cluster_name", "STRING", "cluster name", required=True), - schema_field("cluster_id", "STRING", "UUID for the cluster", required=True), - schema_field("entry_uuid", "STRING", "entry UUID for the job row", required=True), - schema_field( - "job_db_uuid", "INT64", "job db index from the slurm database", required=True - ), - schema_field("job_id_raw", "INT64", "raw job id", required=True), - schema_field("job_id", "STRING", "job id", required=True), - schema_field("state", "STRING", "final job state", required=True), - schema_field("job_name", "STRING", "job name"), - schema_field("partition", "STRING", "job partition"), - schema_field("submit_time", "DATETIME", "job submit time"), - schema_field("start_time", "DATETIME", "job start time"), - schema_field("end_time", "DATETIME", "job end time"), - schema_field("elapsed_raw", "INT64", "STRING", "job run time in seconds"), - # schema_field("elapsed_time", "INTERVAL", "STRING", "job run time interval"), - schema_field("timelimit_raw", "STRING", "job timelimit in minutes"), - schema_field("timelimit", "STRING", "job timelimit"), - # schema_field("num_tasks", "INT64", "number of allocated tasks in job"), - schema_field("nodelist", "STRING", "names of nodes allocated to job"), - schema_field("user", "STRING", "user responsible for job"), - schema_field("uid", "INT64", "uid of job user"), - schema_field("group", "STRING", "group of job user"), - schema_field("gid", "INT64", "gid of job user"), - schema_field("wckey", "STRING", "job wckey"), - schema_field("qos", "STRING", "job qos"), - schema_field("comment", "STRING", "job comment"), - schema_field("admin_comment", "STRING", "job admin comment"), - # extra will be added in 23.02 - # schema_field("extra", "STRING", "job extra field"), - schema_field("exitcode", "STRING", "job exit code"), - schema_field("alloc_cpus", "INT64", "count of allocated CPUs"), - schema_field("alloc_nodes", "INT64", "number of nodes allocated to job"), - schema_field("alloc_tres", "STRING", "allocated trackable resources (TRES)"), - # schema_field("system_cpu", "INTERVAL", "cpu time used by parent processes"), - # schema_field("cpu_time", "INTERVAL", "CPU time used (elapsed * cpu count)"), - schema_field("cpu_time_raw", "INT64", "CPU time used (elapsed * cpu count)"), - # schema_field("ave_cpu", "INT64", "Average CPU time of all tasks in job"), - # schema_field( - # "tres_usage_tot", - # "STRING", - # "Tres total usage by all tasks in job", - # ), -] - - -slurm_field_map = { - "job_db_uuid": "DBIndex", - "job_id_raw": "JobIDRaw", - "job_id": "JobID", - "state": "State", - "job_name": "JobName", - "partition": "Partition", - "submit_time": "Submit", - "start_time": "Start", - "end_time": "End", - "elapsed_raw": "ElapsedRaw", - "elapsed_time": "Elapsed", - "timelimit_raw": "TimelimitRaw", - "timelimit": "Timelimit", - "num_tasks": "NTasks", - "nodelist": "Nodelist", - "user": "User", - "uid": "Uid", - "group": "Group", - "gid": "Gid", - "wckey": "Wckey", - "qos": "Qos", - "comment": "Comment", - "admin_comment": "AdminComment", - # "extra": "Extra", - "exit_code": "ExitCode", - "alloc_cpus": "AllocCPUs", - "alloc_nodes": "AllocNodes", - "alloc_tres": "AllocTres", - "system_cpu": "SystemCPU", - "cpu_time": "CPUTime", - "cpu_time_raw": "CPUTimeRaw", - "ave_cpu": "AveCPU", - "tres_usage_tot": "TresUsageInTot", -} - -# new field name is the key for job_schema. Used to lookup the datatype when -# creating the job rows -job_schema = {field.name: field for field in schema_fields} -# Order is important here, as that is how they are parsed from sacct output -Job = namedtuple("Job", job_schema.keys()) - -client = bq.Client( - project=cfg.project, - credentials=util.default_credentials(), - client_options=util.create_client_options(util.ApiEndpoint.BQ), -) -dataset_id = f"{cfg.slurm_cluster_name}_job_data" -dataset = bq.DatasetReference(project=cfg.project, dataset_id=dataset_id) -table = bq.Table( - bq.TableReference(dataset, f"jobs_{cfg.slurm_cluster_name}"), schema_fields -) - - -class JobInsertionFailed(Exception): - pass - - -def make_job_row(job): - job_row = { - field_name: dict.get(converters, field.field_type)(job[field_name]) - for field_name, field in job_schema.items() - if field_name in job - } - job_row["entry_uuid"] = uuid.uuid4().hex - job_row["cluster_id"] = cfg.cluster_id - job_row["cluster_name"] = cfg.slurm_cluster_name - return job_row - - -def load_slurm_jobs(start, end): - states = ",".join( - ( - "BOOT_FAIL", - "CANCELLED", - "COMPLETED", - "DEADLINE", - "FAILED", - "NODE_FAIL", - "OUT_OF_MEMORY", - "PREEMPTED", - "REQUEUED", - "REVOKED", - "TIMEOUT", - ) - ) - start_iso = start.isoformat(timespec="seconds") - end_iso = end.isoformat(timespec="seconds") - # slurm_fields and bq_fields will be in matching order - slurm_fields = ",".join(slurm_field_map.values()) - bq_fields = slurm_field_map.keys() - cmd = ( - f"{SACCT} --start {start_iso} --end {end_iso} -X -D --format={slurm_fields} " - f"--state={states} --parsable2 --noheader --allusers --duplicates" - ) - text = run(cmd).stdout.splitlines() - # zip pairs bq_fields with the value from sacct - jobs = [dict(zip(bq_fields, line.split("|"))) for line in text] - - # The job index cache allows us to avoid sending duplicate jobs. This avoids a race condition with updating the database. - with shelve.open(str(job_idx_cache_path), flag="r") as job_idx_cache: - job_rows = [ - make_job_row(job) - for job in jobs - if str(job["job_db_uuid"]) not in job_idx_cache - ] - return job_rows - - -def init_table(): - global dataset - global table - dataset = client.create_dataset(dataset, exists_ok=True) - table = client.create_table(table, exists_ok=True) - until_found = retry.Retry(predicate=retry.if_exception_type(exceptions.NotFound)) - table = client.get_table(table, retry=until_found) - # cannot add required fields to an existing schema - table.schema = schema_fields - table = client.update_table(table, ["schema"]) - - -def purge_job_idx_cache(): - purge_time = datetime.now() - timedelta(minutes=30) - with shelve.open(str(job_idx_cache_path), writeback=True) as cache: - to_delete = [] - for idx, stamp in cache.items(): - if stamp < purge_time: - to_delete.append(idx) - for idx in to_delete: - del cache[idx] - - -def bq_submit(jobs): - try: - result = client.insert_rows(table, jobs) - except exceptions.NotFound as e: - print(f"failed to upload job data, table not yet found: {e}") - raise e - except Exception as e: - print(f"failed to upload job data: {e}") - raise e - if result: - pprint(jobs) - pprint(result) - raise JobInsertionFailed("failed to upload job data to big query") - print(f"successfully loaded {len(jobs)} jobs") - - -def get_time_window(): - if not timestamp_file.is_file(): - timestamp_file.touch() - try: - timestamp = datetime.strptime( - timestamp_file.read_text().rstrip(), SLURM_TIME_FORMAT - ) - # time window will overlap the previous by 10 minutes. Duplicates will be filtered out by the job_idx_cache - start = timestamp - timedelta(minutes=10) - except ValueError: - # timestamp 1 is 1 second after the epoch; timestamp 0 is special for sacct - start = datetime.fromtimestamp(1) - # end is now() truncated to the last second - end = datetime.now().replace(microsecond=0) - return start, end - - -def write_timestamp(time): - timestamp_file.write_text(time.isoformat(timespec="seconds")) - - -def update_job_idx_cache(jobs, timestamp): - with shelve.open(str(job_idx_cache_path), writeback=True) as job_idx_cache: - for job in jobs: - job_idx = str(job["job_db_uuid"]) - job_idx_cache[job_idx] = timestamp - - -def main(): - if not cfg.enable_bigquery_load: - print("bigquery load is not currently enabled") - exit(0) - init_table() - - start, end = get_time_window() - jobs = load_slurm_jobs(start, end) - # on failure, an exception will cause the timestamp not to be rewritten. So - # it will try again next time. If some writes succeed, we don't currently - # have a way to not submit duplicates next time. - if jobs: - bq_submit(jobs) - write_timestamp(end) - update_job_idx_cache(jobs, end) - - -parser = argparse.ArgumentParser(description="submit slurm job data to big query") -parser.add_argument( - "timestamp_file", - nargs="?", - action="store", - type=Path, - help="specify timestamp file for reading and writing the time window start. Precedence over TIMESTAMP_FILE env var.", -) - -purge_job_idx_cache() -if __name__ == "__main__": - args = parser.parse_args() - if args.timestamp_file: - timestamp_file = args.timestamp_file.resolve() - main() diff --git a/scripts/requirements.txt b/scripts/requirements.txt deleted file mode 100644 index 48b69757..00000000 --- a/scripts/requirements.txt +++ /dev/null @@ -1,49 +0,0 @@ -addict==2.4.0 -backcall==0.2.0 -cachetools==5.3.1 -certifi==2023.7.22 -charset-normalizer==3.2.0 -decorator==5.1.1 -docopt==0.6.2 -google-api-core==2.19.0 -google-api-python-client==2.93.0 -google-auth==2.22.0 -google-auth-httplib2==0.1.0 -google-cloud-bigquery==3.11.3 -google-cloud-core==2.3.3 -google-cloud-storage==2.10.0 -google-cloud-tpu==1.10.0 -google-crc32c==1.5.0 -google-resumable-media==2.5.0 -googleapis-common-protos==1.59.1 -grpcio==1.56.2 -grpcio-status==1.56.0 -httplib2==0.22.0 -idna==3.7 -ipython>=8.10 -jedi==0.17.2 -more-executors==2.11.4 -packaging==23.1 -parso==0.7.1 -pexpect==4.8.0 -pickleshare==0.7.5 -pipreqs==0.4.13 -prometheus-client==0.17.1 -prompt-toolkit==3.0.39 -proto-plus==1.22.3 -protobuf==4.23.4 -ptyprocess==0.7.0 -pyasn1==0.5.0 -pyasn1-modules==0.3.0 -Pygments==2.15.1 -pyparsing==3.1.0 -python-dateutil==2.8.2 -PyYAML==6.0 -requests==2.31.0 -rsa==4.9 -six==1.16.0 -traitlets==5.9.0 -uritemplate==4.1.1 -urllib3==1.26.18 -wcwidth==0.2.6 -yarg==0.1.9 diff --git a/scripts/resume.py b/scripts/resume.py deleted file mode 100755 index 01fac856..00000000 --- a/scripts/resume.py +++ /dev/null @@ -1,708 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (C) SchedMD LLC. -# Copyright 2015 Google Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List -import argparse -import collections -import json -import logging -import os -import sys -import yaml -from itertools import chain -from pathlib import Path - -import util -from util import ( - chunked, - dirs, - ensure_execute, - execute_with_futures, - get_insert_operations, - log_api_request, - map_with_futures, - run, - separate, - to_hostlist, - to_hostlist_fast, - trim_self_link, - wait_for_operation, -) -from util import cfg, lkp, NSDict, TPU - -# from util import cfg, lkp, NSDict -import slurm_gcp_plugins - - -filename = Path(__file__).name -LOGFILE = (Path(cfg.slurm_log_dir if cfg else ".") / filename).with_suffix(".log") - -log = logging.getLogger(filename) - - -global_resume_data = None - -PLACEMENT_MAX_CNT = 150 -# Placement group needs to be the same for an entire bulk_insert hence -# if placement is used the actual BULK_INSERT_LIMIT will be -# max([1000, PLACEMENT_MAX_CNT]) -BULK_INSERT_LIMIT = 5000 - - -def instance_properties(nodeset, model, placement_group, labels=None): - template = lkp.node_template(model) - template_info = lkp.template_info(template) - - props = NSDict() - - slurm_metadata = { - "slurm_cluster_name": cfg.slurm_cluster_name, - "slurm_instance_role": "compute", - "startup-script": ( - Path(cfg.slurm_scripts_dir or util.dirs.scripts) / "startup.sh" - ).read_text(), - } - info_metadata = { - item.get("key"): item.get("value") for item in template_info.metadata["items"] - } - - props_metadata = {**info_metadata, **slurm_metadata} - props.metadata = { - "items": [NSDict({"key": k, "value": v}) for k, v in props_metadata.items()] - } - - labels = { - "slurm_cluster_name": cfg.slurm_cluster_name, - "slurm_instance_role": "compute", - **(labels or {}), - } - props.labels = {**template_info.labels, **labels} - - for disk in template_info.disks: - # do not label local ssd - if ( - "diskType" not in disk.initializeParams - or disk.initializeParams.diskType == "local-ssd" - ): - continue - disk.initializeParams.labels.update(labels) - props.disks = template_info.disks - - if placement_group: - props.scheduling = { - "onHostMaintenance": "TERMINATE", - "automaticRestart": False, - } - props.resourcePolicies = [ - placement_group, - ] - - if nodeset.reservation_name: - reservation_name = nodeset.reservation_name - - zones = list(nodeset.zone_policy_allow or []) - assert len(zones) == 1, "Only single zone is supported if using a reservation" - - reservation = lkp.reservation(reservation_name, zones[0]) - - props.reservationAffinity = { - "consumeReservationType": "SPECIFIC_RESERVATION", - "key": f"compute.{util.universe_domain()}/reservation-name", - "values": [reservation_name], - } - - policies = util.reservation_resource_policies(reservation) - if policies: - props.scheduling = { - "onHostMaintenance": "TERMINATE", - "automaticRestart": False, - } - props.resourcePolicies = policies - log.info( - f"reservation {reservation_name} is being used with policies {props.resourcePolicies}" - ) - else: - props.resourcePolicies = [] - log.info( - f"reservation {reservation_name} is being used without any policies" - ) - - if nodeset.maintenance_interval: - props.scheduling = props.scheduling or {} - props.scheduling["maintenanceInterval"] = nodeset.maintenance_interval - - return props - - -def per_instance_properties(node): - props = NSDict() - # No properties beyond name are supported yet. - - return props - - -def create_instances_request(nodes, partition_name, placement_group, job_id=None): - """Call regionInstances.bulkInsert to create instances""" - assert len(nodes) > 0 - if placement_group: - assert len(nodes) <= min(PLACEMENT_MAX_CNT, BULK_INSERT_LIMIT) - else: - assert len(nodes) <= BULK_INSERT_LIMIT - - # model here indicates any node that can be used to describe the rest - model = next(iter(nodes)) - nodeset = lkp.node_nodeset(model) - template = lkp.node_template(model) - region = lkp.node_region(model) - partition = cfg.partitions[partition_name] - log.debug(f"create_instances_request: {model} placement: {placement_group}") - - body = NSDict() - body.count = len(nodes) - body.minCount = 1 - - # source of instance properties - body.sourceInstanceTemplate = template - - labels = ( - dict(slurm_job_id=job_id) - if job_id is not None and partition.enable_job_exclusive - else None - ) - # overwrites properties across all instances - body.instanceProperties = instance_properties( - nodeset, model, placement_group, labels - ) - - # key is instance name, value overwrites properties - body.perInstanceProperties = {k: per_instance_properties(k) for k in nodes} - - zones = { - **{ - f"zones/{zone}": {"preference": "ALLOW"} - for zone in nodeset.zone_policy_allow or [] - }, - **{ - f"zones/{zone}": {"preference": "DENY"} - for zone in nodeset.zone_policy_deny or [] - }, - } - body.locationPolicy.targetShape = cfg.zone_target_shape or "ANY_SINGLE_ZONE" - if zones: - body.locationPolicy.locations = zones - - if lkp.cfg.enable_slurm_gcp_plugins: - slurm_gcp_plugins.pre_instance_bulk_insert( - lkp=lkp, - nodes=nodes, - placement_group=placement_group, - request_body=body, - ) - - request = util.compute.regionInstances().bulkInsert( - project=cfg.project, region=region, body=body.to_dict() - ) - - if log.isEnabledFor(logging.DEBUG): - log.debug( - f"new request: endpoint={request.methodId} nodes={to_hostlist_fast(nodes)}" - ) - log_api_request(request) - return request - - -def group_nodes_bulk(nodes, resume_data=None): - """group nodes by job_id, placement_group, node_group, and max bulkInsert size""" - if resume_data is None: - # all nodes will be considered jobless - jobs = {} - else: - jobs = {job.job_id: job for job in resume_data.jobs} - - # expand all job nodelists - for job in jobs.values(): - job.nodelist_alloc = job.nodes_alloc - job.nodes_alloc = util.to_hostnames(job.nodelist_alloc) - job.nodelist_resume = job.nodes_resume - job.nodes_resume = util.to_hostnames(job.nodelist_resume) - job.tpu = util.part_is_tpu(job.partition) - if not job.tpu: - # create placement groups if nodes for job need it - job.placement_groups = create_placement_groups( - node_list=job.nodes_alloc, - job_id=job.job_id, - ) - # placement group assignment is based on all allocated nodes, but we only want to - # handle nodes in nodes_resume in this run. - for pg, pg_nodes in job.placement_groups.items(): - job.placement_groups[pg] = list( - set(pg_nodes).intersection(job.nodes_resume) - ) - # a bit of a hack, but nodes resumed using scontrol instead of through job scheduling do not have a job - jobless_nodes = list( - set(nodes).difference( - chain.from_iterable(job.nodes_resume for job in jobs.values()) - ) - ) - jobless_nodes_tpu = [] - for jobless_node in jobless_nodes[:]: - if lkp.node_is_tpu(jobless_node): - jobless_nodes.remove(jobless_node) - jobless_nodes_tpu.append(jobless_node) - - jobs["Normal_None"] = NSDict( - job_id=None, - nodes_resume=jobless_nodes, - nodes_alloc=jobless_nodes, - placement_groups=create_placement_groups(node_list=jobless_nodes), - partition=None, - tpu=False, - ) - jobs["TPU_None"] = NSDict( - job_id=None, - nodes_resume=jobless_nodes_tpu, - nodes_alloc=jobless_nodes_tpu, - partition=None, - tpu=True, - ) - - BulkChunk = collections.namedtuple( - "BulkChunk", - ["prefix", "job_id", "partition_name", "placement_group", "nodes", "i"], - ) - BulkChunkTPU = collections.namedtuple( - "BulkChunkTPU", - ["prefix", "job_id", "partition_name", "nodes", "i"], - ) - grouped_nodes = [ - BulkChunk( - prefix, - job_id if job_id != "Normal_None" else None, - jobs[job_id].partition, - placement_group, - chunk_nodes, - i, - ) - for job_id, job in jobs.items() - if not job.tpu - for placement_group, pg_nodes in job.placement_groups.items() - for prefix, nodes in util.groupby_unsorted(pg_nodes, lkp.node_prefix) - for i, chunk_nodes in enumerate(chunked(nodes, n=BULK_INSERT_LIMIT)) - ] - grouped_nodes_tpu = [ - BulkChunkTPU( - prefix, - job_id if job_id != "TPU_None" else None, - jobs[job_id].partition, - chunk_nodes, - i, - ) - for job_id, job in jobs.items() - if job.tpu - for prefix, nodes in util.groupby_unsorted(job.nodes_resume, lkp.node_prefix) - for i, chunk_nodes in enumerate(lkp.chunk_tpu_nodes(list(nodes))) - ] - - def group_name(chunk: BulkChunk): - if chunk.placement_group is not None: - return f"{chunk.prefix}:job{chunk.job_id}:{chunk.placement_group}:{chunk.i}" - if chunk.job_id is not None: - return f"{chunk.prefix}:job{chunk.job_id}:{chunk.i}" - return f"{chunk.prefix}:{chunk.i}" - - def group_name_tpu(chunk: BulkChunkTPU): - if chunk.job_id is not None: - return f"{chunk.prefix}:job{chunk.job_id}:{chunk.i}" - return f"{chunk.prefix}:{chunk.i}" - - grouped_nodes = {group_name(chunk): chunk for chunk in grouped_nodes} - grouped_nodes_tpu = {group_name_tpu(chunk): chunk for chunk in grouped_nodes_tpu} - return grouped_nodes, grouped_nodes_tpu - - -def start_tpu(data): - tpu = data["tpu"] - node = data["node"] - if len(node) == 1: - node = node[0] - log.debug( - f"Will create a TPU of type {tpu.node_type} tf_version {tpu.tf_version} in zone {tpu.zone} with name {node}" - ) - tpunode = tpu.get_node(node) - if tpunode is None: - if not tpu.create_node(nodename=node): - log.error("Error creating tpu node {node}") - else: - if tpu.preserve_tpu: - if not tpu.start_node(nodename=node): - log.error("Error starting tpu node {node}") - else: - log.info( - f"Tpu node {node} is already created, but will not start it because nodeset does not have preserve_tpu option active." - ) - else: - log.debug( - f"Will create a multi-vm TPU of type {tpu.node_type} tf_version {tpu.tf_version} in zone {tpu.zone} with name {node[0]}" - ) - if not tpu.create_node(nodename=node): - log.error("Error creating tpu node {node}") - - -def resume_nodes(nodes: List[str], resume_data=None): - """resume nodes in nodelist""" - if not nodes: - log.info("No nodes to resume") - return - - if resume_data is None and global_resume_data is not None: - resume_data = global_resume_data.deepcopy() - - nodes = sorted(nodes, key=lkp.node_prefix) - grouped_nodes, grouped_tpu_nodes = group_nodes_bulk(nodes, resume_data) - - if log.isEnabledFor(logging.DEBUG): - # grouped_nodelists is used in later debug logs too - grouped_nodelists = { - group: to_hostlist(chunk.nodes) for group, chunk in grouped_nodes.items() - } - grouped_tpu_nodelists = { - group: to_hostlist(chunk.nodes) - for group, chunk in grouped_tpu_nodes.items() - } - log.debug( - "node bulk groups: \n{}".format(yaml.safe_dump(grouped_nodelists).rstrip()) - ) - log.debug( - "TPU node bulk groups: \n{}".format( - yaml.safe_dump(grouped_tpu_nodelists).rstrip() - ) - ) - tpu_start_data = [] - tpu_objs = {} - for group, chunk in grouped_tpu_nodes.items(): - # do not create multiple tpu_objs if nodes with the same prefix are used - if chunk.prefix not in tpu_objs.keys(): - model = chunk.nodes[0] - tpu_objs[chunk.prefix] = TPU(lkp.node_nodeset(model)) - - tpu_start_data.append({"tpu": tpu_objs[chunk.prefix], "node": chunk.nodes}) - - # make all bulkInsert requests and execute with batch - inserts = { - group: create_instances_request( - chunk.nodes, chunk.partition_name, chunk.placement_group, chunk.job_id - ) - for group, chunk in grouped_nodes.items() - } - - bulk_ops = dict( - zip(inserts.keys(), map_with_futures(ensure_execute, inserts.values())) - ) - log.debug(f"bulk_ops={yaml.safe_dump(bulk_ops)}") - started = { - group: op for group, op in bulk_ops.items() if not isinstance(op, Exception) - } - failed = { - group: err for group, err in bulk_ops.items() if isinstance(err, Exception) - } - if failed: - failed_reqs = [str(e) for e in failed.items()] - log.error("bulkInsert API failures: {}".format("; ".join(failed_reqs))) - for ident, exc in failed.items(): - down_nodes(grouped_nodes[ident].nodes, f"GCP Error: {exc._get_reason()}") - - if log.isEnabledFor(logging.DEBUG): - for group, op in started.items(): - group_nodes = grouped_nodelists[group] - name = op["name"] - gid = op["operationGroupId"] - log.debug( - f"new bulkInsert operation started: group={group} nodes={group_nodes} name={name} operationGroupId={gid}" - ) - # wait for all bulkInserts to complete and log any errors - bulk_operations = {group: wait_for_operation(op) for group, op in started.items()} - - # Start TPU after regular nodes so that regular nodes are not affected by the slower TPU nodes - log.debug(f"tpu_start_data={yaml.safe_dump(tpu_start_data)}") - execute_with_futures(start_tpu, tpu_start_data) - - all_successful_inserts = [] - - for group, bulk_op in bulk_operations.items(): - group_id = bulk_op["operationGroupId"] - bulk_op_name = bulk_op["name"] - if "error" in bulk_op: - error = bulk_op["error"]["errors"][0] - group_nodes = to_hostlist_fast(grouped_nodes[group].nodes) - log.warning( - f"bulkInsert operation errors: {error['code']} name={bulk_op_name} operationGroupId={group_id} nodes={group_nodes}" - ) - successful_inserts, failed_inserts = separate( - lambda op: "error" in op, get_insert_operations(group_id) - ) - # Apparently multiple errors are possible... so join with +. - by_error_inserts = util.groupby_unsorted( - failed_inserts, - lambda op: "+".join(err["code"] for err in op["error"]["errors"]), - ) - for code, failed_ops in by_error_inserts: - failed_nodes = {trim_self_link(op["targetLink"]): op for op in failed_ops} - hostlist = util.to_hostlist(failed_nodes) - count = len(failed_nodes) - log.error( - f"{count} instances failed to start: {code} ({hostlist}) operationGroupId={group_id}" - ) - failed_node, failed_op = next(iter(failed_nodes.items())) - msg = "; ".join( - f"{err['code']}: {err['message'] if 'message' in err else 'no message'}" - for err in failed_op["error"]["errors"] - ) - if code != "RESOURCE_ALREADY_EXISTS": - down_nodes(hostlist, f"GCP Error: {msg}") - log.error( - f"errors from insert for node '{failed_node}' ({failed_op['name']}): {msg}" - ) - - ready_nodes = {trim_self_link(op["targetLink"]) for op in successful_inserts} - if len(ready_nodes) > 0: - ready_nodelist = to_hostlist_fast(ready_nodes) - log.info(f"created {len(ready_nodes)} instances: nodes={ready_nodelist}") - all_successful_inserts.extend(successful_inserts) - - -def update_job_comment(nodelist: str, comment: str): - if global_resume_data is None: - log.warning( - "Cannot update and notify jobs with API failures as no valid resume file is present." - ) - return - - nodes = util.to_hostnames(nodelist) - job_list = ( - job - for job in global_resume_data.jobs - if any(map(lambda node: node in nodes, util.to_hostnames(job.nodelist_resume))) - ) - for job in job_list: - run(f"{lkp.scontrol} update jobid={job.job_id} admincomment='{comment}'") - run(f"{lkp.scontrol} notify {job.job_id} '{comment}'") - - -def down_nodes(nodelist, reason): - """set nodes down with reason""" - if isinstance(nodelist, list): - nodelist = util.to_hostlist(nodelist) - update_job_comment(nodelist, reason) - run(f"{lkp.scontrol} update nodename={nodelist} state=down reason='{reason}'") - - -def hold_job(job_id, reason): - """hold job, set comment to reason""" - run(f"{lkp.scontrol} hold jobid={job_id}") - run(f"{lkp.scontrol} update jobid={job_id} comment='{reason}'") - - -def create_placement_request(pg_name, region): - config = { - "name": pg_name, - "region": region, - "groupPlacementPolicy": { - "collocation": "COLLOCATED", - }, - } - if lkp.cfg.enable_slurm_gcp_plugins: - slurm_gcp_plugins.pre_placement_group_insert( - lkp=lkp, pg_name=pg_name, region=region, request_body=config - ) - request = util.compute.resourcePolicies().insert( - project=cfg.project, region=region, body=config - ) - log_api_request(request) - return request - - -def create_placement_groups(node_list: list, job_id=0): - pgs = {} - node_map = lkp.nodeset_map(node_list) - for _, nodes in node_map.items(): - pgs.update(create_nodeset_placement_groups(nodes, job_id=job_id)) - return pgs - - -def create_nodeset_placement_groups(node_list: list, job_id=0): - model = next(iter(node_list)) - nodeset = lkp.node_nodeset(model) - if not nodeset.enable_placement: - return {None: node_list} - if not valid_placement_nodes(node_list): - return {None: node_list} - region = lkp.node_region(model) - - groups = { - f"{cfg.slurm_cluster_name}-{nodeset.nodeset_name}-{job_id}-{i}": nodes - for i, nodes in enumerate(chunked(node_list, n=PLACEMENT_MAX_CNT)) - } - - if log.isEnabledFor(logging.DEBUG): - debug_groups = { - group: to_hostlist_fast(nodes) for group, nodes in groups.items() - } - log.debug( - f"creating {len(groups)} placement groups: \n{yaml.safe_dump(debug_groups).rstrip()}" - ) - requests = { - group: create_placement_request(group, region) - for group, incl_nodes in groups.items() - } - ops = dict( - zip(requests.keys(), map_with_futures(ensure_execute, requests.values())) - ) - - def classify_result(item): - op = item[1] - if not isinstance(op, Exception): - return "submitted" - if all(e.get("reason") == "alreadyExists" for e in op.error_details): - return "redundant" - return "failed" - - grouped_ops = dict(util.groupby_unsorted(list(ops.items()), classify_result)) - submitted, redundant, failed = ( - dict(grouped_ops.get(key, {})) for key in ("submitted", "redundant", "failed") - ) - if redundant: - log.warning( - "placement policies already exist: {}".format(",".join(redundant.keys())) - ) - if failed: - reqs = [f"{e}" for _, e in failed.values()] - log.fatal("failed to create placement policies: {}".format("; ".join(reqs))) - operations = {group: wait_for_operation(op) for group, op in submitted.items()} - for group, op in operations.items(): - if "error" in op: - msg = "; ".join( - f"{err['code']}: {err['message'] if 'message' in err else 'no message'}" - for err in op["error"]["errors"] - ) - log.error( - f"placement group failed to create: '{group}' ({op['name']}): {msg}" - ) - - log.info( - f"created {len(operations)} placement groups ({to_hostlist_fast(operations.keys())})" - ) - return groups - - -def valid_placement_nodes(nodelist): - invalid_types = frozenset(["e2", "t2d", "n1", "t2a", "m1", "m2", "m3"]) - for node in nodelist: - mt = lkp.node_template_info(node).machineType - if mt.split("-")[0] in invalid_types: - log.warn(f"Unsupported machine type for placement policy: {mt}.") - log.warn( - f"Please do not use any the following machine types with placement policy: ({','.join(invalid_types)})" - ) - return False - return True - - -def get_resume_file_data(): - SLURM_RESUME_FILE = os.getenv("SLURM_RESUME_FILE") - if SLURM_RESUME_FILE is None: - log.warning( - "SLURM_RESUME_FILE was not in environment. Cannot get detailed job, node, partition allocation data." - ) - return None - resume_file = Path(SLURM_RESUME_FILE) - resume_json = resume_file.read_text() - if args.loglevel == logging.DEBUG: - (dirs.scripts / "resume_data.json").write_text(resume_json) - return NSDict(json.loads(resume_json)) - - -def main(nodelist, force=False): - """main called when run as script""" - log.debug(f"ResumeProgram {nodelist}") - # Filter out nodes not in config.yaml - other_nodes, pm_nodes = separate( - lkp.is_power_managed_node, util.to_hostnames(nodelist) - ) - if other_nodes: - log.debug( - f"Ignoring non-power-managed nodes '{to_hostlist_fast(other_nodes)}' from '{nodelist}'" - ) - - pm_nodelist = util.to_hostlist_fast(pm_nodes) - if pm_nodes: - log.debug(f"Resuming nodes '{pm_nodelist}' from '{nodelist}'") - else: - log.debug("No nodes to resume") - return - - log.info(f"resume {pm_nodelist}") - resume_nodes(pm_nodes, global_resume_data) - # TODO only run below if resume_nodes succeeds but - # resume_nodes does not currently return any status. - if lkp.cfg.enable_slurm_gcp_plugins: - slurm_gcp_plugins.post_main_resume_nodes( - nodelist=nodelist, global_resume_data=global_resume_data - ) - - -parser = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter -) -parser.add_argument("nodelist", help="list of nodes to resume") -parser.add_argument( - "--force", - "-f", - "--static", - action="store_true", - help="Force attempted creation of the nodelist, whether nodes are exclusive or not.", -) -parser.add_argument( - "--debug", - "-d", - dest="loglevel", - action="store_const", - const=logging.DEBUG, - default=logging.INFO, - help="Enable debugging output", -) -parser.add_argument( - "--trace-api", - "-t", - action="store_true", - help="Enable detailed api request output", -) - - -if __name__ == "__main__": - args = parser.parse_args() - - if cfg.enable_debug_logging: - args.loglevel = logging.DEBUG - if args.trace_api: - cfg.extra_logging_flags = list(cfg.extra_logging_flags) - cfg.extra_logging_flags.append("trace_api") - util.chown_slurm(LOGFILE, mode=0o600) - util.config_root_logger(filename, level=args.loglevel, logfile=LOGFILE) - sys.excepthook = util.handle_exception - - global_resume_data = get_resume_file_data() - main(args.nodelist, args.force) diff --git a/scripts/setup.py b/scripts/setup.py deleted file mode 100755 index 92d14bc0..00000000 --- a/scripts/setup.py +++ /dev/null @@ -1,545 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (C) SchedMD LLC. -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -import os -import shutil -import subprocess -import sys -import stat -import time -from pathlib import Path - -import util -from util import ( - lkp, - cfg, - dirs, - slurmdirs, - run, - install_custom_scripts, -) - -from conf import ( - install_slurm_conf, - install_slurmdbd_conf, - gen_cloud_conf, - gen_cloud_gres_conf, - gen_topology_conf, - install_gres_conf, - install_cgroup_conf, - install_topology_conf, - install_jobsubmit_lua, - login_nodeset, -) -from slurmsync import sync_slurm - -from setup_network_storage import ( - setup_network_storage, - setup_nfs_exports, -) - -SETUP_SCRIPT = Path(__file__) -filename = SETUP_SCRIPT.name -LOGFILE = ((cfg.slurm_log_dir if cfg else ".") / SETUP_SCRIPT).with_suffix(".log") -log = logging.getLogger(filename) - - -MOTD_HEADER = """ - SSSSSSS - SSSSSSSSS - SSSSSSSSS - SSSSSSSSS - SSSS SSSSSSS SSSS - SSSSSS SSSSSS - SSSSSS SSSSSSS SSSSSS - SSSS SSSSSSSSS SSSS - SSS SSSSSSSSS SSS - SSSSS SSSS SSSSSSSSS SSSS SSSSS - SSS SSSSSS SSSSSSSSS SSSSSS SSS - SSSSSS SSSSSSS SSSSSS - SSS SSSSSS SSSSSS SSS - SSSSS SSSS SSSSSSS SSSS SSSSS - S SSS SSSSSSSSS SSS S - SSS SSSS SSSSSSSSS SSSS SSS - S SSS SSSSSS SSSSSSSSS SSSSSS SSS S - SSSSS SSSSSS SSSSSSSSS SSSSSS SSSSS - S SSSSS SSSS SSSSSSS SSSS SSSSS S - S SSS SSS SSS SSS S - S S S S - SSS - SSS - SSS - SSS - SSSSSSSSSSSS SSS SSSS SSSS SSSSSSSSS SSSSSSSSSSSSSSSSSSSS -SSSSSSSSSSSSS SSS SSSS SSSS SSSSSSSSSS SSSSSSSSSSSSSSSSSSSSSS -SSSS SSS SSSS SSSS SSSS SSSS SSSS SSSS -SSSS SSS SSSS SSSS SSSS SSSS SSSS SSSS -SSSSSSSSSSSS SSS SSSS SSSS SSSS SSSS SSSS SSSS - SSSSSSSSSSSS SSS SSSS SSSS SSSS SSSS SSSS SSSS - SSSS SSS SSSS SSSS SSSS SSSS SSSS SSSS - SSSS SSS SSSS SSSS SSSS SSSS SSSS SSSS -SSSSSSSSSSSSS SSS SSSSSSSSSSSSSSS SSSS SSSS SSSS SSSS -SSSSSSSSSSSS SSS SSSSSSSSSSSSS SSSS SSSS SSSS SSSS - -""" - - -def start_motd(): - """advise in motd that slurm is currently configuring""" - wall_msg = "*** Slurm is currently being configured in the background. ***" - motd_msg = MOTD_HEADER + wall_msg + "\n\n" - Path("/etc/motd").write_text(motd_msg) - util.run(f"wall -n '{wall_msg}'", timeout=30) - - -def end_motd(broadcast=True): - """modify motd to signal that setup is complete""" - Path("/etc/motd").write_text(MOTD_HEADER) - - if not broadcast: - return - - run( - "wall -n '*** Slurm {} setup complete ***'".format(lkp.instance_role), - timeout=30, - ) - if lkp.instance_role != "controller": - run( - """wall -n ' -/home on the controller was mounted over the existing /home. -Log back in to ensure your home directory is correct. -'""", - timeout=30, - ) - - -def failed_motd(): - """modify motd to signal that setup is failed""" - wall_msg = f"*** Slurm setup failed! Please view log: {LOGFILE} ***" - motd_msg = MOTD_HEADER + wall_msg + "\n\n" - Path("/etc/motd").write_text(motd_msg) - util.run(f"wall -n '{wall_msg}'", timeout=30) - - -def run_custom_scripts(): - """run custom scripts based on instance_role""" - custom_dir = dirs.custom_scripts - if lkp.instance_role == "controller": - # controller has all scripts, but only runs controller.d - custom_dirs = [custom_dir / "controller.d"] - elif lkp.instance_role == "compute": - # compute setup with compute.d and nodeset.d - custom_dirs = [custom_dir / "compute.d", custom_dir / "nodeset.d"] - elif lkp.instance_role == "login": - # login setup with only login.d - custom_dirs = [custom_dir / "login.d"] - else: - # Unknown role: run nothing - custom_dirs = [] - custom_scripts = [ - p - for d in custom_dirs - for p in d.rglob("*") - if p.is_file() and not p.name.endswith(".disabled") - ] - print_scripts = ",".join(str(s.relative_to(custom_dir)) for s in custom_scripts) - log.debug(f"custom scripts to run: {custom_dir}/({print_scripts})") - - try: - for script in custom_scripts: - if "/controller.d/" in str(script): - timeout = lkp.cfg.get("controller_startup_scripts_timeout", 300) - elif "/compute.d/" in str(script) or "/nodeset.d/" in str(script): - timeout = lkp.cfg.get("compute_startup_scripts_timeout", 300) - elif "/login.d/" in str(script): - timeout = lkp.cfg.get("login_startup_scripts_timeout", 300) - else: - timeout = 300 - timeout = None if not timeout or timeout < 0 else timeout - log.info(f"running script {script.name} with timeout={timeout}") - result = run(str(script), timeout=timeout, check=False, shell=True) - runlog = ( - f"{script.name} returncode={result.returncode}\n" - f"stdout={result.stdout}stderr={result.stderr}" - ) - log.info(runlog) - result.check_returncode() - except OSError as e: - log.error(f"script {script} is not executable") - raise e - except subprocess.TimeoutExpired as e: - log.error(f"script {script} did not complete within timeout={timeout}") - raise e - except Exception as e: - log.error(f"script {script} encountered an exception") - log.exception(e) - raise e - - -def setup_secondary_disks(): - """Format and mount secondary disk""" - run( - "sudo mkfs.ext4 -m 0 -F -E lazy_itable_init=0,lazy_journal_init=0,discard /dev/sdb" - ) - with open("/etc/fstab", "a") as f: - f.write( - "\n/dev/sdb {0} ext4 discard,defaults,nofail 0 2".format( - dirs.secdisk - ) - ) - - -def setup_jwt_key(): - jwt_key = Path(slurmdirs.state / "jwt_hs256.key") - - if jwt_key.exists(): - log.info("JWT key already exists. Skipping key generation.") - else: - run("dd if=/dev/urandom bs=32 count=1 > " + str(jwt_key), shell=True) - - util.chown_slurm(jwt_key, mode=0o400) - - -def setup_munge_key(): - munge_key = Path(dirs.munge / "munge.key") - - if munge_key.exists(): - log.info("Munge key already exists. Skipping key generation.") - else: - run("create-munge-key -f", timeout=30) - - shutil.chown(munge_key, user="munge", group="munge") - os.chmod(munge_key, stat.S_IRUSR) - run("systemctl restart munge", timeout=30) - - -def setup_nss_slurm(): - """install and configure nss_slurm""" - # setup nss_slurm - util.mkdirp(Path("/var/spool/slurmd")) - run( - "ln -s {}/lib/libnss_slurm.so.2 /usr/lib64/libnss_slurm.so.2".format( - slurmdirs.prefix - ), - check=False, - ) - run(r"sed -i 's/\(^\(passwd\|group\):\s\+\)/\1slurm /g' /etc/nsswitch.conf") - - -def setup_sudoers(): - content = """ -# Allow SlurmUser to manage the slurm daemons -slurm ALL= NOPASSWD: /usr/bin/systemctl restart slurmd.service -slurm ALL= NOPASSWD: /usr/bin/systemctl restart slurmctld.service -""" - sudoers_file = Path("/etc/sudoers.d/slurm") - sudoers_file.write_text(content) - sudoers_file.chmod(0o0440) - - -def update_system_config(file, content): - """Add system defaults options for service files""" - sysconfig = Path("/etc/sysconfig") - default = Path("/etc/default") - - if sysconfig.exists(): - conf_dir = sysconfig - elif default.exists(): - conf_dir = default - else: - raise Exception("Cannot determine system configuration directory.") - - slurmd_file = Path(conf_dir, file) - slurmd_file.write_text(content) - - -def configure_mysql(): - cnfdir = Path("/etc/my.cnf.d") - if not cnfdir.exists(): - cnfdir = Path("/etc/mysql/conf.d") - if not (cnfdir / "mysql_slurm.cnf").exists(): - (cnfdir / "mysql_slurm.cnf").write_text( - """ -[mysqld] -bind-address=127.0.0.1 -innodb_buffer_pool_size=1024M -innodb_log_file_size=64M -innodb_lock_wait_timeout=900 -""" - ) - run("systemctl enable mariadb", timeout=30) - run("systemctl restart mariadb", timeout=30) - - mysql = "mysql -u root -e" - run(f"""{mysql} "drop user 'slurm'@'localhost'";""", timeout=30, check=False) - run(f"""{mysql} "create user 'slurm'@'localhost'";""", timeout=30) - run( - f"""{mysql} "grant all on slurm_acct_db.* TO 'slurm'@'localhost'";""", - timeout=30, - ) - run( - f"""{mysql} "drop user 'slurm'@'{lkp.control_host}'";""", - timeout=30, - check=False, - ) - run(f"""{mysql} "create user 'slurm'@'{lkp.control_host}'";""", timeout=30) - run( - f"""{mysql} "grant all on slurm_acct_db.* TO 'slurm'@'{lkp.control_host}'";""", - timeout=30, - ) - - -def configure_dirs(): - for p in dirs.values(): - util.mkdirp(p) - util.chown_slurm(dirs.slurm) - util.chown_slurm(dirs.scripts) - - for p in slurmdirs.values(): - util.mkdirp(p) - util.chown_slurm(p) - - etc_slurm = Path("/etc/slurm") - if etc_slurm.exists() and etc_slurm.is_symlink(): - etc_slurm.unlink() - etc_slurm.symlink_to(slurmdirs.etc) - - scripts_etc = dirs.scripts / "etc" - if scripts_etc.exists() and scripts_etc.is_symlink(): - scripts_etc.unlink() - scripts_etc.symlink_to(slurmdirs.etc) - - scripts_log = dirs.scripts / "log" - if scripts_log.exists() and scripts_log.is_symlink(): - scripts_log.unlink() - scripts_log.symlink_to(dirs.log) - - -def setup_controller(args): - """Run controller setup""" - log.info("Setting up controller") - util.chown_slurm(dirs.scripts / "config.yaml", mode=0o600) - install_custom_scripts() - - install_slurm_conf(lkp) - install_slurmdbd_conf(lkp) - - gen_cloud_conf(lkp) - gen_cloud_gres_conf(lkp) - gen_topology_conf(lkp) - install_gres_conf(lkp) - install_cgroup_conf(lkp) - install_topology_conf(lkp) - install_jobsubmit_lua(lkp) - - setup_jwt_key() - setup_munge_key() - setup_sudoers() - - if cfg.controller_secondary_disk: - setup_secondary_disks() - setup_network_storage(log) - - run_custom_scripts() - - if not cfg.cloudsql_secret: - configure_mysql() - - run("systemctl enable slurmdbd", timeout=30) - run("systemctl restart slurmdbd", timeout=30) - - # Wait for slurmdbd to come up - time.sleep(5) - - sacctmgr = f"{slurmdirs.prefix}/bin/sacctmgr -i" - result = run( - f"{sacctmgr} add cluster {cfg.slurm_cluster_name}", timeout=30, check=False - ) - if "already exists" in result.stdout: - log.info(result.stdout) - elif result.returncode > 1: - result.check_returncode() # will raise error - - run("systemctl enable slurmctld", timeout=30) - run("systemctl restart slurmctld", timeout=30) - - run("systemctl enable slurmrestd", timeout=30) - run("systemctl restart slurmrestd", timeout=30) - - # Export at the end to signal that everything is up - run("systemctl enable nfs-server", timeout=30) - run("systemctl start nfs-server", timeout=30) - - setup_nfs_exports() - run("systemctl enable --now slurmcmd.timer", timeout=30) - - log.info("Check status of cluster services") - run("systemctl status munge", timeout=30) - run("systemctl status slurmdbd", timeout=30) - run("systemctl status slurmctld", timeout=30) - run("systemctl status slurmrestd", timeout=30) - - sync_slurm() - run("systemctl enable slurm_load_bq.timer", timeout=30) - run("systemctl start slurm_load_bq.timer", timeout=30) - run("systemctl status slurm_load_bq.timer", timeout=30) - - log.info("Done setting up controller") - pass - - -def setup_login(args): - """run login node setup""" - log.info("Setting up login") - slurmctld_host = f"{lkp.control_host}" - if lkp.control_addr: - slurmctld_host = f"{lkp.control_host}({lkp.control_addr})" - slurmd_options = [ - f'--conf-server="{slurmctld_host}:{lkp.control_host_port}"', - f'--conf="Feature={login_nodeset}"', - "-Z", - ] - sysconf = f"""SLURMD_OPTIONS='{" ".join(slurmd_options)}'""" - update_system_config("slurmd", sysconf) - install_custom_scripts() - - setup_network_storage(log) - setup_sudoers() - run("systemctl restart munge") - run("systemctl enable slurmd", timeout=30) - run("systemctl restart slurmd", timeout=30) - run("systemctl enable --now slurmcmd.timer", timeout=30) - - run_custom_scripts() - - log.info("Check status of cluster services") - run("systemctl status munge", timeout=30) - run("systemctl status slurmd", timeout=30) - - log.info("Done setting up login") - - -def setup_compute(args): - """run compute node setup""" - log.info("Setting up compute") - util.chown_slurm(dirs.scripts / "config.yaml", mode=0o600) - slurmctld_host = f"{lkp.control_host}" - if lkp.control_addr: - slurmctld_host = f"{lkp.control_host}({lkp.control_addr})" - slurmd_options = [ - f'--conf-server="{slurmctld_host}:{lkp.control_host_port}"', - ] - if args.slurmd_feature is not None: - slurmd_options.append(f'--conf="Feature={args.slurmd_feature}"') - slurmd_options.append("-Z") - sysconf = f"""SLURMD_OPTIONS='{" ".join(slurmd_options)}'""" - update_system_config("slurmd", sysconf) - install_custom_scripts() - - setup_nss_slurm() - setup_network_storage(log) - - has_gpu = run("lspci | grep --ignore-case 'NVIDIA' | wc -l", shell=True).returncode - if has_gpu: - run("nvidia-smi") - - run_custom_scripts() - - setup_sudoers() - run("systemctl restart munge", timeout=30) - run("systemctl enable slurmd", timeout=30) - run("systemctl restart slurmd", timeout=30) - run("systemctl enable --now slurmcmd.timer", timeout=30) - - log.info("Check status of cluster services") - run("systemctl status munge", timeout=30) - run("systemctl status slurmd", timeout=30) - - log.info("Done setting up compute") - - -def main(args): - start_motd() - configure_dirs() - - # call the setup function for the instance type - setup = dict.get( - { - "controller": setup_controller, - "compute": setup_compute, - "login": setup_login, - }, - lkp.instance_role, - lambda: log.fatal(f"Unknown node role: {lkp.instance_role}"), - ) - setup(args) - - end_motd() - - -if __name__ == "__main__": - util.chown_slurm(LOGFILE, mode=0o600) - - parser = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter - ) - parser.add_argument( - "--slurmd-feature", - dest="slurmd_feature", - help="Feature for slurmd to register with. Controller ignores this option.", - ) - args = parser.parse_args() - - util.config_root_logger(filename, logfile=LOGFILE) - sys.excepthook = util.handle_exception - - lkp = util.Lookup(cfg) # noqa F811 - - try: - main(args) - except subprocess.TimeoutExpired as e: - log.error( - f"""TimeoutExpired: - command={e.cmd} - timeout={e.timeout} - stdout: -{e.stdout.strip()} - stderr: -{e.stderr.strip()} -""" - ) - log.error("Aborting setup...") - failed_motd() - except subprocess.CalledProcessError as e: - log.error( - f"""CalledProcessError: - command={e.cmd} - returncode={e.returncode} - stdout: -{e.stdout.strip()} - stderr: -{e.stderr.strip()} -""" - ) - log.error("Aborting setup...") - failed_motd() - except Exception as e: - log.exception(e) - log.error("Aborting setup...") - failed_motd() diff --git a/scripts/setup_hybrid.py b/scripts/setup_hybrid.py deleted file mode 100755 index 3abe00cc..00000000 --- a/scripts/setup_hybrid.py +++ /dev/null @@ -1,67 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (C) SchedMD LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import sys -from pathlib import Path -import setup -import util -from util import lkp, config_root_logger, handle_exception - - -filename = Path(__file__).name -logfile = Path(filename).with_suffix(".log") -log = logging.getLogger(filename) -setup.log.disabled = False -util.log.disabled = False - - -def main(args): - log.info("Generating new cloud.conf for slurm.conf") - setup.gen_cloud_conf(lkp) - - log.info("Generating new cloud_gres.conf for gres.conf") - setup.gen_cloud_gres_conf(lkp) - - log.info("Generating new cloud_topology.conf for topology.conf") - setup.gen_topology_conf(lkp) - - log.info("Done.") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter - ) - parser.add_argument( - "--debug", - "-d", - dest="debug", - action="store_true", - help="Enable debugging output", - ) - - args = parser.parse_args() - - if args.debug: - config_root_logger(filename, level="DEBUG", logfile=logfile) - else: - config_root_logger(filename, level="INFO", logfile=logfile) - sys.excepthook = handle_exception - - main(args) diff --git a/scripts/setup_network_storage.py b/scripts/setup_network_storage.py deleted file mode 100755 index b3283dd3..00000000 --- a/scripts/setup_network_storage.py +++ /dev/null @@ -1,307 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (C) SchedMD LLC. -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys -import stat -import time - -import shutil -from pathlib import Path -from concurrent.futures import as_completed -from addict import Dict as NSDict - -import util -from util import lkp, run, cfg, dirs, separate - - -def mounts_by_local(mounts): - """convert list of mounts to dict of mounts, local_mount as key""" - return {str(Path(m.local_mount).resolve()): m for m in mounts} - - -def resolve_network_storage(nodeset=None): - """Combine appropriate network_storage fields to a single list""" - - if lkp.instance_role == "compute": - try: - nodeset = lkp.node_nodeset() - except Exception: - # External nodename, skip lookup - nodeset = None - - # seed mounts with the default controller mounts - if cfg.disable_default_mounts: - default_mounts = [] - else: - default_mounts = [ - NSDict( - { - "server_ip": lkp.control_addr or lkp.control_host, - "remote_mount": str(path), - "local_mount": str(path), - "fs_type": "nfs", - "mount_options": "defaults,hard,intr", - } - ) - for path in ( - dirs.home, - dirs.apps, - ) - ] - - # create dict of mounts, local_mount: mount_info - mounts = mounts_by_local(default_mounts) - - # On non-controller instances, entries in network_storage could overwrite - # default exports from the controller. Be careful, of course - mounts.update(mounts_by_local(cfg.network_storage)) - if lkp.instance_role in ("login", "controller"): - mounts.update(mounts_by_local(cfg.login_network_storage)) - - if nodeset is not None: - mounts.update(mounts_by_local(nodeset.network_storage)) - return list(mounts.values()) - - -def separate_external_internal_mounts(mounts): - """separate into cluster-external and internal mounts""" - - def internal_mount(mount): - # NOTE: Valid Lustre server_ip can take the form of '@tcp' - server_ip = mount.server_ip.split("@")[0] - mount_addr = util.host_lookup(server_ip) - return mount_addr == lkp.control_host_addr - - return separate(internal_mount, mounts) - - -def setup_network_storage(log): - """prepare network fs mounts and add them to fstab""" - log.info("Set up network storage") - # filter mounts into two dicts, cluster-internal and external mounts - - all_mounts = resolve_network_storage() - ext_mounts, int_mounts = separate_external_internal_mounts(all_mounts) - - if lkp.instance_role == "controller": - mounts = ext_mounts - else: - mounts = ext_mounts + int_mounts - - # Determine fstab entries and write them out - fstab_entries = [] - for mount in mounts: - local_mount = Path(mount.local_mount) - remote_mount = mount.remote_mount - fs_type = mount.fs_type - server_ip = mount.server_ip or "" - util.mkdirp(local_mount) - - log.info( - "Setting up mount ({}) {}{} to {}".format( - fs_type, - server_ip + ":" if fs_type != "gcsfuse" else "", - remote_mount, - local_mount, - ) - ) - - mount_options = mount.mount_options.split(",") if mount.mount_options else [] - if not mount_options or "_netdev" not in mount_options: - mount_options += ["_netdev"] - - if fs_type == "gcsfuse": - fstab_entries.append( - "{0} {1} {2} {3} 0 0".format( - remote_mount, local_mount, fs_type, ",".join(mount_options) - ) - ) - else: - fstab_entries.append( - "{0}:{1} {2} {3} {4} 0 0".format( - server_ip, - remote_mount, - local_mount, - fs_type, - ",".join(mount_options), - ) - ) - - fstab = Path("/etc/fstab") - if not Path(fstab.with_suffix(".bak")).is_file(): - shutil.copy2(fstab, fstab.with_suffix(".bak")) - shutil.copy2(fstab.with_suffix(".bak"), fstab) - with open(fstab, "a") as f: - f.write("\n") - for entry in fstab_entries: - f.write(entry) - f.write("\n") - - mount_fstab(mounts_by_local(mounts), log) - munge_mount_handler(log) - - -def mount_fstab(mounts, log): - """Wait on each mount, then make sure all fstab is mounted""" - from more_executors import Executors, ExceptionRetryPolicy - - def mount_path(path): - log.info(f"Waiting for '{path}' to be mounted...") - try: - run(f"mount {path}", timeout=120) - except Exception as e: - exc_type, _, _ = sys.exc_info() - log.error(f"mount of path '{path}' failed: {exc_type}: {e}") - raise e - log.info(f"Mount point '{path}' was mounted.") - - MAX_MOUNT_TIMEOUT = 60 * 5 - future_list = [] - retry_policy = ExceptionRetryPolicy( - max_attempts=40, exponent=1.6, sleep=1.0, max_sleep=16.0 - ) - with Executors.thread_pool().with_timeout(MAX_MOUNT_TIMEOUT).with_retry( - retry_policy=retry_policy - ) as exe: - for path in mounts: - future = exe.submit(mount_path, path) - future_list.append(future) - - # Iterate over futures, checking for exceptions - for future in as_completed(future_list): - try: - future.result() - except Exception as e: - raise e - - -def munge_mount_handler(log): - if not cfg.munge_mount: - log.error("Missing munge_mount in cfg") - elif lkp.instance_role == "controller": - return - - mount = cfg.munge_mount - server_ip = ( - mount.server_ip - if mount.server_ip - else (cfg.slurm_control_addr or cfg.slurm_control_host) - ) - remote_mount = mount.remote_mount - local_mount = Path("/mnt/munge") - fs_type = mount.fs_type if mount.fs_type is not None else "nfs" - mount_options = ( - mount.mount_options - if mount.mount_options is not None - else "defaults,hard,intr,_netdev" - ) - - munge_key = Path(dirs.munge / "munge.key") - - log.info(f"Mounting munge share to: {local_mount}") - local_mount.mkdir() - if fs_type.lower() == "gcsfuse".lower(): - if remote_mount is None: - remote_mount = "" - cmd = [ - "gcsfuse", - f"--only-dir={remote_mount}" if remote_mount != "" else None, - server_ip, - str(local_mount), - ] - else: - if remote_mount is None: - remote_mount = Path("/etc/munge") - cmd = [ - "mount", - f"--types={fs_type}", - f"--options={mount_options}" if mount_options != "" else None, - f"{server_ip}:{remote_mount}", - str(local_mount), - ] - # wait max 120s for munge mount - timeout = 120 - for retry, wait in enumerate(util.backoff_delay(0.5, timeout), 1): - try: - run(cmd, timeout=timeout) - break - except Exception as e: - log.error( - f"munge mount failed: '{cmd}' {e}, try {retry}, waiting {wait:0.2f}s" - ) - time.sleep(wait) - err = e - continue - else: - raise err - - log.info(f"Copy munge.key from: {local_mount}") - shutil.copy2(Path(local_mount / "munge.key"), munge_key) - - log.info("Restrict permissions of munge.key") - shutil.chown(munge_key, user="munge", group="munge") - os.chmod(munge_key, stat.S_IRUSR) - - log.info(f"Unmount {local_mount}") - if fs_type.lower() == "gcsfuse".lower(): - run(f"fusermount -u {local_mount}", timeout=120) - else: - run(f"umount {local_mount}", timeout=120) - shutil.rmtree(local_mount) - - -def setup_nfs_exports(): - """nfs export all needed directories""" - # The controller only needs to set up exports for cluster-internal mounts - # switch the key to remote mount path since that is what needs exporting - mounts = resolve_network_storage() - # manually add munge_mount - mounts.append( - NSDict( - { - "server_ip": cfg.munge_mount.server_ip, - "remote_mount": cfg.munge_mount.remote_mount, - "local_mount": Path(f"{dirs.munge}_tmp"), - "fs_type": cfg.munge_mount.fs_type, - "mount_options": cfg.munge_mount.mount_options, - } - ) - ) - # controller mounts - _, con_mounts = separate_external_internal_mounts(mounts) - con_mounts = {m.remote_mount: m for m in con_mounts} - for nodeset in cfg.nodeset.values(): - # get internal mounts for each nodeset by calling - # resolve_network_storage as from a node in each nodeset - ns_mounts = resolve_network_storage(nodeset=nodeset) - _, int_mounts = separate_external_internal_mounts(ns_mounts) - con_mounts.update({m.remote_mount: m for m in int_mounts}) - - # export path if corresponding selector boolean is True - exports = [] - for path in con_mounts: - util.mkdirp(Path(path)) - run(rf"sed -i '\#{path}#d' /etc/exports", timeout=30) - exports.append(f"{path} *(rw,no_subtree_check,no_root_squash)") - - exportsd = Path("/etc/exports.d") - util.mkdirp(exportsd) - with (exportsd / "slurm.exports").open("w") as f: - f.write("\n") - f.write("\n".join(exports)) - run("exportfs -a", timeout=30) diff --git a/scripts/slurm_gcp_plugins/README.md b/scripts/slurm_gcp_plugins/README.md deleted file mode 100644 index 7c739363..00000000 --- a/scripts/slurm_gcp_plugins/README.md +++ /dev/null @@ -1,107 +0,0 @@ -# Plugin mechanism for slurm-gcp - -## Introduction - -Slurm in general provides many hooks for customization of its various functions. -In fact - slurm-gcp is using one of these customization points, PrologSlurmctld, -to perform tasks related to VM instance creation as a response to job node -allocation. - -The plugin mechanism in this directory similarly allows deployment specific -customizations to slurm-gcp by dropping Python modules in -`/scripts/slurm_gcp_plugins` and enabling plugins setting the -configuration directive `enable_slurm_gcp_plugins = true` in -`/scripts/config.yaml` - -A very basic `test_plugin`, is provided as an example. - -## Plugins - -Callbacks to registered plugins can be made from various places in resume.py and -suspend.py. The following callbacks are currently made: - -### Callback function signature - -Callback functions in the plugins are recommended to be declared as follows: - -```python -def post_main_resume_nodes(*pos_args, **keyword_args): -... -``` - -and extract arguments from `keyword_args`. Check the callback sites to -understand which values that are available. - -### Current callback sites: - -Callbacks are currently performed from the following places: - -#### scripts/resume.py:main_resume_nodes - -At the end of main the following callback is called - -```python -def post_main_resume_nodes(*pos_args, **keyword_args): -``` - -The primary intention is allow a plugin to record details about the instance -and/or setup/change properties for which the VMs needs to be up and running. - -Currently the call is made regardless of if the the resume node operation -succeeded or not. - -#### scripts/resume.py:create_instances_request - -In create_instances_request just before the bulk instance insert is called, the -following callback is called - -```python -def pre_instance_bulk_insert(*pos_args, **keyword_args): -``` - -The primary intention is allow a plugin to modify the instance creation request. - -#### scripts/resume.py:create_placement_request - -In create_instances_request just before the resource policy creation, the -following callback is called - -```python -def pre_placement_group_insert(*pos_args, **keyword_args): -``` - -The primary intention is allow a plugin to modify the resource policy creation -request. - -#### scripts/suspend.py:main_suspend_nodes - -In main just before the VMs are deleted but while they still (should) exist, the -following callback is called - -```python -def pre_main_suspend_nodes(*pos_args, **keyword_args): -``` - -The primary intention is allow a plugin to cleanup or record details while the -node still exists. - -#### scripts/util.py:instances - -Just before the per-instance information is requested the following callback is -called: - -```python -def register_instance_information_fields(*pos_args, **keyword_args): -``` - -The primary intention is allow a plugin to add information to the per instance -lookup. - -### Logging and error handling - -Plugin functions are recommended to use `logging` to communicate information, -warnings and errors. The `slurm_gcp_plugins` registry tries to isolate the -caller of the callbacks (i.e. resume.py and suspend.py) from effects of errors -with a general try-catch wrapper for each plugin callback. However - as the -callback happens in the same process there are notable limits on how much -isolation that can be achieved. diff --git a/scripts/slurm_gcp_plugins/__init__.py b/scripts/slurm_gcp_plugins/__init__.py deleted file mode 100644 index a4f11079..00000000 --- a/scripts/slurm_gcp_plugins/__init__.py +++ /dev/null @@ -1,135 +0,0 @@ -import importlib -import pkgutil -import logging -import inspect - -# Only perform discovery at init -discovered_plugins = { - name.lstrip("."): importlib.import_module(name=name, package="slurm_gcp_plugins") - for finder, name, ispkg in pkgutil.iter_modules(path=__path__, prefix=".") - if name.lstrip(".") != "utils" -} - -logging.info( - ( - "slurm_gcp_plugins found:" - + ", ".join( - [ - "slurm_gcp_plugins" + plugin - for plugin in sorted(discovered_plugins.keys()) - ] - ) - ) -) - - -def get_plugins(): - return discovered_plugins - - -def get_plugins_function(function_name): - plugins = get_plugins() - - return { - plugin: function - for plugin in sorted(plugins.keys()) - for name, function in inspect.getmembers(plugins[plugin], inspect.isfunction) - if name == function_name - } - - -def run_plugins_for_function(plugin_function_name, pos_args, keyword_args): - if "lkp" not in keyword_args: - logging.error( - ( - f"Plugin callback {plugin_function_name} called" - + 'without a "lkp" argument need to get obtain deployment' - + "information" - ) - ) - return - - if not keyword_args["lkp"].cfg: - logging.error( - ( - f"Plugin callback {plugin_function_name} called" - + 'with "lkp.cfg" unpopulated. lkp.cfg is needed' - + "to argument need to get obtain deployment" - + "information" - ) - ) - return - - cfg = keyword_args["lkp"].cfg - if cfg.enable_slurm_gcp_plugins: - for plugin, function in get_plugins_function(plugin_function_name).items(): - if plugin in cfg.enable_slurm_gcp_plugins: - logging.debug(f"Running {function} from plugin {plugin}") - try: - function(*pos_args, **keyword_args) - except BaseException as e: - logging.error( - f"Plugin callback {plugin}:{function} caused an exception: {e}" - ) - else: - logging.debug( - f"Not running {function} from non-enabled plugin {plugin}" - ) - - -# Implement this function to add fields to the cached VM instance lookup -def register_instance_information_fields(*pos_args, **keyword_args): - run_plugins_for_function( - plugin_function_name="register_instance_information_fields", - pos_args=pos_args, - keyword_args=keyword_args, - ) - - -# Called just after VM instances have been created and are up -def post_main_resume_nodes(*pos_args, **keyword_args): - run_plugins_for_function( - plugin_function_name="post_main_resume_nodes", - pos_args=pos_args, - keyword_args=keyword_args, - ) - - -# Called just before VM instances are deleted should be still up -# (NOTE: if a node has failed it might not be up or unresponsive) -def pre_main_suspend_nodes(*pos_args, **keyword_args): - run_plugins_for_function( - plugin_function_name="pre_main_suspend_nodes", - pos_args=pos_args, - keyword_args=keyword_args, - ) - - -# Called just before VM instances are created are created with -# bulkInsert- this function can be implemented to inspect and/or -# modify the insertion request. -def pre_instance_bulk_insert(*pos_args, **keyword_args): - run_plugins_for_function( - plugin_function_name="pre_instance_bulk_insert", - pos_args=pos_args, - keyword_args=keyword_args, - ) - - -# Called just before placement groups are created - this function can -# be implemented to inspect and/or modify the insertion request. -def pre_placement_group_insert(*pos_args, **keyword_args): - run_plugins_for_function( - plugin_function_name="pre_placement_group_insert", - pos_args=pos_args, - keyword_args=keyword_args, - ) - - -__all__ = [ - "post_main_resume_nodes", - "pre_main_suspend_nodes", - "register_instance_information_fields", - "pre_instance_bulk_insert", - "pre_placement_group_insert", -] diff --git a/scripts/slurm_gcp_plugins/max_hops/README.md b/scripts/slurm_gcp_plugins/max_hops/README.md deleted file mode 100644 index 9e8ad4af..00000000 --- a/scripts/slurm_gcp_plugins/max_hops/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# max_hops slurm_gcp_plugin plugin - -## Overview - -This plugin allows placement parameters to be set controlling the max number of -network hops between nodes in dynamic jobs. - -## Usage - -### Configuration - -This plugin can be enabled by adding the following to the slurm-gcp config: - -```yaml -enable_slurm_gcp_plugins: - #possibly other plugins - max_hops: - max_hops: 1 -``` - -to set the default max_hops to, in this example, 1 for _all_ jobs. - -### Per job setting - -The max hops setting can be changed on a per job basis using the --prefer -argument e.g. as follows: - -salloc --prefer=max_hops.max_hops=1 - -to allow at most one network hop. For this to work the -`ignore_prefer_validation` needs to be added to the slurm `SchedulerParameters` -configuration item. - -## Callbacks used - -### pre_placement_group_insert - -Used to change the placement group creation request. diff --git a/scripts/slurm_gcp_plugins/max_hops/__init__.py b/scripts/slurm_gcp_plugins/max_hops/__init__.py deleted file mode 100644 index 6505f8f4..00000000 --- a/scripts/slurm_gcp_plugins/max_hops/__init__.py +++ /dev/null @@ -1,58 +0,0 @@ -import logging -import sys -import slurm_gcp_plugins.utils as sgp_utils - -# Allows setting a specific max_hop for jobs -# -# To enable: -# * add this directory to the slurm-gcp plugin path (usually /slurm/scripts/slurm-gcp-plugins) -# * add the following to the slurm-gcp config (usually /slurm/scripts/config.yaml): -# -# enable_slurm_gcp_plugins: -# -# max_hops: -# max_hops: -# -# -# Where can be either of 1,2,3 (in increasing order of distance) -# If no max_hops is provided but the plugins is still enabled the default level is 3 - - -def pre_placement_group_insert(*pos_args, **keyword_args): - logging.info("Trying to enable max hop") - # Avoid circular import (util imports the plugins) - if "util" in sys.modules: - logging.info("Setting compute service version to beta") - sys.modules["util"].compute = sys.modules["util"].compute_service( - version="beta" - ) - max_distance = sgp_utils.get_plugin_setting( - plugin="max_hops", - setting="max_hops", - job=get_job_from_placement_group_name(keyword_args["pg_name"]), - lkp=keyword_args["lkp"], - default=3, - ) - logging.debug(f"Setting max hop for placement policy to {max_distance}") - keyword_args["request_body"]["groupPlacementPolicy"][ - "collocation=" - ] = "COLLOCATED" - keyword_args["request_body"]["groupPlacementPolicy"][ - "maxDistance" - ] = max_distance - else: - logging.error( - "max_hops can not be set (slurm_gcp util.py must be imported by the caller of the plugin callback)" - ) - - -__all__ = [ - "pre_placement_group_insert", -] - - -# This should be replaced if the job id becomes available in the context of this plugin hook -def get_job_from_placement_group_name(pg_name): - # f"{cfg.slurm_cluster_name}-{partition_name}-{job_id}-{i}" - - return pg_name.split("-")[2] diff --git a/scripts/slurm_gcp_plugins/test_plugin/README.md b/scripts/slurm_gcp_plugins/test_plugin/README.md deleted file mode 100644 index c3a46ca4..00000000 --- a/scripts/slurm_gcp_plugins/test_plugin/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Test slurm_gcp_plugin plugin - -## Overview - -This is a very basic but still useful test plugin that records the VM instance -id of the nodes used for jobs (when dynamic nodes are used). - -## Callbacks used - -### post_main_resume_nodes - -Used to log the instance id of created VMs - -### register_instance_information_fields - -Used to add the instance id to the information collected for VM instances. diff --git a/scripts/slurm_gcp_plugins/test_plugin/__init__.py b/scripts/slurm_gcp_plugins/test_plugin/__init__.py deleted file mode 100644 index deb53f7a..00000000 --- a/scripts/slurm_gcp_plugins/test_plugin/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -import logging - -instance_information_fields = ["resourceStatus", "id"] - - -def register_instance_information_fields(*pos_args, **keyword_args): - logging.debug("register_instance_information_fields called from test_plugin") - keyword_args["instance_information_fields"].extend(instance_information_fields) - - -def post_main_resume_nodes(*pos_args, **keyword_args): - logging.debug("post_main_resume_nodes called from test_plugin") - for node in keyword_args["nodelist"]: - logging.info( - ( - "test_plugin:" - + f"nodename:{node} " - + f"instance_id:{keyword_args['lkp'].instance(node)['id']} " - + f"physicalHost:{keyword_args['lkp'].instance(node)['resourceStatus']['physicalHost']}" - ) - ) - - -__all__ = [ - "register_instance_information_fields", - "post_main_resume_nodes", -] diff --git a/scripts/slurm_gcp_plugins/utils/__init__.py b/scripts/slurm_gcp_plugins/utils/__init__.py deleted file mode 100644 index 6977fb5c..00000000 --- a/scripts/slurm_gcp_plugins/utils/__init__.py +++ /dev/null @@ -1,56 +0,0 @@ -import subprocess -import logging - -# Various plugin utility functions - -# Plugin helper function to get plugin settings in the following order: -# -# 1. from job features with -# 2. from slurm-gcp config -# 3. If provided, the default -# 4. None - - -def get_plugin_setting(plugin, setting, lkp, job, default=None): - features = get_job_features(job) - if f"{plugin}.{setting}" in features: - return features[f"{plugin}.{setting}"] - - if "enable_slurm_gcp_plugins" in lkp.cfg: - if plugin in lkp.cfg.enable_slurm_gcp_plugins: - try: - iter(lkp.cfg.enable_slurm_gcp_plugins[plugin]) - except TypeError: - # not iterable - 1 - else: - if setting in lkp.cfg.enable_slurm_gcp_plugins[plugin]: - return lkp.cfg.enable_slurm_gcp_plugins[plugin][setting] - - return default - - -# Plugin helper function to get job features -def get_job_features(job): - if job is None: - return {} - - features = {} - res, output = subprocess.getstatusoutput(f"squeue -h -o %f -j {job}") - if res == 0: - for feature in output.split("&"): - kv = feature.split("=", 1) - v = None - if len(kv) == 2: - v = kv[1] - features[kv[0]] = v - else: - logging.error("Unable to retrieve features of job:{job}") - - return features - - -__all__ = [ - "get_plugin_setting", - "get_job_features", -] diff --git a/scripts/slurmsync.py b/scripts/slurmsync.py deleted file mode 100755 index 53af894c..00000000 --- a/scripts/slurmsync.py +++ /dev/null @@ -1,575 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (C) SchedMD LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import datetime -import fcntl -import hashlib -import json -import logging -import re -import sys -from enum import Enum -from itertools import chain -from pathlib import Path -import yaml - -import util -from util import ( - batch_execute, - ensure_execute, - execute_with_futures, - fetch_config_yaml, - fetch_config_yaml_md5, - install_custom_scripts, - load_config_file, - run, - save_config, - separate, - to_hostlist_fast, - Lookup, - NSDict, - TPU, - chunked, -) -from util import lkp, cfg, compute, CONFIG_FILE -from suspend import delete_instances -from resume import start_tpu -from conf import ( - gen_cloud_conf, - gen_cloud_gres_conf, - gen_topology_conf, - install_slurm_conf, - install_slurmdbd_conf, - install_gres_conf, - install_cgroup_conf, - install_topology_conf, -) - -filename = Path(__file__).name -LOGFILE = (Path(cfg.slurm_log_dir if cfg else ".") / filename).with_suffix(".log") - -log = logging.getLogger(filename) - -TOT_REQ_CNT = 1000 - - -NodeStatus = Enum( - "NodeStatus", - ( - "orphan", - "power_down", - "preempted", - "restore", - "resume", - "terminated", - "unbacked", - "unchanged", - "unknown", - ), -) - - -def start_instance_op(inst, project=None): - project = project or lkp.project - return compute.instances().start( - project=project, - zone=lkp.instance(inst).zone, - instance=inst, - ) - - -def start_instances(node_list): - log.info("{} instances to start ({})".format(len(node_list), ",".join(node_list))) - - normal, tpu_nodes = separate(lkp.node_is_tpu, node_list) - invalid, valid = separate(lambda inst: bool(lkp.instance), normal) - - ops = {inst: start_instance_op(inst) for inst in valid} - - done, failed = batch_execute(ops) - - tpu_start_data = [] - for ns, nodes in util.groupby_unsorted(tpu_nodes, lkp.node_nodeset_name): - tpuobj = TPU(lkp.cfg.nodeset_tpu[ns]) - for snodes in chunked(nodes, n=tpuobj.vmcount): - tpu_start_data.append({"tpu": tpuobj, "node": snodes}) - execute_with_futures(start_tpu, tpu_start_data) - - -def _find_dynamic_node_status() -> NodeStatus: - # TODO: cover more cases: - # * delete dead dynamic nodes - # * delete orhpaned instances - return NodeStatus.unchanged # don't touch dynamic nodes - - -def _find_tpu_node_status(nodename, state): - ns = lkp.node_nodeset(nodename) - tpuobj = TPU(ns) - inst = tpuobj.get_node(nodename) - # If we do not find the node but it is from a Tpu that has multiple vms look for the master node - if inst is None and tpuobj.vmcount > 1: - # Get the tpu slurm nodelist of the nodes in the same tpu group as nodename - nodelist = run( - f"{lkp.scontrol} show topo {nodename}" - + " | awk -F'=' '/Level=0/ { print $NF }'", - shell=True, - ).stdout - l_nodelist = util.to_hostnames(nodelist) - group_names = set(l_nodelist) - # get the list of all the existing tpus in the nodeset - tpus_list = set(tpuobj.list_node_names()) - # In the intersection there must be only one node that is the master - tpus_int = list(group_names.intersection(tpus_list)) - if len(tpus_int) > 1: - log.error( - f"More than one cloud tpu node for tpu group {nodelist}, there should be only one that should be {l_nodelist[0]}, but we have found {tpus_int}" - ) - return NodeStatus.unknown - if len(tpus_int) == 1: - inst = tpuobj.get_node(tpus_int[0]) - # if len(tpus_int ==0) this case is not relevant as this would be the case always that a TPU group is not running - if inst is None: - if state.base == "DOWN" and "POWERED_DOWN" in state.flags: - return NodeStatus.restore - if "POWERING_DOWN" in state.flags: - return NodeStatus.restore - if "COMPLETING" in state.flags: - return NodeStatus.unbacked - if state.base != "DOWN" and not ( - set(("POWER_DOWN", "POWERING_UP", "POWERING_DOWN", "POWERED_DOWN")) - & state.flags - ): - return NodeStatus.unbacked - if lkp.is_static_node(nodename): - return NodeStatus.resume - elif ( - state is not None - and "POWERED_DOWN" not in state.flags - and "POWERING_DOWN" not in state.flags - and inst.state == TPU.State.STOPPED - ): - if tpuobj.preemptible: - return NodeStatus.preempted - if not state.base.startswith("DOWN"): - return NodeStatus.terminated - elif ( - state is None or "POWERED_DOWN" in state.flags - ) and inst.state == TPU.State.READY: - return NodeStatus.orphan - elif state is None: - # if state is None here, the instance exists but it's not in Slurm - return NodeStatus.unknown - - return NodeStatus.unchanged - - -def allow_power_down(state): - config = run(f"{lkp.scontrol} show config").stdout.rstrip() - m = re.search(r"SuspendExcStates\s+=\s+(?P[\w\(\)]+)", config) - if not m: - log.warning("SuspendExcStates not found in Slurm config") - return True - states = set(m.group("states").split(",")) - if "(null)" in states or bool(state & state.flags.union(state.base)): - return False - return True - - -def find_node_status(nodename): - """Determine node/instance status that requires action""" - state = lkp.slurm_node(nodename) - - if lkp.node_is_dyn(nodename): - return _find_dynamic_node_status() - - if lkp.node_is_tpu(nodename): - return _find_tpu_node_status(nodename, state) - - # split below is workaround for VMs whose hostname is FQDN - inst = lkp.instance(nodename.split(".")[0]) - power_flags = frozenset( - ("POWER_DOWN", "POWERING_UP", "POWERING_DOWN", "POWERED_DOWN") - ) & (state.flags if state is not None else set()) - - if inst is None: - if "POWERING_UP" in state.flags: - return NodeStatus.unchanged - if state.base == "DOWN" and "POWERED_DOWN" in state.flags: - return NodeStatus.restore - if "POWERING_DOWN" in state.flags: - return NodeStatus.restore - if "COMPLETING" in state.flags: - return NodeStatus.unbacked - if state.base != "DOWN" and not power_flags: - return NodeStatus.unbacked - if state.base == "DOWN" and not power_flags and allow_power_down(state): - return NodeStatus.power_down - if "POWERED_DOWN" in state.flags and lkp.is_static_node(nodename): - return NodeStatus.resume - elif ( - state is not None - and "POWERED_DOWN" not in state.flags - and "POWERING_DOWN" not in state.flags - and inst.status == "TERMINATED" - ): - if inst.scheduling.preemptible: - return NodeStatus.preempted - if not state.base.startswith("DOWN"): - return NodeStatus.terminated - elif (state is None or "POWERED_DOWN" in state.flags) and inst.status == "RUNNING": - log.info("%s is potential orphan node", nodename) - age_threshold_seconds = 90 - inst_seconds_old = _seconds_since_timestamp(inst.creationTimestamp) - log.info("%s state: %s, age: %0.1fs", nodename, state, inst_seconds_old) - if inst_seconds_old < age_threshold_seconds: - log.info( - "%s not marked as orphan, it started less than %ds ago (%0.1fs)", - nodename, - age_threshold_seconds, - inst_seconds_old, - ) - return NodeStatus.unchanged - return NodeStatus.orphan - elif state is None: - # if state is None here, the instance exists but it's not in Slurm - return NodeStatus.unknown - - return NodeStatus.unchanged - - -def _seconds_since_timestamp(timestamp): - """Returns duration in seconds since a timestamp - Args: - timestamp: A formatted timestamp string (%Y-%m-%dT%H:%M:%S.%f%z) - Returns: - number of seconds that have past since the timestamp (float) - """ - if timestamp[-3] == ":": # python 36 datetime does not support the colon - timestamp = timestamp[:-3] + timestamp[-2:] - creation_dt = datetime.datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f%z") - return datetime.datetime.now().timestamp() - creation_dt.timestamp() - - -def do_node_update(status, nodes): - """update node/instance based on node status""" - if status == NodeStatus.unchanged: - return - count = len(nodes) - hostlist = util.to_hostlist(nodes) - - def nodes_down(): - """down nodes""" - log.info( - f"{count} nodes set down due to node status '{status.name}' ({hostlist})" - ) - run( - f"{lkp.scontrol} update nodename={hostlist} state=down reason='Instance stopped/deleted'" - ) - - def nodes_restart(): - """start instances for nodes""" - log.info(f"{count} instances restarted ({hostlist})") - start_instances(nodes) - - def nodes_idle(): - """idle nodes""" - log.info(f"{count} nodes to idle ({hostlist})") - run(f"{lkp.scontrol} update nodename={hostlist} state=resume") - - def nodes_resume(): - """resume nodes via scontrol""" - log.info(f"{count} instances to resume ({hostlist})") - run(f"{lkp.scontrol} update nodename={hostlist} state=power_up") - - def nodes_delete(): - """delete instances for nodes""" - log.info(f"{count} instances to delete ({hostlist})") - delete_instances(nodes) - - def nodes_power_down(): - """power_down node in slurm""" - log.info(f"{count} instances to power down ({hostlist})") - run(f"{lkp.scontrol} update nodename={hostlist} state=power_down") - - def nodes_unknown(): - """Error status, nodes shouldn't get in this status""" - log.error(f"{count} nodes have unexpected status: ({hostlist})") - first = next(iter(nodes)) - state = lkp.slurm_node(first) - state = "{}+{}".format(state.base, "+".join(state.flags)) if state else "None" - inst = lkp.instance(first) - log.error(f"{first} state: {state}, instance status:{inst.status}") - - update = dict.get( - { - NodeStatus.orphan: nodes_delete, - NodeStatus.power_down: nodes_power_down, - NodeStatus.preempted: lambda: (nodes_down(), nodes_restart()), - NodeStatus.restore: nodes_idle, - NodeStatus.resume: nodes_resume, - NodeStatus.terminated: nodes_down, - NodeStatus.unbacked: nodes_down, - NodeStatus.unchanged: lambda: None, - NodeStatus.unknown: nodes_unknown, - }, - status, - ) - update() - - -def delete_placement_groups(placement_groups): - def delete_placement_request(pg_name, region): - return compute.resourcePolicies().delete( - project=lkp.project, region=region, resourcePolicy=pg_name - ) - - requests = { - pg.name: delete_placement_request(pg["name"], util.trim_self_link(pg["region"])) - for pg in placement_groups - } - - def swallow_err(_: str) -> None: - pass - - done, failed = batch_execute(requests, log_err=swallow_err) - if failed: - # Filter out resourceInUseByAnotherResource errors , they are expected to happen - def ignore_err(e) -> bool: - return "resourceInUseByAnotherResource" in str(e) - - failures = [f"{n}: {e}" for n, (_, e) in failed.items() if not ignore_err(e)] - if failures: - log.error(f"some placement groups failed to delete: {failures}") - log.info( - f"deleted {len(done)} of {len(placement_groups)} placement groups ({to_hostlist_fast(done.keys())})" - ) - - -def sync_placement_groups(): - """Delete placement policies that are for jobs that have completed/terminated""" - keep_states = frozenset( - [ - "RUNNING", - "CONFIGURING", - "STOPPED", - "SUSPENDED", - "COMPLETING", - ] - ) - - if lkp.instance_role_safe != "controller": - return - - keep_jobs = { - str(job["job_id"]) - for job in json.loads(run(f"{lkp.scontrol} show jobs --json").stdout)["jobs"] - if "job_state" in job and set(job["job_state"]) & keep_states - } - keep_jobs.add("0") # Job 0 is a placeholder for static node placement - - fields = "items.regions.resourcePolicies,nextPageToken" - flt = f"name={lkp.cfg.slurm_cluster_name}-*" - act = compute.resourcePolicies() - op = act.aggregatedList(project=lkp.project, fields=fields, filter=flt) - placement_groups = {} - pg_regex = re.compile( - rf"{lkp.cfg.slurm_cluster_name}-(?P[^\s\-]+)-(?P\d+)-(?P\d+)" - ) - while op is not None: - result = ensure_execute(op) - # merge placement group info from API and job_id,partition,index parsed from the name - pgs = ( - NSDict({**pg, **pg_regex.match(pg["name"]).groupdict()}) - for pg in chain.from_iterable( - item["resourcePolicies"] - for item in result.get("items", {}).values() - if item - ) - if pg_regex.match(pg["name"]) is not None - ) - placement_groups.update( - {pg["name"]: pg for pg in pgs if pg.get("job_id") not in keep_jobs} - ) - op = act.aggregatedList_next(op, result) - - if len(placement_groups) > 0: - delete_placement_groups(list(placement_groups.values())) - - -def sync_slurm(): - if lkp.instance_role_safe != "controller": - return - - compute_instances = [ - name for name, inst in lkp.instances().items() if inst.role == "compute" - ] - slurm_nodes = list(lkp.slurm_nodes().keys()) - - all_nodes = list( - set( - chain( - compute_instances, - slurm_nodes, - ) - ) - ) - log.debug( - f"reconciling {len(compute_instances)} ({len(all_nodes)-len(compute_instances)}) GCP instances and {len(slurm_nodes)} Slurm nodes ({len(all_nodes)-len(slurm_nodes)})." - ) - node_statuses = { - k: list(v) for k, v in util.groupby_unsorted(all_nodes, find_node_status) - } - if log.isEnabledFor(logging.DEBUG): - status_nodelist = { - status.name: to_hostlist_fast(nodes) - for status, nodes in node_statuses.items() - } - log.debug(f"node statuses: \n{yaml.safe_dump(status_nodelist).rstrip()}") - - for status, nodes in node_statuses.items(): - do_node_update(status, nodes) - - -def read_hash(filename): - filename = Path(filename) - if not filename.exists(): - return None - with open(filename, "r", encoding="utf-8") as file: - return file.readline() - - -def save_hash(filename, hash): - with open(filename, "w+", encoding="utf-8") as file: - file.write(hash) - - -def reconfigure_slurm(): - CONFIG_HASH = Path("/slurm/scripts/.config.hash") - update_msg = "*** slurm configuration was updated ***" - cfg_old = load_config_file(CONFIG_FILE) - - if cfg_old.hybrid: - # terraform handles generating the config.yaml, don't do it here - return - - hash_new: hashlib.md5 = fetch_config_yaml_md5() - hash_old: str = read_hash(CONFIG_HASH) - - if hash_new.hexdigest() != hash_old: - log.debug("Delta detected. Reconfiguring Slurm now.") - cfg_new = fetch_config_yaml() - save_hash(CONFIG_HASH, hash_new.hexdigest()) - save_config(cfg_new, CONFIG_FILE) - cfg_new = load_config_file(CONFIG_FILE) - lkp = Lookup(cfg_new) - util.lkp = lkp - if lkp.instance_role_safe == "controller": - install_slurm_conf(lkp) - install_slurmdbd_conf(lkp) - gen_cloud_conf(lkp) - gen_cloud_gres_conf(lkp) - gen_topology_conf(lkp) - install_gres_conf(lkp) - install_cgroup_conf(lkp) - install_topology_conf(lkp) - log.info("Restarting slurmctld to make changes take effect.") - try: - run("sudo systemctl restart slurmctld.service", check=False) - run(f"{lkp.scontrol} reconfigure", timeout=30) - except Exception as e: - log.error(e) - util.run(f"wall '{update_msg}'", timeout=30) - log.debug("Done.") - elif lkp.instance_role_safe in ["compute", "login"]: - log.info("Restarting slurmd to make changes take effect.") - run("systemctl restart slurmd") - util.run(f"wall '{update_msg}'", timeout=30) - log.debug("Done.") - - -def main(): - try: - reconfigure_slurm() - except Exception: - log.exception("failed to reconfigure slurm") - - try: - sync_slurm() - except Exception: - log.exception("failed to sync instances") - - try: - sync_placement_groups() - except Exception: - log.exception("failed to sync placement groups") - - try: - install_custom_scripts(check_hash=True) - except Exception: - log.exception("failed to sync custom scripts") - - -parser = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter -) -parser.add_argument( - "--debug", - "-d", - dest="loglevel", - action="store_const", - const=logging.DEBUG, - default=logging.INFO, - help="Enable debugging output", -) -parser.add_argument( - "--trace-api", - "-t", - action="store_true", - help="Enable detailed api request output", -) -parser.add_argument( - "--force", - "-f", - action="store_true", - help="Force tasks to run, regardless of lock.", -) - -if __name__ == "__main__": - args = parser.parse_args() - util.chown_slurm(LOGFILE, mode=0o600) - - if cfg.enable_debug_logging: - args.loglevel = logging.DEBUG - if args.trace_api: - cfg.extra_logging_flags = list(cfg.extra_logging_flags) - cfg.extra_logging_flags.append("trace_api") - util.config_root_logger(filename, level=args.loglevel, logfile=LOGFILE) - - sys.excepthook = util.handle_exception - - # only run one instance at a time unless --force - if args.force: - main() - else: - pid_file = (Path("/tmp") / Path(__file__).name).with_suffix(".pid") - with pid_file.open("w") as fp: - try: - fcntl.lockf(fp, fcntl.LOCK_EX | fcntl.LOCK_NB) - main() - except BlockingIOError: - sys.exit(0) diff --git a/scripts/startup.sh b/scripts/startup.sh deleted file mode 100755 index 9918411a..00000000 --- a/scripts/startup.sh +++ /dev/null @@ -1,147 +0,0 @@ -#!/bin/bash -# Copyright (C) SchedMD LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -e - -SLURM_DIR=/slurm -FLAGFILE=$SLURM_DIR/slurm_configured_do_not_remove -SCRIPTS_DIR=$SLURM_DIR/scripts -if [[ -z "$HOME" ]]; then - # google-startup-scripts.service lacks environment variables - HOME="$(getent passwd "$(whoami)" | cut -d: -f6)" -fi - -METADATA_SERVER="metadata.google.internal" -URL="http://$METADATA_SERVER/computeMetadata/v1" -CURL="curl -sS --fail --header Metadata-Flavor:Google" - -PING_METADATA="ping -q -w1 -c1 $METADATA_SERVER" -echo "INFO: $PING_METADATA" -for i in $(seq 10); do - [ $i -gt 1 ] && sleep 5; - $PING_METADATA > /dev/null && s=0 && break || s=$?; - echo "ERROR: Failed to contact metadata server, will retry" -done -if [ $s -ne 0 ]; then - echo "ERROR: Unable to contact metadata server, aborting" - wall -n '*** Slurm setup failed in the startup script! see `journalctl -u google-startup-scripts` ***' - exit 1 -else - echo "INFO: Successfully contacted metadata server" -fi - -PING_GOOGLE="ping -q -w1 -c1 8.8.8.8" -echo "INFO: $PING_GOOGLE" -for i in $(seq 5); do - [ $i -gt 1 ] && sleep 2; - $PING_GOOGLE > /dev/null && s=0 && break || s=$?; - echo "failed to ping Google DNS, will retry" -done -if [ $s -ne 0 ]; then - echo "WARNING: No internet access detected" -else - echo "INFO: Internet access detected" -fi - -mkdir -p $SCRIPTS_DIR -UNIVERSE_DOMAIN="$($CURL $URL/instance/attributes/universe_domain)" -BUCKET="$($CURL $URL/instance/attributes/slurm_bucket_path)" -if [[ -z $BUCKET ]]; then - echo "ERROR: No bucket path detected." - exit 1 -fi - -SCRIPTS_ZIP="$HOME/slurm-gcp-scripts.zip" -export CLOUDSDK_CORE_UNIVERSE_DOMAIN="$UNIVERSE_DOMAIN" -until gcloud storage cp "$BUCKET/slurm-gcp-devel.zip" "$SCRIPTS_ZIP"; do - echo "WARN: Could not download SlurmGCP scripts, retrying in 5 seconds." - sleep 5 -done -unzip -o "$SCRIPTS_ZIP" -d "$SCRIPTS_DIR" -rm -rf "$SCRIPTS_ZIP" - -#temporary hack to not make the script fail on TPU vm -chown slurm:slurm -R "$SCRIPTS_DIR" || true -chmod 700 -R "$SCRIPTS_DIR" - - -if [ -f $FLAGFILE ]; then - echo "WARNING: Slurm was previously configured, quitting" - exit 0 -fi -touch $FLAGFILE - -function tpu_setup { - #allow the following command to fail, as this attribute does not exist for regular nodes - docker_image=$($CURL $URL/instance/attributes/slurm_docker_image 2> /dev/null || true) - if [ -z $docker_image ]; then #Not a tpu node, do not do anything - return - fi - if [ "$OS_ENV" == "slurm_container" ]; then #Already inside the slurm container, we should continue starting - return - fi - - #given a input_string like "WORKER_0:Joseph;WORKER_1:richard;WORKER_2:edward;WORKER_3:john" and a number 1, this function will print richard - parse_metadata() { - local number=$1 - local input_string=$2 - local word=$(echo "$input_string" | awk -v n="$number" -F ':|;' '{ for (i = 1; i <= NF; i+=2) if ($(i) == "WORKER_"n) print $(i+1) }') - echo "$word" - } - - input_string=$($CURL $URL/instance/attributes/slurm_names) - worker_id=$($CURL $URL/instance/attributes/tpu-env | awk '/WORKER_ID/ {print $2}' | tr -d \') - real_name=$(parse_metadata $worker_id $input_string) - - #Prepare to docker pull with gcloud - mkdir -p /root/.docker - cat << EOF > /root/.docker/config.json -{ - "credHelpers": { - "gcr.io": "gcloud", - "us-docker.pkg.dev": "gcloud" - } -} -EOF - #cgroup detection - CGV=1 - CGROUP_FLAGS="-v /sys/fs/cgroup:/sys/fs/cgroup:rw" - if [ -f /sys/fs/cgroup/cgroup.controllers ]; then #CGV2 - CGV=2 - fi - if [ $CGV == 2 ]; then - CGROUP_FLAGS="--cgroup-parent=docker.slice --cgroupns=private --tmpfs /run --tmpfs /run/lock --tmpfs /tmp" - if [ ! -f /etc/systemd/system/docker.slice ]; then #In case that there is no slice prepared for hosting the containers create it - printf "[Unit]\nDescription=docker slice\nBefore=slices.target\n[Slice]\nCPUAccounting=true\nMemoryAccounting=true" > /etc/systemd/system/docker.slice - systemctl start docker.slice - fi - fi - #for the moment always use --privileged, as systemd might not work properly otherwise - TPU_FLAGS="--privileged" - # TPU_FLAGS="--cap-add SYS_RESOURCE --device /dev/accel0 --device /dev/accel1 --device /dev/accel2 --device /dev/accel3" - # if [ $CGV == 2 ]; then #In case that we are in CGV2 for systemd to work correctly for the moment we go with privileged - # TPU_FLAGS="--privileged" - # fi - - docker run -d $CGROUP_FLAGS $TPU_FLAGS --net=host --name=slurmd --hostname=$real_name --entrypoint=/usr/bin/systemd --restart unless-stopped $docker_image - exit 0 -} - -tpu_setup #will do nothing for normal nodes or the container spawned inside TPU - -echo "INFO: Running python cluster setup script" -SETUP_SCRIPT_FILE=$SCRIPTS_DIR/setup.py -chmod +x $SETUP_SCRIPT_FILE -exec $SETUP_SCRIPT_FILE diff --git a/scripts/suspend.py b/scripts/suspend.py deleted file mode 100755 index af70d976..00000000 --- a/scripts/suspend.py +++ /dev/null @@ -1,184 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (C) SchedMD LLC. -# Copyright 2015 Google Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List -import argparse -import logging -import sys -from pathlib import Path - -import util -from util import ( - groupby_unsorted, - log_api_request, - batch_execute, - to_hostlist_fast, - wait_for_operations, - separate, - execute_with_futures, -) -from util import lkp, cfg, compute, TPU - -import slurm_gcp_plugins - -filename = Path(__file__).name -LOGFILE = (Path(cfg.slurm_log_dir if cfg else ".") / filename).with_suffix(".log") -log = logging.getLogger(filename) - -TOT_REQ_CNT = 1000 - - -def truncate_iter(iterable, max_count): - end = "..." - _iter = iter(iterable) - for i, el in enumerate(_iter, start=1): - if i >= max_count: - yield end - break - yield el - - -def delete_instance_request(instance, project=None, zone=None): - project = project or lkp.project - request = compute.instances().delete( - project=project, - zone=(zone or lkp.instance(instance).zone), - instance=instance, - ) - log_api_request(request) - return request - - -def stop_tpu(data): - tpu_nodeset = data["nodeset"] - node = data["node"] - tpu = data["tpu"] - if tpu_nodeset.preserve_tpu and tpu.vmcount == 1: - log.info(f"stopping node {node}") - if tpu.stop_node(node): - return - log.error("Error stopping node {node} will delete instead") - log.info(f"deleting node {node}") - if not tpu.delete_node(node): - log.error("Error deleting node {node}") - - -def delete_tpu_instances(instances): - stop_data = [] - for prefix, nodes in util.groupby_unsorted(instances, lkp.node_prefix): - log.info(f"Deleting TPU nodes from prefix {prefix}") - lnodes = list(nodes) - tpu_nodeset = lkp.node_nodeset(lnodes[0]) - tpu = TPU(tpu_nodeset) - stop_data.extend( - [{"tpu": tpu, "node": node, "nodeset": tpu_nodeset} for node in lnodes] - ) - execute_with_futures(stop_tpu, stop_data) - - -def delete_instances(instances): - """delete instances individually""" - invalid, valid = separate(lambda inst: bool(lkp.instance(inst)), instances) - if len(invalid) > 0: - log.debug("instances do not exist: {}".format(",".join(invalid))) - if len(valid) == 0: - log.debug("No instances to delete") - return - - requests = {inst: delete_instance_request(inst) for inst in valid} - - log.info(f"delete {len(valid)} instances ({to_hostlist_fast(valid)})") - done, failed = batch_execute(requests) - if failed: - for err, nodes in groupby_unsorted(lambda n: failed[n][1], failed.keys()): - log.error(f"instances failed to delete: {err} ({to_hostlist_fast(nodes)})") - wait_for_operations(done.values()) - # TODO do we need to check each operation for success? That is a lot more API calls - log.info(f"deleted {len(done)} instances {to_hostlist_fast(done.keys())}") - - -def suspend_nodes(nodes: List[str]) -> None: - tpu_nodes, other_nodes = [], [] - for node in nodes[:]: - if lkp.node_is_tpu(node): - tpu_nodes.append(node) - else: - other_nodes.append(node) - - delete_instances(other_nodes) - delete_tpu_instances(tpu_nodes) - - -def main(nodelist): - """main called when run as script""" - log.debug(f"SuspendProgram {nodelist}") - - # Filter out nodes not in config.yaml - other_nodes, pm_nodes = separate( - lkp.is_power_managed_node, util.to_hostnames(nodelist) - ) - if other_nodes: - log.debug( - f"Ignoring non-power-managed nodes '{to_hostlist_fast(other_nodes)}' from '{nodelist}'" - ) - if pm_nodes: - log.debug(f"Suspending nodes '{to_hostlist_fast(pm_nodes)}' from '{nodelist}'") - else: - log.debug("No cloud nodes to suspend") - return - - log.info(f"suspend {nodelist}") - if lkp.cfg.enable_slurm_gcp_plugins: - slurm_gcp_plugins.pre_main_suspend_nodes(lkp=lkp, nodelist=nodelist) - suspend_nodes(pm_nodes) - - -parser = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter -) -parser.add_argument("nodelist", help="list of nodes to suspend") -parser.add_argument( - "--debug", - "-d", - dest="loglevel", - action="store_const", - const=logging.DEBUG, - default=logging.INFO, - help="Enable debugging output", -) -parser.add_argument( - "--trace-api", - "-t", - action="store_true", - help="Enable detailed api request output", -) - - -if __name__ == "__main__": - args = parser.parse_args() - - if cfg.enable_debug_logging: - args.loglevel = logging.DEBUG - if args.trace_api: - cfg.extra_logging_flags = list(cfg.extra_logging_flags) - cfg.extra_logging_flags.append("trace_api") - util.chown_slurm(LOGFILE, mode=0o600) - util.config_root_logger(filename, level=args.loglevel, logfile=LOGFILE) - log = logging.getLogger(Path(__file__).name) - sys.excepthook = util.handle_exception - - main(args.nodelist) diff --git a/scripts/tests/README.md b/scripts/tests/README.md deleted file mode 100644 index 8452813f..00000000 --- a/scripts/tests/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# Unit tests - -```sh -# cwd is scripts/tests -$ pytest -W ignore::DeprecationWarning -``` diff --git a/scripts/tests/test_topology.py b/scripts/tests/test_topology.py deleted file mode 100644 index fc1f249c..00000000 --- a/scripts/tests/test_topology.py +++ /dev/null @@ -1,120 +0,0 @@ -from typing import Optional -import mock -import sys - -if ".." not in sys.path: - sys.path.append("..") # TODO: make this more robust -import util -import conf - -from dataclasses import dataclass, field -import tempfile - - -# TODO: use "real" classes once they are defined (instead of NSDict) -@dataclass -class TstNodeset: - nodeset_name: str - node_count_static: int = 0 - node_count_dynamic_max: int = 0 - - -@dataclass -class TstCfg: - slurm_cluster_name: str = "m22" - nodeset: dict[str, TstNodeset] = field(default_factory=dict) - nodeset_tpu: dict[str, TstNodeset] = field(default_factory=dict) - output_dir: Optional[str] = None - - -@dataclass -class TstTPU: # to prevent client initialization durint "TPU.__init__" - vmcount: int - - -def make_to_hostnames_mock(tbl: Optional[dict[str, list[str]]]): - tbl = tbl or {} - - def se(k: str) -> list[str]: - if k not in tbl: - raise AssertionError(f"to_hostnames mock: unexpected nodelist: '{k}'") - return tbl[k] - - return se - - -def test_gen_topology_conf_empty(): - cfg = TstCfg(output_dir=tempfile.mkdtemp()) - conf.gen_topology_conf(util.Lookup(cfg)) - assert ( - open(cfg.output_dir + "/cloud_topology.conf").read() - == """ -# Warning: -# This file is managed by a script. Manual modifications will be overwritten. - - -""" - ) - - -@mock.patch("util.TPU") -@mock.patch( - "util.to_hostnames", - side_effect=make_to_hostnames_mock( - { - "m22-bold-[0-3]": ["m22-bold-0", "m22-bold-1", "m22-bold-2", "m22-bold-3"], - "m22-bold-[4-8]": [ - "m22-bold-4", - "m22-bold-5", - "m22-bold-6", - "m22-bold-7", - "m22-bold-8", - ], - "m22-slim-[0-2]": ["m22-slim-0", "m22-slim-1", "m22-slim-2"], - } - ), -) -def test_gen_topology_conf(to_hostnames_mock, tpu_mock): - cfg = TstCfg( - nodeset_tpu={ - "a": TstNodeset("bold", node_count_static=4, node_count_dynamic_max=5), - "b": TstNodeset("slim", node_count_dynamic_max=3), - }, - nodeset={ - "c": TstNodeset("green", node_count_static=2, node_count_dynamic_max=3), - "d": TstNodeset("blue", node_count_static=7), - "e": TstNodeset("pink", node_count_dynamic_max=4), - }, - output_dir=tempfile.mkdtemp(), - ) - - def tpu_se(ns: TstNodeset) -> TstTPU: - if ns.nodeset_name == "bold": - return TstTPU(vmcount=3) - if ns.nodeset_name == "slim": - return TstTPU(vmcount=1) - raise AssertionError(f"unexpected TPU name: '{ns.nodeset_name}'") - - tpu_mock.side_effect = tpu_se - - conf.gen_topology_conf(util.Lookup(cfg)) - assert ( - open(cfg.output_dir + "/cloud_topology.conf").read() - == """ -# Warning: -# This file is managed by a script. Manual modifications will be overwritten. - -SwitchName=nodeset-root Switches=blue,green,pink -SwitchName=blue Nodes=m22-blue-[0-6] -SwitchName=green Nodes=m22-green-[0-4] -SwitchName=pink Nodes=m22-pink-[0-3] -SwitchName=nodeset_tpu-root Switches=bold,slim -SwitchName=bold Switches=bold-[0-3] -SwitchName=bold-0 Nodes=m22-bold-[0-2] -SwitchName=bold-1 Nodes=m22-bold-3 -SwitchName=bold-2 Nodes=m22-bold-[4-6] -SwitchName=bold-3 Nodes=m22-bold-[7-8] -SwitchName=slim Nodes=m22-slim-[0-2] - -""" - ) diff --git a/scripts/tests/test_util.py b/scripts/tests/test_util.py deleted file mode 100644 index 9c3a03c2..00000000 --- a/scripts/tests/test_util.py +++ /dev/null @@ -1,148 +0,0 @@ -import sys -import pytest - -if ".." not in sys.path: - sys.path.append("..") # TODO: make this more robust -import util -from google.api_core.client_options import ClientOptions # noqa: E402 - -# Note: need to install pytest-mock - - -@pytest.mark.parametrize( - "name,expected", - [ - ( - "az-buka-23", - { - "cluster": "az", - "nodeset": "buka", - "node": "23", - "prefix": "az-buka", - "range": None, - "suffix": "23", - }, - ), - ( - "az-buka-xyzf", - { - "cluster": "az", - "nodeset": "buka", - "node": "xyzf", - "prefix": "az-buka", - "range": None, - "suffix": "xyzf", - }, - ), - ( - "az-buka-[2-3]", - { - "cluster": "az", - "nodeset": "buka", - "node": "[2-3]", - "prefix": "az-buka", - "range": "[2-3]", - "suffix": None, - }, - ), - ], -) -def test_node_desc(name, expected): - assert util.lkp._node_desc(name) == expected - - -@pytest.mark.parametrize( - "name", - [ - "az-buka", - ], -) -def test_node_desc_fail(name): - with pytest.raises(Exception): - util.lkp._node_desc(name) - - -@pytest.mark.parametrize( - "names,expected", - [ - ("pedro,pedro-1,pedro-2,pedro-01,pedro-02", "pedro,pedro-[1-2,01-02]"), - ("pedro,,pedro-1,,pedro-2", "pedro,pedro-[1-2]"), - ("pedro-8,pedro-9,pedro-10,pedro-11", "pedro-[8-9,10-11]"), - ("pedro-08,pedro-09,pedro-10,pedro-11", "pedro-[08-11]"), - ("pedro-08,pedro-09,pedro-8,pedro-9", "pedro-[8-9,08-09]"), - ("pedro-10,pedro-08,pedro-09,pedro-8,pedro-9", "pedro-[8-9,08-10]"), - ("pedro-8,pedro-9,juan-10,juan-11", "juan-[10-11],pedro-[8-9]"), - ("az,buki,vedi", "az,buki,vedi"), - ("a0,a1,a2,a3,a4,a5,a6,a7,a8,a9,a10,a11,a12", "a[0-9,10-12]"), - ("a0,a2,a4,a6,a7,a8,a11,a12", "a[0,2,4,6-8,11-12]"), - ("seas7-0,seas7-1", "seas7-[0-1]"), - ], -) -def test_to_hostlist_fast(names, expected): - assert util.to_hostlist_fast(names.split(",")) == expected - - -@pytest.mark.parametrize( - "api,ep_ver,expected", - [ - ( - util.ApiEndpoint.BQ, - "v1", - ClientOptions( - api_endpoint="https://bq.googleapis.com/v1/", - universe_domain="googleapis.com", - ), - ), - ( - util.ApiEndpoint.COMPUTE, - "staging_v1", - ClientOptions( - api_endpoint="https://compute.googleapis.com/staging_v1/", - universe_domain="googleapis.com", - ), - ), - ( - util.ApiEndpoint.SECRET, - "v1", - ClientOptions( - api_endpoint="https://secret_manager.googleapis.com/v1/", - universe_domain="googleapis.com", - ), - ), - ( - util.ApiEndpoint.STORAGE, - "beta", - ClientOptions( - api_endpoint="https://storage.googleapis.com/beta/", - universe_domain="googleapis.com", - ), - ), - ( - util.ApiEndpoint.TPU, - "alpha", - ClientOptions( - api_endpoint="https://tpu.googleapis.com/alpha/", - universe_domain="googleapis.com", - ), - ), - ], -) -def test_create_client_options( - api: util.ApiEndpoint, ep_ver: str, expected: ClientOptions, mocker -): - ud_mock = mocker.patch("util.universe_domain") - ep_mock = mocker.patch("util.endpoint_version") - ud_mock.return_value = "googleapis.com" - ep_mock.return_value = ep_ver - co = util.create_client_options(api) - assert ( - co.api_endpoint == expected.api_endpoint - and co.universe_domain == expected.universe_domain - ) - ud_mock.return_value = None - ep_mock.return_value = None - co = util.create_client_options(api) - assert ( - co.api_endpoint != expected.api_endpoint - and co.universe_domain != expected.universe_domain - ) diff --git a/scripts/util.py b/scripts/util.py deleted file mode 100755 index e2d9c710..00000000 --- a/scripts/util.py +++ /dev/null @@ -1,2083 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (C) SchedMD LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Iterable, List, Tuple, Optional -import argparse -import base64 -import collections -import hashlib -import importlib.util -import inspect -import json -import logging -import logging.config -import math -import os -import re -import shelve -import shlex -import shutil -import socket -import subprocess -import sys -import tempfile -from enum import Enum -from collections import defaultdict, namedtuple -from concurrent.futures import ThreadPoolExecutor, as_completed -from contextlib import contextmanager -from functools import lru_cache, reduce, wraps -from itertools import chain, compress, islice -from pathlib import Path -from time import sleep, time - -import slurm_gcp_plugins - -required_modules = [ - ("googleapiclient", "google-api-python-client"), - ("requests", "requests"), - ("yaml", "yaml"), - ("addict", "addict"), - ("httplib2", "httplib2"), - ("google.cloud.tpu_v2", "google-cloud-tpu"), -] -missing_imports = False -can_tpu = True -for module, name in required_modules: - if importlib.util.find_spec(module) is None: - if module == "google.cloud.tpu_v2": - can_tpu = False - print( - f"WARNING: Missing Python module '{module} (pip:{name})', TPU support will not work." - ) - else: - missing_imports = True - print(f"ERROR: Missing Python module '{module} (pip:{name})'") -if missing_imports: - print("Aborting due to missing Python modules") - exit(1) - -import google.auth # noqa: E402 -from google.oauth2 import service_account # noqa: E402 -import googleapiclient.discovery # noqa: E402 -import google_auth_httplib2 # noqa: E402 -from googleapiclient.http import set_user_agent # noqa: E402 -from google.api_core.client_options import ClientOptions # noqa: E402 -import httplib2 # noqa: E402 - -if can_tpu: - from google.cloud import tpu_v2 as tpu # noqa: E402 -import google.api_core.exceptions as gExceptions # noqa: E402 - -from requests import get as get_url # noqa: E402 -from requests.exceptions import RequestException # noqa: E402 - -import yaml # noqa: E402 -from addict import Dict as NSDict # noqa: E402 - -optional_modules = [ - ("google.cloud.secretmanager", "google-cloud-secret-manager"), -] -for module, name in optional_modules: - if importlib.util.find_spec(module) is None: - print(f"WARNING: Missing Python module '{module}' (pip:{name}) ") - -USER_AGENT = "Slurm_GCP_Scripts/1.5 (GPN:SchedMD)" -ENV_CONFIG_YAML = os.getenv("SLURM_CONFIG_YAML") -if ENV_CONFIG_YAML: - CONFIG_FILE = Path(ENV_CONFIG_YAML) -else: - CONFIG_FILE = Path(__file__).with_name("config.yaml") -API_REQ_LIMIT = 2000 -URI_REGEX = r"[a-z]([-a-z0-9]*[a-z0-9])?" - - -def mkdirp(path: Path) -> None: - path.mkdir(parents=True, exist_ok=True) - - -scripts_dir = next( - p for p in (Path(__file__).parent, Path("/slurm/scripts")) if p.is_dir() -) - -# readily available compute api handle -compute = None -# slurm-gcp config object, could be empty if not available -cfg = NSDict() -# caching Lookup object -lkp = None - -# load all directories as Paths into a dict-like namespace -dirs = NSDict( - { - n: Path(p) - for n, p in dict.items( - { - "home": "/home", - "apps": "/opt/apps", - "slurm": "/slurm", - "scripts": scripts_dir, - "custom_scripts": "/slurm/custom_scripts", - "munge": "/etc/munge", - "secdisk": "/mnt/disks/sec", - "log": "/var/log/slurm", - } - ) - } -) - -slurmdirs = NSDict( - { - n: Path(p) - for n, p in dict.items( - { - "prefix": "/usr/local", - "etc": "/usr/local/etc/slurm", - "state": "/var/spool/slurm", - } - ) - } -) - - -yaml.SafeDumper.yaml_representers[ - None -] = lambda self, data: yaml.representer.SafeRepresenter.represent_str(self, str(data)) - - -class ApiEndpoint(Enum): - COMPUTE = "compute" - BQ = "bq" - STORAGE = "storage" - TPU = "tpu" - SECRET = "secret_manager" - - -@lru_cache(maxsize=1) -def default_credentials(): - return google.auth.default()[0] - - -@lru_cache(maxsize=1) -def authentication_project(): - return google.auth.default()[1] - - -DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" - - -def universe_domain() -> str: - try: - return instance_metadata("attributes/universe_domain") - except Exception: - return DEFAULT_UNIVERSE_DOMAIN - - -def endpoint_version(api: ApiEndpoint) -> Optional[str]: - if api and api.value in lkp.endpoint_versions: - return lkp.endpoint_versions[api.value] - return None - - -@lru_cache(maxsize=1) -def get_credentials() -> Optional[service_account.Credentials]: - """Get credentials for service account""" - key_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") - if key_path is not None: - credentials = service_account.Credentials.from_service_account_file( - key_path, scopes=[f"https://www.{universe_domain()}/auth/cloud-platform"] - ) - else: - credentials = default_credentials() - - return credentials - - -def create_client_options(api: ApiEndpoint = None) -> ClientOptions: - """Create client options for cloud endpoints""" - ver = endpoint_version(api) - ud = universe_domain() - options = {} - if ud and ud != DEFAULT_UNIVERSE_DOMAIN: - options["universe_domain"] = ud - if ver: - options["api_endpoint"] = f"https://{api.value}.{ud}/{ver}/" - co = ClientOptions(**options) - log.debug(f"Using ClientOptions = {co} for API: {api.value}") - return co - - -class LogFormatter(logging.Formatter): - """adds logging flags to the levelname in log records""" - - def format(self, record): - new_fmt = self._fmt - flag = getattr(record, "flag", None) - if flag is not None: - start, level, end = new_fmt.partition("%(levelname)s") - if level: - new_fmt = f"{start}{level}(%(flag)s){end}" - # insert function name if record level is DEBUG - if record.levelno < logging.INFO: - prefix, msg, suffix = new_fmt.partition("%(message)s") - new_fmt = f"{prefix}%(funcName)s: {msg}{suffix}" - self._style._fmt = new_fmt - return super().format(record) - - -class FlagLogAdapter(logging.LoggerAdapter): - """creates log adapters that add a flag to the log record, - allowing it to be filtered""" - - def __init__(self, logger, flag, extra=None): - if extra is None: - extra = {} - self.flag = flag - super().__init__(logger, extra) - - @property - def enabled(self): - return cfg.extra_logging_flags.get(self.flag, False) - - def process(self, msg, kwargs): - extra = kwargs.setdefault("extra", {}) - extra.update(self.extra) - extra["flag"] = self.flag - return msg, kwargs - - -logging.basicConfig(level=logging.INFO, stream=sys.stdout) -log = logging.getLogger(__name__) -logging_flags = [ - "trace_api", - "subproc", - "hostlists", -] -log_trace_api = FlagLogAdapter(log, "trace_api") -log_subproc = FlagLogAdapter(log, "subproc") -log_hostlists = FlagLogAdapter(log, "hostlists") - - -def access_secret_version(project_id, secret_id, version_id="latest"): - """ - Access the payload for the given secret version if one exists. The version - can be a version number as a string (e.g. "5") or an alias (e.g. "latest"). - """ - from google.cloud import secretmanager - from google.api_core import exceptions - - co = create_client_options(ApiEndpoint.SECRET) - client = secretmanager.SecretManagerServiceClient(client_options=co) - name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}" - try: - response = client.access_secret_version(request={"name": name}) - log.debug(f"Secret '{name}' was found.") - payload = response.payload.data.decode("UTF-8") - except exceptions.NotFound: - log.debug(f"Secret '{name}' was not found!") - payload = None - - return payload - - -def parse_self_link(self_link: str): - """Parse a selfLink url, extracting all useful values - https://.../v1/projects//regions//... - {'project': , 'region': , ...} - can also extract zone, instance (name), image, etc - """ - link_patt = re.compile(r"(?P[^\/\s]+)s\/(?P[^\s\/]+)") - return NSDict(link_patt.findall(self_link)) - - -def parse_bucket_uri(uri: str): - """ - Parse a bucket url - E.g. gs:/// - """ - pattern = re.compile(r"gs://(?P[^/\s]+)/(?P([^/\s]+)(/[^/\s]+)*)") - matches = pattern.match(uri) - return matches.group("bucket"), matches.group("path") - - -def trim_self_link(link: str): - """get resource name from self link url, eg. - https://.../v1/projects//regions/ - -> - """ - try: - return link[link.rindex("/") + 1 :] - except ValueError: - raise Exception(f"'/' not found, not a self link: '{link}' ") - - -def execute_with_futures(func, seq): - with ThreadPoolExecutor() as exe: - futures = [] - for i in seq: - future = exe.submit(func, i) - futures.append(future) - for future in as_completed(futures): - result = future.exception() - if result is not None: - raise result - - -def map_with_futures(func, seq): - with ThreadPoolExecutor() as exe: - futures = [] - for i in seq: - future = exe.submit(func, i) - futures.append(future) - for future in futures: - # Will be result or raise Exception - res = None - try: - res = future.result() - except Exception as e: - res = e - yield res - - -def blob_get(file, project=None): - from google.cloud import storage - - if project is None: - project = lkp.project - uri = instance_metadata("attributes/slurm_bucket_path") - bucket_name, path = parse_bucket_uri(uri) - blob_name = f"{path}/{file}" - co = create_client_options(ApiEndpoint.STORAGE) - storage_client = storage.Client(project=project, client_options=co) - return storage_client.get_bucket(bucket_name).blob(blob_name) - - -def blob_list(prefix="", delimiter=None, project=None): - from google.cloud import storage - - if project is None: - project = lkp.project - uri = instance_metadata("attributes/slurm_bucket_path") - bucket_name, path = parse_bucket_uri(uri) - blob_prefix = f"{path}/{prefix}" - co = create_client_options(ApiEndpoint.STORAGE) - storage_client = storage.Client(project=project, client_options=co) - # Note: The call returns a response only when the iterator is consumed. - blobs = storage_client.list_blobs( - bucket_name, prefix=blob_prefix, delimiter=delimiter - ) - return [blob for blob in blobs] - - -def _hash_file(fullpath): - with open(fullpath, "rb") as f: - file_hash = hashlib.md5() - chunk = f.read(8192) - while chunk: - file_hash.update(chunk) - chunk = f.read(8192) - return base64.b64encode(file_hash.digest()).decode("utf-8") - - -def install_custom_scripts(check_hash=False): - """download custom scripts from gcs bucket""" - - compute_tokens = ["compute", "prolog", "epilog"] - if lkp.instance_role == "compute": - try: - compute_tokens.append(f"nodeset-{lkp.node_nodeset_name()}") - except Exception as e: - log.error(f"Failed to lookup nodeset: {e}") - - prefix_tokens = dict.get( - { - "login": ["login"], - "compute": compute_tokens, - "controller": ["controller", "prolog", "epilog"], - }, - lkp.instance_role, - [], - ) - prefixes = [f"slurm-{tok}-script" for tok in prefix_tokens] - blobs = list(chain.from_iterable(blob_list(prefix=p) for p in prefixes)) - - script_pattern = re.compile(r"slurm-(?P\S+)-script-(?P\S+)") - for blob in blobs: - m = script_pattern.match(Path(blob.name).name) - if not m: - log.warning(f"found blob that doesn't match expected pattern: {blob.name}") - continue - path_parts = m["path"].split("-") - path_parts[0] += ".d" - stem, _, ext = m["name"].rpartition("_") - filename = ".".join((stem, ext)) - - path = Path(*path_parts, filename) - fullpath = (dirs.custom_scripts / path).resolve() - mkdirp(fullpath.parent) - - for par in path.parents: - chown_slurm(dirs.custom_scripts / par) - need_update = True - if check_hash and fullpath.exists(): - need_update = _hash_file(fullpath) != blob.md5_hash - if need_update: - log.info(f"installing custom script: {path} from {blob.name}") - with fullpath.open("wb") as f: - blob.download_to_file(f) - chown_slurm(fullpath, mode=0o755) - - -def reservation_resource_policies(reservation): - """ - Inspects reservation object, returns list of resource policies names. - Converts policy URLs to names, e.g.: - projects/111111/regions/us-central1/resourcePolicies/zebra -> zebra - """ - return [u.split("/")[-1] for u in reservation.get("resourcePolicies", {}).values()] - - -def compute_service(credentials=None, user_agent=USER_AGENT, version="beta"): - """Make thread-safe compute service handle - creates a new Http for each request - """ - - credentials = get_credentials() - - def build_request(http, *args, **kwargs): - new_http = httplib2.Http() - if user_agent is not None: - new_http = set_user_agent(new_http, user_agent) - if credentials is not None: - new_http = google_auth_httplib2.AuthorizedHttp(credentials, http=new_http) - return googleapiclient.http.HttpRequest(new_http, *args, **kwargs) - - ver = endpoint_version(ApiEndpoint.COMPUTE) - disc_url = googleapiclient.discovery.DISCOVERY_URI - if ver: - version = ver - disc_url = disc_url.replace(DEFAULT_UNIVERSE_DOMAIN, universe_domain()) - - log.debug(f"Using version={version} of Google Compute Engine API") - return googleapiclient.discovery.build( - "compute", - version, - requestBuilder=build_request, - credentials=credentials, - discoveryServiceUrl=disc_url, - ) - - -def load_config_data(config): - """load dict-like data into a config object""" - cfg = NSDict(config) - if not cfg.slurm_log_dir: - cfg.slurm_log_dir = dirs.log - if not cfg.slurm_bin_dir: - cfg.slurm_bin_dir = slurmdirs.prefix / "bin" - if not cfg.slurm_control_host: - cfg.slurm_control_host = f"{cfg.slurm_cluster_name}-controller" - if not cfg.slurm_control_host_port: - cfg.slurm_control_host_port = "6820-6830" - if not cfg.munge_mount: - # NOTE: should only happen with cloud controller - cfg.munge_mount = NSDict( - { - "server_ip": cfg.slurm_control_addr or cfg.slurm_control_host, - "remote_mount": "/etc/munge", - "fs_type": "nfs", - "mount_options": "defaults,hard,intr,_netdev", - } - ) - - if not cfg.enable_debug_logging and isinstance(cfg.enable_debug_logging, NSDict): - cfg.enable_debug_logging = False - cfg.extra_logging_flags = NSDict( - {flag: cfg.extra_logging_flags.get(flag, False) for flag in logging_flags} - ) - return cfg - - -def new_config(config): - """initialize a new config object - necessary defaults are handled here - """ - cfg = load_config_data(config) - - network_storage_iter = filter( - None, - ( - *cfg.network_storage, - *cfg.login_network_storage, - *chain.from_iterable(ns.network_storage for ns in cfg.nodeset.values()), - *chain.from_iterable(ns.network_storage for ns in cfg.nodeset_dyn.values()), - *chain.from_iterable(ns.network_storage for ns in cfg.nodeset_tpu.values()), - ), - ) - for netstore in network_storage_iter: - if netstore != "gcsfuse" and ( - netstore.server_ip is None or netstore.server_ip == "$controller" - ): - netstore.server_ip = cfg.slurm_control_host - return cfg - - -def fetch_config_yaml(): - """Fetch config.yaml from bucket""" - config_yaml = blob_get("config.yaml").download_as_text() - cfg = new_config(yaml.safe_load(config_yaml)) - return cfg - - -def fetch_config_yaml_md5(): - """Fetch config.yaml blob md5 from bucket""" - import hashlib - - blob = blob_get("config.yaml") - blob.reload() # Populate blob with metadata - hash_str = str(blob.md5_hash).encode(encoding="utf-8") - return hashlib.md5(hash_str) - - -def load_config_file(path): - """load config from file""" - content = None - try: - content = yaml.safe_load(Path(path).read_text()) - except FileNotFoundError: - log.warning(f"config file not found: {path}") - return NSDict() - return load_config_data(content) - - -def save_config(cfg, path): - """save given config to file at path""" - Path(path).write_text(yaml.dump(cfg, Dumper=Dumper)) - - -def filter_logging_flags(record): - """logging filter for flags - if there are no flags, always pass. If there are flags, only pass if a flag - matches an enabled flag in cfg.extra_logging_flags""" - flag = getattr(record, "flag", None) - if flag is None: - return True - return cfg.extra_logging_flags.get(flag, False) - - -def owned_file_handler(filename): - """create file handler""" - if filename is None: - return None - chown_slurm(filename) - return logging.handlers.WatchedFileHandler(filename, delay=True) - - -def config_root_logger(caller_logger, level="DEBUG", stdout=True, logfile=None): - """configure the root logger, disabling all existing loggers""" - handlers = list(compress(("stdout_handler", "file_handler"), (stdout, logfile))) - - config = { - "version": 1, - "disable_existing_loggers": True, - "formatters": { - "standard": { - "()": LogFormatter, - "fmt": "%(levelname)s: %(message)s", - }, - "stamp": { - "()": LogFormatter, - "fmt": "%(asctime)s %(levelname)s: %(message)s", - }, - }, - "filters": { - "logging_flags": {"()": lambda: filter_logging_flags}, - }, - "handlers": { - "stdout_handler": { - "level": logging.DEBUG, - "formatter": "standard", - "class": "logging.StreamHandler", - "stream": sys.stdout, - "filters": ["logging_flags"], - }, - "file_handler": { - "()": owned_file_handler, - "level": logging.DEBUG, - "formatter": "stamp", - "filters": ["logging_flags"], - "filename": logfile, - }, - }, - "root": { - "handlers": handlers, - "level": level, - }, - } - if not logfile: - del config["handlers"]["file_handler"] - logging.config.dictConfig(config) - loggers = ( - __name__, - "resume", - "suspend", - "slurmsync", - "setup", - caller_logger, - ) - for logger in map(logging.getLogger, loggers): - logger.disabled = False - - -def log_api_request(request): - """log.trace info about a compute API request""" - if log_trace_api.enabled: - # output the whole request object as pretty yaml - # the body is nested json, so load it as well - rep = json.loads(request.to_json()) - if rep.get("body", None) is not None: - rep["body"] = json.loads(rep["body"]) - pretty_req = yaml.safe_dump(rep).rstrip() - # label log message with the calling function - log_trace_api.debug(f"{inspect.stack()[1].function}:\n{pretty_req}") - - -def handle_exception(exc_type, exc_value, exc_trace): - """log exceptions other than KeyboardInterrupt""" - # TODO does this work? - if not issubclass(exc_type, KeyboardInterrupt): - log.exception("Fatal exception", exc_info=(exc_type, exc_value, exc_trace)) - sys.__excepthook__(exc_type, exc_value, exc_trace) - - -def run( - args, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=False, - timeout=None, - check=True, - universal_newlines=True, - **kwargs, -): - """Wrapper for subprocess.run() with convenient defaults""" - if isinstance(args, list): - args = list(filter(lambda x: x is not None, args)) - args = " ".join(args) - if not shell and isinstance(args, str): - args = shlex.split(args) - log_subproc.debug(f"run: {args}") - result = subprocess.run( - args, - stdout=stdout, - stderr=stderr, - shell=shell, - timeout=timeout, - check=check, - universal_newlines=universal_newlines, - **kwargs, - ) - return result - - -def spawn(cmd, quiet=False, shell=False, **kwargs): - """nonblocking spawn of subprocess""" - if not quiet: - log_subproc.debug(f"spawn: {cmd}") - args = cmd if shell else shlex.split(cmd) - return subprocess.Popen(args, shell=shell, **kwargs) - - -def chown_slurm(path: Path, mode=None) -> None: - if path.exists(): - if mode: - path.chmod(mode) - else: - mkdirp(path.parent) - if mode: - path.touch(mode=mode) - else: - path.touch() - try: - shutil.chown(path, user="slurm", group="slurm") - except LookupError: - log.warning(f"User 'slurm' does not exist. Cannot 'chown slurm:slurm {path}'.") - except PermissionError: - log.warning(f"Not authorized to 'chown slurm:slurm {path}'.") - except Exception as err: - log.error(err) - - -@contextmanager -def cd(path): - """Change working directory for context""" - prev = Path.cwd() - os.chdir(path) - try: - yield - finally: - os.chdir(prev) - - -def cached_property(f): - return property(lru_cache()(f)) - - -def retry(max_retries: int, init_wait_time: float, warn_msg: str, exc_type: Exception): - """Retries functions that raises the exception exc_type. - Retry time is increased by a factor of two for every iteration. - - Args: - max_retries (int): Maximum number of retries - init_wait_time (float): Initial wait time in secs - warn_msg (str): Message to print during retries - exc_type (Exception): Exception type to check for - """ - - if max_retries <= 0: - raise ValueError("Incorrect value for max_retries, must be >= 1") - if init_wait_time <= 0.0: - raise ValueError("Invalid value for init_wait_time, must be > 0.0") - - def decorator(f): - @wraps(f) - def wrapper(*args, **kwargs): - retry = 0 - secs = init_wait_time - captured_exc = None - while retry < max_retries: - try: - return f(*args, **kwargs) - except exc_type as e: - captured_exc = e - log.warn(f"{warn_msg}, retrying in {secs}") - sleep(secs) - retry += 1 - secs *= 2 - raise captured_exc - - return wrapper - - return decorator - - -def separate(pred, coll): - """filter into 2 lists based on pred returning True or False - returns ([False], [True]) - """ - return reduce(lambda acc, el: acc[pred(el)].append(el) or acc, coll, ([], [])) - - -def chunked(iterable, n=API_REQ_LIMIT): - """group iterator into chunks of max size n""" - it = iter(iterable) - while True: - chunk = list(islice(it, n)) - if not chunk: - return - yield chunk - - -def groupby_unsorted(seq, key): - indices = defaultdict(list) - for i, el in enumerate(seq): - indices[key(el)].append(i) - for k, idxs in indices.items(): - yield k, (seq[i] for i in idxs) - - -@lru_cache(maxsize=32) -def find_ratio(a, n, s, r0=None): - """given the start (a), count (n), and sum (s), find the ratio required""" - if n == 2: - return s / a - 1 - an = a * n - if n == 1 or s == an: - return 1 - if r0 is None: - # we only need to know which side of 1 to guess, and the iteration will work - r0 = 1.1 if an < s else 0.9 - - # geometric sum formula - def f(r): - return a * (1 - r**n) / (1 - r) - s - - # derivative of f - def df(r): - rm1 = r - 1 - rn = r**n - return (a * (rn * (n * rm1 - r) + r)) / (r * rm1**2) - - MIN_DR = 0.0001 # negligible change - r = r0 - # print(f"r(0)={r0}") - MAX_TRIES = 64 - for i in range(1, MAX_TRIES + 1): - try: - dr = f(r) / df(r) - except ZeroDivisionError: - log.error(f"Failed to find ratio due to zero division! Returning r={r0}") - return r0 - r = r - dr - # print(f"r({i})={r}") - # if the change in r is small, we are close enough - if abs(dr) < MIN_DR: - break - else: - log.error(f"Could not find ratio after {MAX_TRIES}! Returning r={r0}") - return r0 - return r - - -def backoff_delay(start, timeout=None, ratio=None, count: int = 0): - """generates `count` waits starting at `start` - sum of waits is `timeout` or each one is `ratio` bigger than the last - the last wait is always 0""" - # timeout or ratio must be set but not both - assert (timeout is None) ^ (ratio is None) - assert ratio is None or ratio > 0 - assert timeout is None or timeout >= start - assert (count > 1 or timeout is not None) and isinstance(count, int) - assert start > 0 - - if count == 0: - # Equation for auto-count is tuned to have a max of - # ~int(timeout) counts with a start wait of <0.01. - # Increasing start wait decreases count eg. - # backoff_delay(10, timeout=60) -> count = 5 - count = int( - (timeout / ((start + 0.05) ** (1 / 2)) + 2) // math.log(timeout + 2) - ) - - yield start - # if ratio is set: - # timeout = start * (1 - ratio**(count - 1)) / (1 - ratio) - if ratio is None: - ratio = find_ratio(start, count - 1, timeout) - - wait = start - # we have start and 0, so we only need to generate count - 2 - for _ in range(count - 2): - wait *= ratio - yield wait - yield 0 - return - - -ROOT_URL = "http://metadata.google.internal/computeMetadata/v1" - - -def get_metadata(path, root=ROOT_URL): - """Get metadata relative to metadata/computeMetadata/v1""" - HEADERS = {"Metadata-Flavor": "Google"} - url = f"{root}/{path}" - try: - resp = get_url(url, headers=HEADERS) - resp.raise_for_status() - return resp.text - except RequestException: - log.debug(f"metadata not found ({url})") - raise Exception(f"failed to get_metadata from {url}") - - -@lru_cache(maxsize=None) -def instance_metadata(path): - """Get instance metadata""" - return get_metadata(path, root=f"{ROOT_URL}/instance") - - -@lru_cache(maxsize=None) -def project_metadata(key): - """Get project metadata project/attributes/-""" - return get_metadata(key, root=f"{ROOT_URL}/project/attributes") - - -def bucket_blob_download(bucket_name, blob_name): - from google.cloud import storage - - co = create_client_options("storage") - storage_client = storage.Client(client_options=co) - bucket = storage_client.bucket(bucket_name) - blob = bucket.blob(blob_name) - contents = None - with tempfile.NamedTemporaryFile(mode="w+t") as tmp: - blob.download_to_filename(tmp.name) - with open(tmp.name, "r") as f: - contents = f.read() - return contents - - -def natural_sort(text): - def atoi(text): - return int(text) if text.isdigit() else text - - return [atoi(w) for w in re.split(r"(\d+)", text)] - - -# TODO: replace with to_hostlist_fast -def to_hostlist(nodenames) -> str: - """make hostlist from list of node names""" - # use tmp file because list could be large - tmp_file = tempfile.NamedTemporaryFile(mode="w+t", delete=False) - tmp_file.writelines("\n".join(sorted(nodenames, key=natural_sort))) - tmp_file.close() - - hostlist = run(f"{lkp.scontrol} show hostlist {tmp_file.name}").stdout.rstrip() - log_hostlists.debug(f"hostlist({len(nodenames)}): {hostlist}".format(hostlist)) - os.remove(tmp_file.name) - return hostlist - - -def to_hostlist_fast(names: Iterable[str]) -> str: - """ - Fast implementation of to_hostlist that doesn't invoke `scontrol` - IMPORTANT: - * Acts as `scontrol show hostlistsorted`, i.e. original order is not preserved - * Achieves worse compression than `to_hostlist` for some cases - """ - pref = defaultdict(list) - tokenizer = re.compile(r"^(.*?)(\d*)$") - for name in filter(None, names): - p, s = tokenizer.match(name).groups() - pref[p].append(s) - - def _compress_suffixes(ss: List[str]) -> List[str]: - cur, res = None, [] - - def cur_repr(): - nums, strs = cur - if nums[0] == nums[1]: - return strs[0] - return f"{strs[0]}-{strs[1]}" - - for s in sorted(ss, key=int): - n = int(s) - if cur is None: - cur = ((n, n), (s, s)) - continue - - nums, strs = cur - if n == nums[1] + 1: - cur = ((nums[0], n), (strs[0], s)) - else: - res.append(cur_repr()) - cur = ((n, n), (s, s)) - if cur: - res.append(cur_repr()) - return res - - res = [] - for p in sorted(pref.keys()): - sl = defaultdict(list) - for s in pref[p]: - sl[len(s)].append(s) - cs = [] - for ln in sorted(sl.keys()): - if ln == 0: - res.append(p) - else: - cs.extend(_compress_suffixes(sl[ln])) - if not cs: - continue - if len(cs) == 1 and "-" not in cs[0]: - res.append(f"{p}{cs[0]}") - else: - res.append(f"{p}[{','.join(cs)}]") - return ",".join(res) - - -def part_is_tpu(part): - """check if partition with name part contains a nodeset of type tpu""" - return len(lkp.cfg.partitions[part].partition_nodeset_tpu) > 0 - - -def get_vmcount_of_tpu_part(part): - res = 0 - for ns in lkp.cfg.partitions[part].partition_nodeset_tpu: - tpu_obj = TPU(lkp.cfg.nodeset_tpu[ns]) - if res == 0: - res = tpu_obj.vmcount - else: - if res != tpu_obj.vmcount: - # this should not happen, that in the same partition there are different vmcount nodesets - return -1 - return res - - -def to_hostnames(nodelist: str) -> List[str]: - """make list of hostnames from hostlist expression""" - if not nodelist: - return [] # avoid degenerate invocation of scontrol - if isinstance(nodelist, str): - hostlist = nodelist - else: - hostlist = ",".join(nodelist) - hostnames = run(f"{lkp.scontrol} show hostnames {hostlist}").stdout.splitlines() - log_hostlists.debug(f"hostnames({len(hostnames)}) from {hostlist}") - return hostnames - - -def retry_exception(exc): - """return true for exceptions that should always be retried""" - retry_errors = ( - "Rate Limit Exceeded", - "Quota Exceeded", - ) - return any(e in str(exc) for e in retry_errors) - - -def ensure_execute(request): - """Handle rate limits and socket time outs""" - - for retry, wait in enumerate(backoff_delay(0.5, timeout=10 * 60, count=20)): - try: - return request.execute() - except googleapiclient.errors.HttpError as e: - if retry_exception(e): - log.error(f"retry:{retry} '{e}'") - sleep(wait) - continue - raise - - except socket.timeout as e: - # socket timed out, try again - log.debug(e) - - except Exception as e: - log.error(e, exc_info=True) - raise - - break - - -def batch_execute(requests, retry_cb=None, log_err=log.error): - """execute list or dict as batch requests - retry if retry_cb returns true - """ - - compute = globals()["compute"] - BATCH_LIMIT = 1000 - if not isinstance(requests, dict): - requests = {str(k): v for k, v in enumerate(requests)} # rid generated here - done = {} - failed = {} - timestamps = [] - rate_limited = False - - def batch_callback(rid, resp, exc): - nonlocal rate_limited - if exc is not None: - log_err(f"compute request exception {rid}: {exc}") - if retry_exception(exc): - rate_limited = True - else: - req = requests.pop(rid) - failed[rid] = (req, exc) - else: - # if retry_cb is set, don't move to done until it returns false - if retry_cb is None or not retry_cb(resp): - requests.pop(rid) - done[rid] = resp - - def batch_request(reqs): - batch = compute.new_batch_http_request(callback=batch_callback) - for rid, req in reqs: - batch.add(req, request_id=rid) - return batch - - while requests: - if timestamps: - timestamps = [stamp for stamp in timestamps if stamp > time()] - if rate_limited and timestamps: - stamp = next(iter(timestamps)) - sleep(max(stamp - time(), 0)) - rate_limited = False - # up to API_REQ_LIMIT (2000) requests - # in chunks of up to BATCH_LIMIT (1000) - batches = [ - batch_request(chunk) - for chunk in chunked(islice(requests.items(), API_REQ_LIMIT), BATCH_LIMIT) - ] - timestamps.append(time() + 100) - with ThreadPoolExecutor() as exe: - futures = [] - for batch in batches: - future = exe.submit(ensure_execute, batch) - futures.append(future) - for future in futures: - result = future.exception() - if result is not None: - raise result - - return done, failed - - -def wait_request(operation, project=None, compute=None): - """makes the appropriate wait request for a given operation""" - if not compute: - compute = globals()["compute"] - if project is None: - project = lkp.project - if "zone" in operation: - req = compute.zoneOperations().wait( - project=project, - zone=trim_self_link(operation["zone"]), - operation=operation["name"], - ) - elif "region" in operation: - req = compute.regionOperations().wait( - project=project, - region=trim_self_link(operation["region"]), - operation=operation["name"], - ) - else: - req = compute.globalOperations().wait( - project=project, operation=operation["name"] - ) - return req - - -def wait_for_operation(operation, project=None, compute=None): - """wait for given operation""" - if not compute: - compute = globals()["compute"] - if project is None: - project = parse_self_link(operation["selfLink"]).project - wait_req = wait_request(operation, project=project, compute=compute) - - while True: - result = ensure_execute(wait_req) - if result["status"] == "DONE": - log_errors = " with errors" if "error" in result else "" - log.debug( - f"operation complete{log_errors}: type={result['operationType']}, name={result['name']}" - ) - return result - - -def wait_for_operations(operations, project=None, compute=None): - if not compute: - compute = globals()["compute"] - return [ - wait_for_operation(op, project=project, compute=compute) for op in operations - ] - - -def get_filtered_operations( - op_filter, - zone=None, - region=None, - only_global=False, - project=None, - compute=None, -): - """get list of operations associated with group id""" - - if not compute: - compute = globals()["compute"] - if project is None: - project = lkp.project - operations = [] - - def get_aggregated_operations(items): - # items is a dict of location key to value: dict(operations=) or an empty dict - operations.extend( - chain.from_iterable( - ops["operations"] for ops in items.values() if "operations" in ops - ) - ) - - def get_list_operations(items): - operations.extend(items) - - handle_items = get_list_operations - if only_global: - act = compute.globalOperations() - op = act.list(project=project, filter=op_filter) - nxt = act.list_next - elif zone is not None: - act = compute.zoneOperations() - op = act.list(project=project, zone=zone, filter=op_filter) - nxt = act.list_next - elif region is not None: - act = compute.regionOperations() - op = act.list(project=project, region=region, filter=op_filter) - nxt = act.list_next - else: - act = compute.globalOperations() - op = act.aggregatedList( - project=project, filter=op_filter, fields="items.*.operations,nextPageToken" - ) - nxt = act.aggregatedList_next - handle_items = get_aggregated_operations - while op is not None: - result = ensure_execute(op) - handle_items(result["items"]) - op = nxt(op, result) - return operations - - -def get_insert_operations(group_ids, flt=None, project=None, compute=None): - """get all insert operations from a list of operationGroupId""" - if not compute: - compute = globals()["compute"] - if project is None: - project = lkp.project - if isinstance(group_ids, str): - group_ids = group_ids.split(",") - filters = [ - "operationType=insert", - flt, - " OR ".join(f"(operationGroupId={id})" for id in group_ids), - ] - return get_filtered_operations(" AND ".join(f"({f})" for f in filters if f)) - - -def machine_type_sockets(template): - pattern = re.compile("^(?P[^-]+)") - m = pattern.match(template.machineType) - if not m: - raise Exception(f"template {template} does not match expected regex") - family = m.group("family") - guestCpus: int = int(template.machine_info.guestCpus) - socket_count = dict.get( - { - "h3": 2, - "c2d": 2 if guestCpus > 56 else 1, - "a3": 2, - }, - family, - 1, # assume 1 socket for all other families - ) - return socket_count - - -def isSmt(template): - machineType: str = template.machineType - guestCpus: int = int(template.machine_info.guestCpus) - - pattern = re.compile("^(?P[^-]+)") - matches = pattern.match(machineType) - machineTypeFamily: str = matches["family"] - - # https://cloud.google.com/compute/docs/cpu-platforms - noSmtFamily = [ - "t2a", - "t2d", - "h3", - ] - if machineTypeFamily in noSmtFamily: - return False - elif guestCpus == 1: - return False - return True - - -def getThreadsPerCore(template): - threadsPerCore: int = template.advancedMachineFeatures.threadsPerCore - - if not isSmt(template): - return 1 - elif threadsPerCore: - return threadsPerCore - else: - return 2 - - -@retry( - max_retries=9, - init_wait_time=1, - warn_msg="Temporary failure in name resolution", - exc_type=socket.gaierror, -) -def host_lookup(host_name: str) -> str: - return socket.gethostbyname(host_name) - - -class Dumper(yaml.SafeDumper): - """Add representers for pathlib.Path and NSDict for yaml serialization""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.add_representer(NSDict, self.represent_nsdict) - self.add_multi_representer(Path, self.represent_path) - - @staticmethod - def represent_nsdict(dumper, data): - return dumper.represent_mapping("tag:yaml.org,2002:map", data.items()) - - @staticmethod - def represent_path(dumper, path): - return dumper.represent_scalar("tag:yaml.org,2002:str", str(path)) - - -class TPU: - """Class for handling the TPU-vm nodes""" - - if can_tpu: - State = tpu.types.cloud_tpu.Node.State - TPUS_PER_VM = 4 - __expected_states = { - "create": State.READY, - "start": State.READY, - "stop": State.STOPPED, - } - - __tpu_version_mapping = { - "V2": tpu.AcceleratorConfig().Type.V2, - "V3": tpu.AcceleratorConfig().Type.V3, - "V4": tpu.AcceleratorConfig().Type.V4, - } - - def __init__(self, nodeset): - if not can_tpu: - raise Exception("TPU pip package not installed") - self._nodeset = nodeset - self._parent = f"projects/{lkp.project}/locations/{nodeset.zone}" - co = create_client_options(ApiEndpoint.TPU) - self._client = tpu.TpuClient(client_options=co) - self.data_disks = [] - for data_disk in nodeset.data_disks: - ad = tpu.AttachedDisk() - ad.source_disk = data_disk - ad.mode = tpu.AttachedDisk.DiskMode.DISK_MODE_UNSPECIFIED - self.data_disks.append(ad) - ns_ac = nodeset.accelerator_config - if ns_ac.topology != "" and ns_ac.version != "": - ac = tpu.AcceleratorConfig() - ac.topology = ns_ac.topology - ac.type_ = self.__tpu_version_mapping[ns_ac.version] - self.ac = ac - else: - req = tpu.GetAcceleratorTypeRequest( - name=f"{self._parent}/acceleratorTypes/{nodeset.node_type}" - ) - self.ac = self._client.get_accelerator_type(req).accelerator_configs[0] - self.vmcount = self.__calc_vm_from_topology(self.ac.topology) - - @property - def nodeset(self): - return self._nodeset - - @property - def preserve_tpu(self): - return self._nodeset.preserve_tpu - - @property - def node_type(self): - return self._nodeset.node_type - - @property - def tf_version(self): - return self._nodeset.tf_version - - @property - def enable_public_ip(self): - return self._nodeset.enable_public_ip - - @property - def preemptible(self): - return self._nodeset.preemptible - - @property - def reserved(self): - return self._nodeset.reserved - - @property - def service_account(self): - return self._nodeset.service_account - - @property - def zone(self): - return self._nodeset.zone - - def check_node_type(self): - if self.node_type is None: - return False - try: - request = tpu.GetAcceleratorTypeRequest( - name=f"{self._parent}/acceleratorTypes/{self.node_type}" - ) - return self._client.get_accelerator_type(request=request) is not None - except Exception: - return False - - def check_tf_version(self): - try: - request = tpu.GetRuntimeVersionRequest( - name=f"{self._parent}/runtimeVersions/{self.tf_version}" - ) - return self._client.get_runtime_version(request=request) is not None - except Exception: - return False - - def __calc_vm_from_topology(self, topology): - topo = topology.split("x") - tot = 1 - for num in topo: - tot = tot * int(num) - return tot // self.TPUS_PER_VM - - def __check_resp(self, response, op_name): - des_state = self.__expected_states.get(op_name) - # If the state is not in the table just print the response - if des_state is None: - return False - if response.__class__.__name__ != "Node": # If the response is not a node fail - return False - if response.state == des_state: - return True - return False - - def list_nodes(self): - try: - request = tpu.ListNodesRequest(parent=self._parent) - res = self._client.list_nodes(request=request) - except gExceptions.NotFound: - res = None - return res - - def list_node_names(self): - return [node.name.split("/")[-1] for node in self.list_nodes()] - - def start_node(self, nodename): - request = tpu.StartNodeRequest(name=f"{self._parent}/nodes/{nodename}") - resp = self._client.start_node(request=request).result() - return self.__check_resp(resp, "start") - - def stop_node(self, nodename): - request = tpu.StopNodeRequest(name=f"{self._parent}/nodes/{nodename}") - resp = self._client.stop_node(request=request).result() - return self.__check_resp(resp, "stop") - - def get_node(self, nodename): - try: - request = tpu.GetNodeRequest(name=f"{self._parent}/nodes/{nodename}") - res = self._client.get_node(request=request) - except gExceptions.NotFound: - res = None - return res - - def _register_node(self, nodename, ip_addr): - dns_name = socket.getnameinfo((ip_addr, 0), 0)[0] - run( - f"{lkp.scontrol} update nodename={nodename} nodeaddr={ip_addr} nodehostname={dns_name}" - ) - - def create_node(self, nodename): - if self.vmcount > 1 and not isinstance(nodename, list): - log.error( - f"Tried to create a {self.vmcount} node TPU on nodeset {self._nodeset.nodeset_name} but only received one nodename {nodename}" - ) - return False - if self.vmcount > 1 and ( - isinstance(nodename, list) and len(nodename) != self.vmcount - ): - log.error( - f"Expected to receive a list of {self.vmcount} nodenames for TPU node creation in nodeset {self._nodeset.nodeset_name}, but received this list {nodename}" - ) - return False - - node = tpu.Node() - node.accelerator_config = self.ac - node.runtime_version = f"tpu-vm-tf-{self.tf_version}" - startup_script = """ - #!/bin/bash - echo "startup script not found > /var/log/startup_error.log" - """ - with open( - Path(cfg.slurm_scripts_dir or dirs.scripts) / "startup.sh", "r" - ) as script: - startup_script = script.read() - if isinstance(nodename, list): - node_id = nodename[0] - slurm_names = [] - wid = 0 - for node_wid in nodename: - slurm_names.append(f"WORKER_{wid}:{node_wid}") - wid += 1 - else: - node_id = nodename - slurm_names = [f"WORKER_0:{nodename}"] - node.metadata = { - "slurm_docker_image": self.nodeset.docker_image, - "startup-script": startup_script, - "slurm_instance_role": "compute", - "slurm_cluster_name": lkp.cfg.slurm_cluster_name, - "slurm_bucket_path": lkp.cfg.bucket_path, - "slurm_names": ";".join(slurm_names), - "universe_domain": universe_domain(), - } - node.tags = [lkp.cfg.slurm_cluster_name] - if self.nodeset.service_account: - node.service_account.email = self.nodeset.service_account.email - node.service_account.scope = self.nodeset.service_account.scopes - node.scheduling_config.preemptible = self.preemptible - node.scheduling_config.reserved = self.reserved - node.network_config.subnetwork = self.nodeset.subnetwork - node.network_config.enable_external_ips = self.enable_public_ip - if self.data_disks: - node.data_disks = self.data_disks - - request = tpu.CreateNodeRequest(parent=self._parent, node=node, node_id=node_id) - resp = self._client.create_node(request=request).result() - if not self.__check_resp(resp, "create"): - return False - if isinstance(nodename, list): - for node_id, net_endpoint in zip(nodename, resp.network_endpoints): - self._register_node(node_id, net_endpoint.ip_address) - else: - ip_add = resp.network_endpoints[0].ip_address - self._register_node(nodename, ip_add) - return True - - def delete_node(self, nodename): - request = tpu.DeleteNodeRequest(name=f"{self._parent}/nodes/{nodename}") - try: - resp = self._client.delete_node(request=request).result() - if resp: - return self.get_node(nodename=nodename) is None - return False - except gExceptions.NotFound: - # log only error if vmcount is 1 as for other tpu vm count, this could be "phantom" nodes - if self.vmcount == 1: - log.error(f"Tpu single node {nodename} not found") - else: - # for the TPU nodes that consist in more than one vm, only the first node of the TPU a.k.a. the master node will - # exist as real TPU nodes, so the other ones are expected to not be found, check the hostname of the node that has - # not been found, and if it ends in 0, it means that is the master node and it should have been found, and in consequence - # log an error - nodehostname = yaml.safe_load( - run(f"{lkp.scontrol} --yaml show node {nodename}").stdout.rstrip() - )["nodes"][0]["hostname"] - if nodehostname.split("-")[-1] == "0": - log.error(f"TPU master node {nodename} not found") - else: - log.info(f"Deleted TPU 'phantom' node {nodename}") - # If the node is not found it is tecnichally deleted, so return success. - return True - - -class Lookup: - """Wrapper class for cached data access""" - - def __init__(self, cfg=None): - self._cfg = cfg or NSDict() - self.template_cache_path = Path(__file__).parent / "template_info.cache" - - @property - def cfg(self): - return self._cfg - - @property - def project(self): - return self.cfg.project or authentication_project() - - @property - def control_addr(self): - return self.cfg.slurm_control_addr - - @property - def control_host(self): - return self.cfg.slurm_control_host - - @cached_property - def control_host_addr(self): - return host_lookup(self.cfg.slurm_control_host) - - @property - def control_host_port(self): - return self.cfg.slurm_control_host_port - - @property - def endpoint_versions(self): - return self.cfg.endpoint_versions - - @property - def scontrol(self): - return Path(self.cfg.slurm_bin_dir if cfg else "") / "scontrol" - - @cached_property - def instance_role(self): - return instance_metadata("attributes/slurm_instance_role") - - @cached_property - def instance_role_safe(self): - try: - role = self.instance_role - except Exception as e: - log.error(e) - role = None - return role - - @cached_property - def compute(self): - # TODO evaluate when we need to use google_app_cred_path - if self.cfg.google_app_cred_path: - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.cfg.google_app_cred_path - return compute_service() - - @cached_property - def hostname(self): - return socket.gethostname() - - @cached_property - def hostname_fqdn(self): - return socket.getfqdn() - - @cached_property - def zone(self): - return instance_metadata("zone") - - node_desc_regex = re.compile( - r"^(?P(?P[^\s\-]+)-(?P\S+))-(?P(?P\w+)|(?P\[[\d,-]+\]))$" - ) - - @lru_cache(maxsize=None) - def _node_desc(self, node_name): - """Get parts from node name""" - if not node_name: - node_name = self.hostname - # workaround below is for VMs whose hostname is FQDN - node_name_short = node_name.split(".")[0] - m = self.node_desc_regex.match(node_name_short) - if not m: - raise Exception(f"node name {node_name} is not valid") - return m.groupdict() - - def node_prefix(self, node_name=None): - return self._node_desc(node_name)["prefix"] - - def node_nodeset_name(self, node_name=None): - return self._node_desc(node_name)["nodeset"] - - def node_nodeset(self, node_name=None): - nodeset_name = self.node_nodeset_name(node_name) - ns = self.cfg.nodeset.get(nodeset_name) - if ns: - return ns - return self.cfg.nodeset_tpu.get(nodeset_name) - - def node_is_tpu(self, node_name=None): - nodeset_name = self.node_nodeset_name(node_name) - return self.cfg.nodeset_tpu.get(nodeset_name) is not None - - def node_is_dyn(self, node_name=None) -> bool: - nodeset = self.node_nodeset_name(node_name) - return self.cfg.nodeset_dyn.get(nodeset) is not None - - def chunk_tpu_nodes(self, tpu_nodes): - model = tpu_nodes[0] - tpu = TPU(self.node_nodeset(model)) - return chunked(tpu_nodes, n=tpu.vmcount) - - def node_template(self, node_name=None): - return self.node_nodeset(node_name).instance_template - - def node_template_info(self, node_name=None): - return self.template_info(self.node_template(node_name)) - - def node_region(self, node_name=None): - nodeset = self.node_nodeset(node_name) - return parse_self_link(nodeset.subnetwork).region - - def nodeset_prefix(self, nodeset_name): - return f"{self.cfg.slurm_cluster_name}-{nodeset_name}" - - def nodelist_range(self, nodeset_name: str, start: int, count: int) -> str: - assert 0 <= start and 0 < count - pref = self.nodeset_prefix(nodeset_name) - if count == 1: - return f"{pref}-{start}" - return f"{pref}-[{start}-{start + count - 1}]" - - def static_dynamic_sizes(self, nodeset: object) -> int: - return (nodeset.node_count_static or 0, nodeset.node_count_dynamic_max or 0) - - def nodelist(self, nodeset) -> str: - cnt = sum(self.static_dynamic_sizes(nodeset)) - if cnt == 0: - return "" - return self.nodelist_range(nodeset.nodeset_name, 0, cnt) - - def nodenames(self, nodeset) -> Tuple[Iterable[str], Iterable[str]]: - pref = self.nodeset_prefix(nodeset.nodeset_name) - s_count, d_count = self.static_dynamic_sizes(nodeset) - return ( - (f"{pref}-{i}" for i in range(s_count)), - (f"{pref}-{i}" for i in range(s_count, s_count + d_count)), - ) - - def power_managed_nodesets(self) -> Iterable[object]: - return chain(self.cfg.nodeset.values(), self.cfg.nodeset_tpu.values()) - - def is_power_managed_node(self, node_name: str) -> bool: - try: - ns = self.node_nodeset(node_name) - if ns is None: - return False - idx = int(self._node_desc(node_name)["suffix"]) - return idx < sum(self.static_dynamic_sizes(ns)) - except Exception: - return False - - def is_static_node(self, node_name: str) -> bool: - if not self.is_power_managed_node(node_name): - return False - idx = int(self._node_desc(node_name)["suffix"]) - return idx < self.node_nodeset(node_name).node_count_static - - @lru_cache(maxsize=None) - def slurm_nodes(self): - StateTuple = namedtuple("StateTuple", "base,flags") - - def make_node_tuple(node_line): - """turn node,state line to (node, StateTuple(state))""" - # state flags include: CLOUD, COMPLETING, DRAIN, FAIL, POWERED_DOWN, - # POWERING_DOWN - node, fullstate = node_line.split(",") - state = fullstate.split("+") - state_tuple = StateTuple(state[0], set(state[1:])) - return (node, state_tuple) - - cmd = ( - f"{self.scontrol} show nodes | " - r"grep -oP '^NodeName=\K(\S+)|\s+State=\K(\S+)' | " - r"paste -sd',\n'" - ) - node_lines = run(cmd, shell=True).stdout.rstrip().splitlines() - nodes = { - node: state - for node, state in map(make_node_tuple, node_lines) - if "CLOUD" in state.flags or "DYNAMIC_NORM" in state.flags - } - return nodes - - def slurm_node(self, nodename): - return self.slurm_nodes().get(nodename) - - @lru_cache(maxsize=1) - def instances(self, project=None, slurm_cluster_name=None): - slurm_cluster_name = slurm_cluster_name or self.cfg.slurm_cluster_name - project = project or self.project - instance_information_fields = [ - "advancedMachineFeatures", - "cpuPlatform", - "creationTimestamp", - "disks", - "disks", - "fingerprint", - "guestAccelerators", - "hostname", - "id", - "kind", - "labelFingerprint", - "labels", - "lastStartTimestamp", - "lastStopTimestamp", - "lastSuspendedTimestamp", - "machineType", - "metadata", - "name", - "networkInterfaces", - "resourceStatus", - "scheduling", - "selfLink", - "serviceAccounts", - "shieldedInstanceConfig", - "shieldedInstanceIntegrityPolicy", - "sourceMachineImage", - "status", - "statusMessage", - "tags", - "zone", - # "deletionProtection", - # "startRestricted", - ] - if lkp.cfg.enable_slurm_gcp_plugins: - slurm_gcp_plugins.register_instance_information_fields( - lkp=lkp, - project=project, - slurm_cluster_name=slurm_cluster_name, - instance_information_fields=instance_information_fields, - ) - instance_information_fields = sorted(set(instance_information_fields)) - instance_fields = ",".join(instance_information_fields) - fields = f"items.zones.instances({instance_fields}),nextPageToken" - flt = f"labels.slurm_cluster_name={slurm_cluster_name} AND name:{slurm_cluster_name}-*" - act = self.compute.instances() - op = act.aggregatedList(project=project, fields=fields, filter=flt) - - def properties(inst): - """change instance properties to a preferred format""" - inst["zone"] = trim_self_link(inst["zone"]) - inst["machineType"] = trim_self_link(inst["machineType"]) - # metadata is fetched as a dict of dicts like: - # {'key': key, 'value': value}, kinda silly - metadata = {i["key"]: i["value"] for i in inst["metadata"].get("items", [])} - if "slurm_instance_role" not in metadata: - return None - inst["role"] = metadata["slurm_instance_role"] - inst["metadata"] = metadata - # del inst["metadata"] # no need to store all the metadata - return NSDict(inst) - - instances = {} - while op is not None: - result = ensure_execute(op) - instance_iter = ( - (inst["name"], properties(inst)) - for inst in chain.from_iterable( - m["instances"] for m in result.get("items", {}).values() - ) - ) - instances.update( - {name: props for name, props in instance_iter if props is not None} - ) - op = act.aggregatedList_next(op, result) - return instances - - def instance(self, instance_name, project=None, slurm_cluster_name=None): - instances = self.instances( - project=project, slurm_cluster_name=slurm_cluster_name - ) - return instances.get(instance_name) - - @lru_cache() - def reservation(self, name: str, zone: str) -> object: - """See https://cloud.google.com/compute/docs/reference/rest/v1/reservations""" - try: - _, project, _, short_name = name.split("/") - except ValueError: - raise ValueError( - f"Invalid reservation name: '{name}', expected format is 'projects/PROJECT/reservations/NAME'" - ) - - return ( - self.compute.reservations() - .get(project=project, zone=zone, reservation=short_name) - .execute() - ) - - @lru_cache(maxsize=1) - def machine_types(self, project=None): - project = project or self.project - field_names = "name,zone,guestCpus,memoryMb,accelerators" - fields = f"items.zones.machineTypes({field_names}),nextPageToken" - - machines = defaultdict(dict) - act = self.compute.machineTypes() - op = act.aggregatedList(project=project, fields=fields) - while op is not None: - result = ensure_execute(op) - machine_iter = chain.from_iterable( - m["machineTypes"] - for m in result["items"].values() - if "machineTypes" in m - ) - for machine in machine_iter: - name = machine["name"] - zone = machine["zone"] - machines[name][zone] = machine - - op = act.aggregatedList_next(op, result) - return machines - - def machine_type(self, machine_type, project=None, zone=None): - """ """ - custom_patt = re.compile( - r"((?P\w+)-)?custom-(?P\d+)-(?P\d+)" - ) - custom_match = custom_patt.match(machine_type) - if zone: - project = project or self.project - machine_info = ensure_execute( - self.compute.machineTypes().get( - project=project, zone=zone, machineType=machine_type - ) - ) - elif custom_match is not None: - groups = custom_match.groupdict() - cpus, mem = (groups[k] for k in ["cpus", "mem"]) - machine_info = { - "guestCpus": int(cpus), - "memoryMb": int(mem), - } - else: - machines = self.machine_types(project=project) - machine_info = next(iter(machines[machine_type].values()), None) - if machine_info is None: - raise Exception(f"machine type {machine_type} not found") - return NSDict(machine_info) - - def template_machine_conf(self, template_link, project=None, zone=None): - template = self.template_info(template_link) - if not template.machineType: - temp_name = trim_self_link(template_link) - raise Exception(f"instance template {temp_name} has no machine type") - template.machine_info = self.machine_type(template.machineType, zone=zone) - machine = template.machine_info - - machine_conf = NSDict() - machine_conf.boards = 1 # No information, assume 1 - machine_conf.sockets = machine_type_sockets(template) - # the value below for SocketsPerBoard must be type int - machine_conf.sockets_per_board = machine_conf.sockets // machine_conf.boards - machine_conf.threads_per_core = 1 - _div = 2 if getThreadsPerCore(template) == 1 else 1 - machine_conf.cpus = ( - int(machine.guestCpus / _div) if isSmt(template) else machine.guestCpus - ) - machine_conf.cores_per_socket = int(machine_conf.cpus / machine_conf.sockets) - # Because the actual memory on the host will be different than - # what is configured (e.g. kernel will take it). From - # experiments, about 16 MB per GB are used (plus about 400 MB - # buffer for the first couple of GB's. Using 30 MB to be safe. - gb = machine.memoryMb // 1024 - machine_conf.memory = machine.memoryMb - (400 + (30 * gb)) - return machine_conf - - @contextmanager - def template_cache(self, writeback=False): - flag = "c" if writeback else "r" - err = None - for wait in backoff_delay(0.125, timeout=60, count=20): - try: - cache = shelve.open( - str(self.template_cache_path), flag=flag, writeback=writeback - ) - break - except OSError as e: - err = e - log.debug(f"Failed to access template info cache: {e}") - sleep(wait) - continue - else: - # reached max_count of waits - raise Exception(f"Failed to access cache file. latest error: {err}") - try: - yield cache - finally: - cache.close() - - @lru_cache(maxsize=None) - def template_info(self, template_link, project=None): - project = project or self.project - template_name = trim_self_link(template_link) - # split read and write access to minimize write-lock. This might be a - # bit slower? TODO measure - if self.template_cache_path.exists(): - with self.template_cache() as cache: - if template_name in cache: - return NSDict(cache[template_name]) - - template = ensure_execute( - self.compute.instanceTemplates().get( - project=project, instanceTemplate=template_name - ) - ).get("properties") - template = NSDict(template) - # name and link are not in properties, so stick them in - template.name = template_name - template.link = template_link - # TODO delete metadata to reduce memory footprint? - # del template.metadata - - # translate gpus into an easier-to-read format - machine_info = self.machine_type(template.machineType, project=project) - if machine_info.accelerators: - template.gpu_type = machine_info.accelerators[0].guestAcceleratorType - template.gpu_count = machine_info.accelerators[0].guestAcceleratorCount - elif template.guestAccelerators: - template.gpu_type = template.guestAccelerators[0].acceleratorType - template.gpu_count = template.guestAccelerators[0].acceleratorCount - else: - template.gpu_type = None - template.gpu_count = 0 - - # keep write access open for minimum time - with self.template_cache(writeback=True) as cache: - cache[template_name] = template.to_dict() - # cache should be owned by slurm - chown_slurm(self.template_cache_path) - - return template - - def nodeset_map(self, hostnames: list): - """Convert a list of nodes into a map of nodeset_name to hostnames""" - nodeset_map = collections.defaultdict(list) - for node in hostnames: - nodeset_map[self.node_nodeset_name(node)].append(node) - return nodeset_map - - -# Define late globals -lkp = Lookup() -cfg = load_config_file(CONFIG_FILE) -if not cfg: - try: - cfg = fetch_config_yaml() - except Exception as e: - log.warning(f"config not found in bucket: {e}") - if cfg: - save_config(cfg, CONFIG_FILE) - -lkp = Lookup(cfg) - -# Needs to be run after the lookup is complete to get endpoint versions -compute = compute_service() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter - ) - parser.add_argument( - "--partitions", - "-p", - help="The partition(s) to retrieve the TPU vmcount value for.", - ) - args = parser.parse_args() - if args.partitions: - # useful exit code - # partition does not exists in config.yaml, thus do not exist in slurm - PART_INVALID = -1 - # in the same partition there are nodesets with different vmcounts - DIFF_VMCOUNTS_SAME_PART = -2 - # partition is a list of partitions in which at least two of them have different vmcount - DIFF_PART_DIFFERENT_VMCOUNTS = -3 - vmcounts = [] - # valid equals to 0 means that we are ok, otherwise it will be set to one of the previously defined exit codes - valid = 0 - for part in args.partitions.split(","): - if part not in lkp.cfg.partitions: - valid = PART_INVALID - break - else: - if part_is_tpu(part): - vmcount = get_vmcount_of_tpu_part(part) - if vmcount == -1: - valid = DIFF_VMCOUNTS_SAME_PART - break - vmcounts.append(vmcount) - else: - vmcounts.append(0) - # this means that there are different vmcounts for these partitions - if valid == 0 and len(set(vmcounts)) != 1: - valid = DIFF_PART_DIFFERENT_VMCOUNTS - if valid != 0: - print(f"VMCOUNT:{valid}") - else: - print(f"VMCOUNT:{vmcounts[0]}") diff --git a/test/Pipfile b/test/Pipfile deleted file mode 100644 index 21901586..00000000 --- a/test/Pipfile +++ /dev/null @@ -1,28 +0,0 @@ -[[source]] -url = "https://pypi.org/simple" -verify_ssl = true -name = "pypi" - -[packages] -addict = "*" -cython = "<3.0" -google-api-python-client = "*" -ipdb = "*" -ipython = "*" -more-executors = "*" -paramiko = "*" -psutil = "*" -pyaml = "*" -pytest = "*" -pytest-benchmark = "*" -pytest-xdist = {extras = ["psutil"], version = "*"} -python-hostlist = "*" -pyyaml = "*" -requests = "*" -tftest = "*" -wheel = "*" - -[dev-packages] - -[requires] -python_version = "3.8" diff --git a/test/README.md b/test/README.md deleted file mode 100644 index defaef48..00000000 --- a/test/README.md +++ /dev/null @@ -1,96 +0,0 @@ -The tests are written with pytest, with terraform handled by the python library -tftest. It is run from the `test` directory as - -`pytest -vs --project_id= --cluster_name= --image-project= --image-family= --image=` - -The Pipfile shows the dependencies. Only one of `--image-family` and `--image` -needs to be specified. The env var `GOOGLE_APPLICATION_CREDENTIALS` should be -set to a json file containing service account credentials. A private key -authorized to the GCP/service account should be in a file `test/gcp_login_id` or -in an env var `GCP_LOGIN_ID`. - -pytest will create a cluster and run the following tests on it. - -## Tests - -Lines with \[ \] are not currently implemented but are planned. - -test_config.py - -- test_gpu_config - - Check that the number of GPUs requested on nodes in terraform is equal to - the number of Slurm gres:gpu configured. -- test_ops_agent - - Check all running instances in the cluster for active ops agent fluentd - service. -- test_controller_custom_scripts - - check that the configured startup script ran -- test_login_custom_scripts - - check that the configured login startup script ran -- \[ \] Network Storage Testing - - Test that Lustre, NFS, and GCSFuse install and can mount storage correctly - on the latest OS image, and/or the HPC image as available. -- \[ \] Reconfigure - did the slurm.conf change? - -test_jobs.py - -- test_job - - Verify that a simple 3-node job running `srun hostname` completes - successfully. -- test_gpu_job - - On every partition with a GPU, run an sbatch `srun nvidia-smi` job and - verify that the job completes successfully. -- test_shielded - - On every partition with nodes configured for shielded VMs, run a simple job. - - If the partition has a GPU _and_ the image OS is Ubuntu 20.04, run a GPU job - instead. - - skip shielded GPU partitions otherwise to avoid spinning up a GPU instance - needlessly. -- test_openmpi - - Run a simple 3-node MPI job and verify that it completes successfully. -- test_placement_groups - - \[ \] start a 1-node job and verify it does _not_ get a placement group - - On any placement group partitions, - 1. start a 10 minute 2-node job - 1. wait for the instances to start and the job to begin - 1. check the instances in the job for `resouceStatus.physicalHost` (topology - information) - - Make sure at least one one part of the topology tag matches between all - instances - 4. cancel the job - 1. wait for the node to finish being powered down again -- test_preemption - - on partitions with preemptible nodes - 1. start a long job - 1. wait for the job to start running - 1. stop an instance in the job allocation - 1. verify that the node goes down with reason set - 1. wait for the node to return to idle (slurmsync handles this) - ###### TODO check that the job was requeued? - 6. cancel the job -- test_prolog_scripts - - run a job and check that prolog and epilog scripts ran -- \[ \] test exclusive nodes - - Since placement group nodes are also exclusive, this is being tested. But a - test just for exclusive nodes would be good. - - job runs on node - - node is torn down after job -- \[ \] Regional placement - - create 3 partitions that force direct jobs to different regions - - submit job to each partition and verify region - - verify that nodes get deleted after job is done. - -test_nodes.py - -- test_static - - get the list of static nodes in the cluster and wait for them all to be - powered up and idle -- test_compute_startup_scripts - - Check that the custom compute startup script ran on the static nodes -- test_exclusive_labels - - run job on exclusive nodes and check the instances for the correct job ID - label -- \[ \] Catching leaked nodes - - Tear down cluster while nodes are still provisioned, verify all nodes get - torn down - - Start up node outside of a job - verify slurmsync brings it down diff --git a/test/cleanup.py b/test/cleanup.py deleted file mode 100755 index 984f68ed..00000000 --- a/test/cleanup.py +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env python3 -import argparse -from pathlib import Path - -import tftest - - -def cleanup(cluster_name): - terraform_dir = Path("../terraform") - cluster_vars = [ - path - for path in terraform_dir.rglob(f"{cluster_name}-*.tfvars") - if path.is_symlink() - ] - for path in cluster_vars: - moduledir, tfvars = path.parent, path.name - print(f"destroy {moduledir/tfvars}") - tf = tftest.TerraformTest(moduledir) - print(tf.setup(output=True)) - print(tf.destroy(tf_var_file=tfvars, output=True)) - path.unlink() - - -parser = argparse.ArgumentParser(description="Cleanup any tftest clusters left around") -parser.add_argument("cluster_name", help="name of the cluster to clean up") - -if __name__ == "__main__": - args = parser.parse_args() - cleanup(args.cluster_name) diff --git a/test/conftest.py b/test/conftest.py deleted file mode 100644 index 77bf9b06..00000000 --- a/test/conftest.py +++ /dev/null @@ -1,181 +0,0 @@ -import logging -import re -import sys -from pathlib import Path - -import pytest - -sys.path.append("../scripts") -import util # noqa: E402 - -from deploy import Cluster, Configuration # noqa: E402 - -logging.basicConfig(level=logging.INFO) -log = logging.getLogger() - -root = Path(__file__).parent.parent -tf_path = root / "terraform" -test_path = root / "test" -tfvars_path = test_path / "tfvars" - - -def pytest_addoption(parser): - parser.addoption( - "--project-id", action="store", help="GCP project to deploy the cluster to" - ) - parser.addoption("--cluster-name", action="store", help="cluster name to deploy") - none_list = set( - [ - "null", - "none", - ] - ) - parser.addoption( - "--image", - action="store", - nargs="?", - type=lambda a: None if a.lower() in none_list else a, - help="image name to use for test cluster", - ) - parser.addoption( - "--image-family", - action="store", - nargs="?", - help="image family to use for test cluster", - ) - parser.addoption( - "--image-project", action="store", help="image project to use for test cluster" - ) - parser.addoption( - "--image-marker", action="store", nargs="?", type=str, help="image marker label" - ) - - -@pytest.hookimpl(tryfirst=True, hookwrapper=True) -def pytest_runtest_makereport(item, call): - # execute all other hooks to obtain the report object - outcome = yield - rep = outcome.get_result() - - # set a report attribute for each phase of a call, which can - # be "setup", "call", "teardown" - - setattr(item, "rep_" + rep.when, rep) - - -# We need to discriminate between arch (arm64, x86_64) because a configuration -# is mapped to a tfvars file that contains instances for an input image. -# Image arch must match instance arch, otherwise instance boot failure. -# -# NOTE: config key should follow: pytest.param("${ARCH}-${TF_CONFIG}") -# -# Use the following to discriminate arch for testing. -# $ pytest -k EXPR -CONFIGS = [ - dict( - marks=("x86_64", "basic"), - moduledir=tf_path / "slurm_cluster/examples/slurm_cluster/test_cluster", - tfvars_file=tfvars_path / "x86_64-basic.tfvars", - tfvars={}, - ), - dict( - marks=("arm64", "basic"), - moduledir=tf_path / "slurm_cluster/examples/slurm_cluster/test_cluster", - tfvars_file=tfvars_path / "arm64-basic.tfvars", - tfvars={}, - ), -] -CONFIGS = {"-".join(conf["marks"]): conf for conf in CONFIGS} -image_pattern = re.compile( - r"^(?:(?P\w+)-)?slurm-gcp-(?:(?P(?P\d+)-(?P\d+)(?:-(?P\d+))?)|(?P\w+))-(?P[\w\-]+?)(?:-(?P\w{10}))?$" -) - -params = ( - pytest.param(k, marks=[getattr(pytest.mark, mark) for mark in conf["marks"]]) - for k, conf in CONFIGS.items() -) - - -@pytest.fixture(scope="session") -def image_marker(request): - from_image = next( - m.group("marker") - for m in ( - image_pattern.match(request.config.getoption("image") or ""), - image_pattern.match(request.config.getoption("image_family" or "")), - ) - if m - ) - return request.config.getoption("image_marker") or from_image - - -@pytest.fixture(params=params, scope="session") -def configuration(request, image_marker): - """fixture providing terraform cluster configuration""" - project_id = request.config.getoption("project_id") - cluster_name = request.config.getoption("cluster_name") - image_project = request.config.getoption("image_project") - image_family = request.config.getoption("image_family") - image = request.config.getoption("image") - request.applymarker(image_marker) - - config = Configuration( - cluster_name=cluster_name, - project_id=project_id, - image_project=image_project, - image_family=image_family, - image=image, - **CONFIGS[request.param], - ) - log.info(f"init cluster {str(config)}") - config.setup() - return config - - -@pytest.fixture(scope="session") -def plan(configuration): - return configuration.tf.plan( - tf_vars=configuration.tfvars, - tf_var_file=configuration.tfvars_file.name, - output=True, - ) - - -@pytest.fixture(scope="session") -def applied(request, configuration): - """fixture providing applied terraform handle""" - request.addfinalizer(configuration.destroy) - log.info(f"apply deployment {str(configuration)}") - configuration.apply() - return configuration.tf - - -@pytest.fixture(scope="session", autouse=True) -def cluster(request, applied): - """fixture providing deploy.Cluster communication handle for the cluster""" - cluster = Cluster(applied) - - def disconnect(): - nonlocal cluster - cluster.save_logs() - log.info("tearing down cluster") - cluster.disconnect() - # TODO verify all instances are removed - - request.addfinalizer(disconnect) - log.info("waiting for cluster to be available") - cluster.activate() - log.info("cluster is now responding") - return cluster - - -@pytest.fixture(scope="session") -def cfg(cluster: Cluster): - """fixture providing util config for the cluster""" - return cluster.cfg - - -@pytest.fixture(scope="session") -def lkp(cfg: util.NSDict): - """fixture providing util.Lookup for the cluster""" - return util.Lookup(cfg) diff --git a/test/deploy.py b/test/deploy.py deleted file mode 100644 index 496c51db..00000000 --- a/test/deploy.py +++ /dev/null @@ -1,451 +0,0 @@ -import json -import logging -import os -import pty -import re -import select -import socket -import subprocess -import sys -import time -from collections import defaultdict -from contextlib import closing -from dataclasses import dataclass, field -from pathlib import Path - -import paramiko -from tftest import TerraformTest -from testutils import backoff_delay, spawn, term_proc, NSDict - -sys.path.append("../scripts") -import util # noqa: E402 - - -log = logging.getLogger() -log.setLevel("INFO") -log.handlers = [] -handler = logging.StreamHandler(sys.stdout) -handler.setLevel("INFO") -# formatter = logging.Formatter() -log.addHandler(handler) - -logging.getLogger("tftest").setLevel("WARNING") -logging.getLogger("paramiko").setLevel("WARNING") - -cred_file = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") -if cred_file: - with Path(cred_file).open("r") as f: - credentials = json.load(f) -else: - credentials = {} - - -def get_sa_user(): - return f"sa_{credentials['client_id']}" - - -def trim_self_link(link: str): - """get resource name from self link url, eg. - https://.../v1/projects//regions/ - -> - """ - try: - return link[link.rindex("/") + 1 :] - except ValueError: - raise Exception(f"'/' not found, not a self link: '{link}' ") - - -class NoPortFoundError(Exception): - pass - - -def find_open_port(): - while True: - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: - s.bind(("localhost", 0)) - s.listen(1) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - port = s.getsockname()[1] - yield port - - -def start_service(service): - for port in find_open_port(): - try: - s = service(port) - except Exception: - continue - if s is not None: - return port, s - else: - raise NoPortFoundError("No available port found") - return -1 - - -class DefaultKeyDict(defaultdict): - def __missing__(self, key): - if self.default_factory: - dict.__setitem__(self, key, self.default_factory(key)) - return self[key] - else: - defaultdict.__missing__(self, key) - - def __repr__(self): - return "{}({})".format( - type(self).__name__, ", ".join(f"{k}:{v}" for k, v in self.items()) - ) - - -@dataclass -class Configuration: - cluster_name: str - project_id: str - moduledir: Path - tfvars_file: Path - tfvars: dict - tf: TerraformTest = field(init=False) - image_project: str - image: str = None - image_family: str = None - marks: list = field(default_factory=list) - - def __post_init__(self): - self.tf = TerraformTest(self.moduledir) - tmp = self._template_tfvars = self.tfvars_file - # basic.tfvars -> -basic.tfvars - self.tfvars_file = self._template_tfvars.with_name( - f"{self.cluster_name}-{tmp.stem}" - ) - self.tfvars_file.write_text(tmp.read_text()) - - vars = { - "project_id": self.project_id, - "slurm_cluster_name": self.cluster_name, - "source_image_family": self.image_family, - "source_image_project": self.image_project, - } - self.tfvars = {**self.tfvars, **vars} - - def __str__(self): - image = self.image or self.image_family - return f"{self.cluster_name}: project={self.project_id} image={self.image_project}/{image} tfmodule={self.moduledir} tfvars={self.tfvars_file} vars={self.tfvars}" - - def setup(self, **kwargs): - all_args = dict(extra_files=[self.tfvars_file], cleanup_on_exit=False) - all_args.update(kwargs) - return self.tf.setup(**all_args) - - def apply(self, **kwargs): - all_args = dict(tf_vars=self.tfvars, tf_var_file=self.tfvars_file.name) - all_args.update(kwargs) - return self.tf.apply(**all_args) - - def destroy(self, **kwargs): - all_args = dict(tf_vars=self.tfvars, tf_var_file=self.tfvars_file.name) - all_args.update(kwargs) - return self.tf.destroy(**all_args) - - -class Tunnel: - def __init__(self, host, target_port=22): - self.host = host - self.target_port = target_port - self.port, self.proc = self._start_tunnel(host, target_port) - self.closed = False - - def __del__(self): - self.close() - - def __repr__(self): - return f"Tunnel({self.port}:{self.host}:{self.target_port}<{self.proc.pid}>)" - - def close(self): - term_proc(self.proc) - - def _start_tunnel(self, instance, target_port): - listen = re.compile(r"^Listening on port \[\d+\].\n$") - log.info(f"start tunnel {instance}:{target_port}") - - def tunnel(port): - """Attempt to create an iap tunnel on the local port""" - # the pty makes gcloud output a message on success, allowing us to - # proceed faster - stdoutfd, peer = pty.openpty() - stdout = os.fdopen(stdoutfd) - proc = spawn( - f"gcloud compute start-iap-tunnel {instance} {target_port} --local-host-port=localhost:{port}", - stderr=subprocess.PIPE, - stdout=peer, - stdin=subprocess.DEVNULL, - ) - stdout_sel = select.poll() - stdout_sel.register(stdout, select.POLLIN) - for w in backoff_delay(0.5, timeout=30): - if proc.poll() is None: - if stdout_sel.poll(1): - out = stdout.readline() - log.debug(f"gcloud iap-tunnel: {out}") - if listen.match(out): - log.debug(f"gcloud iap-tunnel created on port {port}") - return proc - else: - stderr = proc.stderr.read() - log.debug( - f"gcloud iap-tunnel failed on port {port}, rc: {proc.returncode}, stderr: {stderr}" - ) - return None - time.sleep(w) - log.error(f"gcloud iap-tunnel timed out on port {port}") - proc.kill() - return None - - return start_service(tunnel) - - -class Cluster: - def __init__(self, tf, user=None): - self.user = user or get_sa_user() - - self.tf = tf - - self.tunnels = DefaultKeyDict(lambda host: Tunnel(host)) # type: ignore - self.ssh_conns = {} - - self.connected = False - self.active = False - - self.keyfile = Path("gcp_login_id") - if not self.keyfile.exists(): - self.keyfile.write_text(os.environ["GCP_LOGIN_ID"]) - self.keyfile.chmod(0o400) - - def activate(self): - if not self.active: - self.wait_on_active() - - def wait_on_active(self): - for wait in backoff_delay(5, timeout=600): - nodes = [] - try: - nodes = self.get_nodes() - if all(node["state"][0] == "IDLE" for node in nodes): - break - except Exception as e: - log.error( - f"Error getting node state {'+'.join(nodes[0]) if nodes else ''}: {e}" - ) - time.sleep(wait) - else: - raise Exception("Cluster never came up") - self.active = True - - def power_down(self): - all_nodes = ",".join( - p.nodes for p in self.api.slurmctld_get_partitions().partitions - ) - self.login_exec( - f"sudo $(which scontrol) update nodename={all_nodes} state=power_down" - ) - - def ssh(self, instance): - if instance in self.ssh_conns: - return self.ssh_conns[instance] - - ssh = paramiko.SSHClient() - key = paramiko.RSAKey.from_private_key_file(self.keyfile) - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - for wait in backoff_delay(1, timeout=300): - tun = self.tunnels[instance] - log.info( - f"start ssh connection to {self.user}@{trim_self_link(instance)} port {tun.port}" - ) - try: - ssh.connect("127.0.0.1", username=self.user, pkey=key, port=tun.port) - break - except paramiko.ssh_exception.NoValidConnectionsError: - log.error("ssh connection failed") - time.sleep(wait) - tun = self.tunnels.pop(instance) - tun.close() - continue - except Exception as e: - log.error(f"error on start ssh connection: {e}") - else: - log.error(f"Cannot connect through tunnel: {instance}") - raise Exception(f"Cannot connect through tunnel: {instance}") - self.ssh_conns[instance] = ssh - self.connected = True - return ssh - - def _close_ssh(self, instance): - ssh = self.ssh_conns.pop(instance, None) - if ssh: - ssh.close() - tun = self.tunnels.pop(instance, None) - if tun: - tun.close() - - def disconnect(self): - for instance in list(self.ssh_conns): - self._close_ssh(instance) - - @property - def controller_ssh(self): - return self.ssh(self.controller_link) - - @property - def login_ssh(self): - return self.ssh(self.login_link) - - def exec_cmd( - self, - ssh, - cmd, - input="", - prefix="", - timeout=60, - quiet=True, - check=False, - **kwargs, - ): - if not quiet: - log.info(f"{prefix}: {cmd}") - start = time.time() - - stdin, stdout, stderr = ssh.exec_command(cmd, timeout, **kwargs) - if input: - stdin.write(input) - stdin.flush() - stdin.channel.shutdown_write() - status = stdout.channel.recv_exit_status() - stdout = stdout.read().decode() - stderr = stderr.read().decode() - if status and check: - raise Exception(f"Error running command '{cmd}' stderr:{stderr}") - - duration = round(time.time() - start, 3) - start = round(start, 3) - - if not quiet: - log.debug(f"{stdout}") - if status: - log.debug(f"{stderr}") - - result = NSDict( - { - "command": cmd, - "start_time": start, - "duration": duration, - "exit_status": status, - "stdout": stdout, - "stderr": stderr, - } - ) - return result - - def login_exec_output(self, *args, **kwargs): - r = self.login_exec(*args, **kwargs) - return r.stdout or r.stderr - - def controller_exec_output(self, *args, **kwargs): - r = self.controller_exec(*args, **kwargs) - return r.stdout or r.stderr - - def login_exec(self, *args, **kwargs): - return self.exec_cmd(self.login_ssh, *args, prefix=self.login_name, **kwargs) - - def controller_exec(self, *args, **kwargs): - return self.exec_cmd( - self.controller_ssh, *args, prefix=self.controller_name, **kwargs - ) - - def partitions(self): - return self.tf.output()["slurm_partitions"] - - @property - def controller_link(self): - return self.tf.output()["slurm_controller_instance_self_links"][0] - - @property - def controller_name(self): - return trim_self_link(self.controller_link) - - @property - def login_link(self): - return self.tf.output()["slurm_login_instance_self_links"][0] - - @property - def login_name(self): - return trim_self_link(self.login_link) - - @util.cached_property - def cfg(self): - # download the config.yaml from the controller and load it locally - cluster_name = self.tf.output()["slurm_cluster_name"] - cfgfile = Path(f"{cluster_name}-config.yaml") - cfgfile.write_text( - self.controller_exec_output("sudo cat /slurm/scripts/config.yaml") - ) - return util.load_config_file(cfgfile) - - def get_jobs(self): - out = self.login_exec("scontrol show jobs --json")["stdout"] - try: - return json.loads(out)["jobs"] - except Exception as e: - log.error(f"failed to get jobs: {out}") - raise e - - def get_job(self, job_id): - return next((j for j in self.get_jobs() if j["job_id"] == job_id), None) - - def get_nodes(self): - out = self.login_exec("scontrol show nodes --json")["stdout"] - try: - return json.loads(out)["nodes"] - except Exception as e: - log.error(f"failed to get nodes: {out}") - raise e - - def get_node(self, nodename): - return next((n for n in self.get_nodes() if n["name"] == nodename), None) - - def get_file(self, ssh, path): - with ssh.open_sftp() as sftp: - with sftp.file(str(path), "r") as f: - return f.read().decode() - - def login_get_file(self, path): - return self.get_file(self.login_ssh, path) - - def controller_get_file(self, path): - return self.get_file(self.controller_ssh, path) - - def save_logs(self): - local_dir = Path("cluster_logs") - cl_dir = Path("/slurm/scripts/") - paths = [ - "log/slurmdbd.log", - "log/slurmctld.log", - "log/resume.log", - "log/suspend.log", - "log/slurmsync.log", - "etc/slurm.conf", - "etc/cloud.conf", - "etc/gres.conf", - "setup.log", - "config.yaml", - ] - with self.controller_ssh.open_sftp() as sftp: - for path in paths: - clpath = cl_dir / path - fpath = local_dir / clpath.name - self.controller_exec(f"sudo chmod 777 {clpath}") - try: - with sftp.file(str(clpath), "r") as f: - content = f.read().decode() - fpath.parent.mkdir(parents=True, exist_ok=True) - fpath.write_text(content) - log.info(f"saved {fpath.name}") - except IOError: - log.error(f"failed to save file {clpath}") diff --git a/test/requirements.txt b/test/requirements.txt deleted file mode 100644 index 3aeb44ce..00000000 --- a/test/requirements.txt +++ /dev/null @@ -1,60 +0,0 @@ --i https://pypi.org/simple -addict==2.4.0 -asttokens==2.2.1 -backcall==0.2.0 -bcrypt==4.0.1; python_version >= '3.6' -cachetools==5.3.1; python_version >= '3.7' -certifi==2023.7.22; python_version >= '3.6' -cffi==1.15.1 -charset-normalizer==3.2.0; python_full_version >= '3.7.0' -cryptography==43.0.1; python_version >= '3.7' -decorator==5.1.1; python_version >= '3.11' -execnet==2.0.2; python_version >= '3.7' -executing==1.2.0 -google-api-core==2.19.0; python_version >= '3.7' -google-api-python-client==2.93.0 -google-auth==2.22.0; python_version >= '3.6' -google-auth-httplib2==0.1.0 -google-cloud-tpu==1.10.0 -googleapis-common-protos==1.59.1; python_version >= '3.7' -httplib2==0.22.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' -idna==3.7; python_version >= '3.5' -iniconfig==2.0.0; python_version >= '3.7' -ipdb==0.13.13 -ipython==8.14.0 -jedi==0.18.2; python_version >= '3.6' -matplotlib-inline==0.1.6; python_version >= '3.5' -more-executors==2.11.4 -packaging==23.1; python_version >= '3.7' -paramiko==3.4.0 -parso==0.8.3; python_version >= '3.6' -pexpect==4.8.0; sys_platform != 'win32' -pickleshare==0.7.5 -pluggy==1.2.0; python_version >= '3.7' -prompt-toolkit==3.0.39; python_full_version >= '3.7.0' -protobuf==4.23.4; python_version >= '3.7' -psutil==5.9.5 -ptyprocess==0.7.0 -pure-eval==0.2.2 -py-cpuinfo==9.0.0 -pyaml==23.7.0 -pyasn1==0.5.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5' -pyasn1-modules==0.3.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5' -pycparser==2.21 -pygments==2.15.1; python_version >= '3.7' -pynacl==1.5.0; python_version >= '3.6' -pyparsing==3.1.0; python_version >= '3.1' -pytest==7.4.0 -pytest-benchmark==4.0.0 -pytest-xdist[psutil]==3.3.1 -python-hostlist==1.23.0 -pyyaml==6.0 -requests==2.32.2 -rsa==4.9; python_version >= '3.6' and python_version < '4' -six==1.16.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' -stack-data==0.6.2 -tftest==1.8.4 -traitlets==5.9.0; python_version >= '3.7' -uritemplate==4.1.1; python_version >= '3.6' -urllib3==1.26.18; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5' -wcwidth==0.2.6 diff --git a/test/test_config.py b/test/test_config.py deleted file mode 100644 index eada8f92..00000000 --- a/test/test_config.py +++ /dev/null @@ -1,112 +0,0 @@ -import logging - -import pytest - -# from hostlist import expand_hostlist as expand -from testutils import ( - util, - # get_file, -) - -log = logging.getLogger() - - -def test_gpu_config(cluster, lkp): - gpu_groups = {} - for nodeset_name, nodeset in lkp.cfg.nodeset.items(): - template = lkp.template_info(nodeset.instance_template) - if template.gpu_count > 0: - gpu_groups[lkp.nodeset_prefix(nodeset_name)] = template - if not gpu_groups: - pytest.skip("no gpu partitions found") - return - - for prefix, template in gpu_groups.items(): - node = cluster.get_node(f"{prefix}-0") - count = next(g for g in node["gres"].split(",") if g.startswith("gpu")).split( - ":" - )[1] - assert int(count) == template.gpu_count - - -def test_ops_agent(cluster, lkp): - ops_agent_service = "google-cloud-ops-agent-fluent-bit.service" - - def check_ops_agent(inst): - log.info(f"checking if ops agent is active on {inst.name}") - ssh = cluster.ssh(inst.selfLink) - result = cluster.exec_cmd(ssh, f"sudo systemctl status {ops_agent_service}") - if "could not be found" in result.stderr: - pytest.skip(f"Service {ops_agent_service} is not installed.") - result = cluster.exec_cmd(ssh, f"sudo systemctl is-active {ops_agent_service}") - assert result.exit_status == 0 - - lkp.instances.cache_clear() - util.execute_with_futures(check_ops_agent, lkp.instances().values()) - - -def test_controller_custom_scripts(cluster): - check = cluster.controller_exec("ls /slurm/out/controller") - log.debug(f"{check.command}: {check.stdout or check.stderr}") - assert check.exit_status == 0 - - -def test_login_custom_scripts(cluster): - check = cluster.login_exec("ls /slurm/out/login") - log.debug(f"{check.command}: {check.stdout or check.stderr}") - assert check.exit_status == 0 - check = cluster.login_exec("ls /slurm/out/login2") - log.debug(f"{check.command}: {check.stdout or check.stderr}") - assert check.exit_status == 0 - - -# def test_network_mounts(cluster): -# """test cluster-wide and login network storage -# Ignores partition-only network storage for now -# """ -# get_mounts = ( -# "df -h --output=source,target -t nfs4 -t lustre -t gcsfuse -t cfs " -# "| awk '{if (NR!=1) {print $1 \" \" $2}}'" -# ) -# -# def parse_mounts(df): -# return {tuple(mount.split(" ")) for mount in df.splitlines()} -# -# login_mounts = parse_mounts(cluster.login_exec_output(get_mounts)) -# -# # TODO might not work for gcsfuse -# network_storage = { -# (f"{cluster.controller_name}:/home", "/home"), -# (f"{cluster.controller_name}:/usr/local/etc/slurm", "/usr/local/etc/slurm"), -# (f"{cluster.controller_name}:/etc/munge", "/etc/munge"), -# (f"{cluster.controller_name}:/apps", "/apps"), -# } -# network_storage.update( -# { -# ( -# "{}:{}".format( -# cluster.controller_name -# if e["server_ip"] == "$controller" -# else e["server_ip"], -# e["remote_mount"], -# ), -# e["local_mount"], -# ) -# for e in chain( -# cluster.config["network_storage"], -# cluster.config["login_network_storage"], -# ) -# } -# ) -# -# assert network_storage == login_mounts - - -# def test_partitions(cluster, config_partitions, cluster_partitions): -# # The same partition names (keys) should be in config and cluster -# assert set(config_partitions) == set(cluster_partitions) - -# for name, part in cluster_partitions.items(): -# config = config_partitions[name] -# nodelist = expand(part.nodes, sort=True) -# assert len(nodelist) == config['max_node_count'] diff --git a/test/test_jobs.py b/test/test_jobs.py deleted file mode 100644 index 498d91b5..00000000 --- a/test/test_jobs.py +++ /dev/null @@ -1,224 +0,0 @@ -import logging - -import pytest - -from hostlist import expand_hostlist as expand -from deploy import Cluster -from testutils import ( - wait_job_state, - wait_node_state, - wait_node_flags_any, - sbatch, - run, - util, -) -from util import Lookup - -log = logging.getLogger() - - -def test_job(cluster): - job_id = sbatch(cluster, "sbatch -N3 --wrap='srun hostname'") - job = wait_job_state(cluster, job_id, "COMPLETED", "FAILED", "CANCELLED") - assert job["job_state"][0] == "COMPLETED" - - -def test_openmpi(cluster): - prog = r""" -#include -#include - -int main(int argc, char **argv) -{ - int node; - - MPI_Init(&argc, &argv); - MPI_Comm_rank(MPI_COMM_WORLD, &node); - - printf("Hello World from Node %d\n", node); - - MPI_Finalize(); -} -""" - cluster.login_exec("tee hello.c", input=prog) - cluster.login_exec( - "bash --login -c 'module load openmpi && mpicc -o hello hello.c'" - ) - job_id = sbatch(cluster, "sbatch -N3 --wrap='srun hello'") - job = wait_job_state(cluster, job_id, "COMPLETED", "FAILED", "CANCELLED") - log.info(cluster.login_get_file(f"slurm-{job_id}.out")) - assert job["job_state"][0] == "COMPLETED" - - -def test_gpu_job(cluster, lkp): - gpu_parts = {} - for part_name, partition in lkp.cfg.partitions.items(): - for nodeset_name in partition.partition_nodeset: - nodeset = lkp.cfg.nodeset.get(nodeset_name) - template = lkp.template_info(nodeset.instance_template) - if ( - template.gpu_count > 0 - and not template.shieldedInstanceConfig.enableSecureBoot - ): - gpu_parts[part_name] = partition - if not gpu_parts: - pytest.skip("no gpu partitions found") - return - - for part_name, partition in gpu_parts.items(): - job_id = sbatch( - cluster, - f"sbatch --partition={part_name} --gpus=1 --wrap='srun nvidia-smi'", - ) - job = wait_job_state(cluster, job_id, "COMPLETED", "FAILED", "CANCELLED") - assert job["job_state"][0] == "COMPLETED" - log.info(cluster.login_exec_output(f"cat slurm-{job_id}.out")) - - -def test_shielded(image_marker, cluster: Cluster, lkp: Lookup): - # only run test for ubuntu-2004 - log.info(f"detected image_marker:{image_marker}") - if image_marker == "debian-11-arm64": - pytest.skip("shielded not supported on {image_marker}") - skip_gpus = "ubuntu-2004" not in image_marker - - shielded_parts = {} - for part_name, partition in lkp.cfg.partitions.items(): - has_gpus = any( - lkp.template_info( - lkp.cfg.nodeset.get(nodeset_name).instance_template - ).gpu_count - > 0 - for nodeset_name in partition.partition_nodeset - ) - if skip_gpus and has_gpus: - continue - for nodeset_name in partition.partition_nodeset: - nodeset = lkp.cfg.nodeset.get(nodeset_name) - template = lkp.template_info(nodeset.instance_template) - if template.shieldedInstanceConfig.enableSecureBoot: - shielded_parts[part_name] = partition - partition.has_gpus = has_gpus - if not shielded_parts: - pytest.skip("No viable partitions with shielded instances found") - return - - for part_name, partition in shielded_parts.items(): - if partition.has_gpus: - job_id = sbatch( - cluster, - f"sbatch --partition={part_name} --gpus=1 --wrap='srun nvidia-smi'", - ) - else: - job_id = sbatch( - cluster, - f"sbatch --partition={part_name} --wrap='srun hostname'", - ) - job = wait_job_state(cluster, job_id, "COMPLETED", "FAILED", "CANCELLED") - assert job["job_state"][0] == "COMPLETED" - log.info(cluster.login_exec_output(f"cat slurm-{job_id}.out")) - - -def test_placement_groups(cluster, lkp): - nodesets = [] - for nodeset_name, nodeset in lkp.cfg.nodeset.items(): - if nodeset.enable_placement: - nodesets.append(nodeset_name) - partitions = [] - for part_name, partition in lkp.cfg.partitions.items(): - if any(item in nodesets for item in partition.partition_nodeset): - partitions.append(part_name) - if not partitions: - pytest.skip("no partitions with placement groups enabled") - return - - def placement_job(part_name): - job_id = sbatch( - cluster, f"sbatch -N3 --partition={part_name} --wrap='sleep 600'" - ) - job = wait_job_state(cluster, job_id, "RUNNING", max_wait=300) - nodes = expand(job["nodes"]) - physical_hosts = { - node: lkp.describe_instance(node).resourceStatus.physicalHost or None - for node in nodes - } - # this isn't working sometimes now. None are matching - log.debug( - "matching physicalHost IDs: {}".format( - set.intersection(*map(set, physical_hosts.values())) - ) - ) - # assert bool(set.intersection(*physical_hosts)) - assert all(host is not None for node, host in physical_hosts.items()) - cluster.login_exec(f"scancel {job_id}") - job = wait_job_state(cluster, job_id, "CANCELLED") - for node in nodes: - wait_node_flags_any(cluster, node, "idle", "POWERED_DOWN", max_wait=240) - - util.execute_with_futures(placement_job, partitions) - - -# def test_partition_jobs(cluster): -# jobs = [] -# for name, part in cluster_partitions.items(): -# job_id = sbatch( -# cluster, f"sbatch -N2 --partition={name} --wrap='srun hostname'" -# ) -# jobs.append(job_id) -# for job_id in jobs: -# job = wait_job_state(cluster, job_id, "COMPLETED", "FAILED", "CANCELLED") -# assert job["job_state"][0] == "COMPLETED" - - -def test_preemption(cluster: Cluster, lkp: Lookup): - partitions = [] - for part_name, partition in lkp.cfg.partitions.items(): - for nodeset_name in partition.partition_nodeset: - nodeset = lkp.cfg.nodeset.get(nodeset_name) - template = lkp.template_info(nodeset.instance_template) - if template.scheduling.preemptible: - partitions.append(part_name) - break - - if not partitions: - pytest.skip("no partitions with preemptible nodes") - return - - def preemptible_job(part_name): - job_id = sbatch( - cluster, f"sbatch -N2 --partition={part_name} --wrap='srun sleep 9999'" - ) - job = wait_job_state(cluster, job_id, "RUNNING") - last_node = expand(job["nodes"])[-1] - - lkp.instances.cache_clear() - inst = lkp.instance(last_node) - run(f"gcloud compute instances stop {last_node} --zone={inst.zone}") - node = wait_node_state(cluster, last_node, "down", max_wait=180) - assert node["reason"] == "Instance stopped/deleted" - wait_node_state(cluster, last_node, "idle") - cluster.login_exec(f"scancel {job_id}") - wait_job_state(cluster, job_id, "CANCELLED") - - util.execute_with_futures(preemptible_job, partitions) - - -def test_prolog_scripts(cluster: Cluster, lkp: Lookup): - """check that the prolog and epilog scripts ran""" - # The partition this runs on must not be job exclusive so the VM stays - # after job completion - job_id = sbatch(cluster, "sbatch -N1 --wrap='srun sleep 999'") - job = wait_job_state(cluster, job_id, "RUNNING", max_wait=300) - node = next(iter(expand(job["nodes"]))) - - node_ssh = cluster.ssh(lkp.instance(node).selfLink) - check = cluster.exec_cmd(node_ssh, f"ls /slurm/out/prolog_{job_id}") - log.debug(f"{check.command}: {check.stdout or check.stderr}") - assert check.exit_status == 0 - - cluster.login_exec(f"scancel {job_id}") - wait_job_state(cluster, job_id, "CANCELLED", max_wait=300) - - check = cluster.exec_cmd(node_ssh, f"ls /slurm/out/epilog_{job_id}") - log.debug(f"{check.command}: {check.stdout or check.stderr}") - assert check.exit_status == 0 diff --git a/test/test_nodes.py b/test/test_nodes.py deleted file mode 100644 index 69d98ec9..00000000 --- a/test/test_nodes.py +++ /dev/null @@ -1,86 +0,0 @@ -import logging -import hostlist - -from itertools import chain -import pytest - -from deploy import Cluster -from hostlist import expand_hostlist as expand -from testutils import ( - wait_node_flags_any, - wait_until, - wait_job_state, - sbatch, - util, -) -from util import Lookup - - -log = logging.getLogger() - - -def test_static(cluster: Cluster, lkp: util.Lookup): - power_states = set( - ( - "POWERING_DOWN", - "POWERED_DOWN", - "POWERING_UP", - ) - ) - - def is_node_up(node): - info = cluster.get_node(node) - state, *flags = info["state"] - state = state.lower() - flags = set(flags) - log.info( - f"waiting for static node {node} to be up; state={state} flags={','.join(flags)}" - ) - return state == "idle" and not (power_states & flags) - - for node in chain.from_iterable( - hostlist.expand_hostlist(nodes) for nodes in lkp.static_nodelist() - ): - assert wait_until(is_node_up, node) - - -def test_compute_startup_scripts(cluster: Cluster, lkp: Lookup): - """check that custom compute startup scripts ran on static nodes""" - # TODO check non static too? - for node in chain.from_iterable( - hostlist.expand_hostlist(nodes) for nodes in lkp.static_nodelist() - ): - node_ssh = cluster.ssh(lkp.instance(node).selfLink) - check = cluster.exec_cmd(node_ssh, "ls /slurm/out/compute") - log.debug(f"{check.command}: {check.stdout or check.stderr}") - assert check.exit_status == 0 - - -def test_exclusive_labels(cluster: Cluster, lkp: util.Lookup): - partitions = [] - for part_name, partition in lkp.cfg.partitions.items(): - if partition.enable_job_exclusive: - partitions.append(part_name) - if not partitions: - pytest.skip("no partitions with enable_job_exclusive") - return - - def check_node_labels(partition): - job_id = sbatch( - cluster, f"sbatch -N2 --partition={partition} --wrap='sleep 600'" - ) - job = wait_job_state(cluster, job_id, "RUNNING", max_wait=300) - nodes = expand(job["nodes"]) - - node_labels = [lkp.describe_instance(node).labels for node in nodes] - assert all( - "slurm_job_id" in labels and int(labels.slurm_job_id) == job_id - for labels in node_labels - ) - - cluster.login_exec(f"scancel {job_id}") - job = wait_job_state(cluster, job_id, "CANCELLED") - for node in nodes: - wait_node_flags_any(cluster, node, "idle", "POWERED_DOWN") - - util.execute_with_futures(check_node_labels, partitions) diff --git a/test/testutils.py b/test/testutils.py deleted file mode 100644 index 74128fdc..00000000 --- a/test/testutils.py +++ /dev/null @@ -1,183 +0,0 @@ -import logging -import re -import shlex -import subprocess as sp -import sys -import time -from itertools import chain - -import psutil -import yaml - -scripts = "../scripts" -if scripts not in sys.path: - sys.path.append(scripts) -import util # noqa: F401 E402 -from util import backoff_delay, NSDict # noqa: F401 E402 - - -log = logging.getLogger() -log.setLevel(logging.INFO) - - -def term_proc(proc): - try: - psproc = psutil.Process(proc.pid) - except psutil.NoSuchProcess: - log.debug(f"process with pid {proc.pid} doesn't exist") - return - for child in psproc.children(recursive=True): - child.terminate() - try: - child.wait(timeout=1) - except psutil.TimeoutExpired: - log.error(f"killing {child.pid}") - child.kill() - proc.terminate() - try: - proc.wait(timeout=1) - except sp.TimeoutExpired: - log.error(f"killing {proc.pid}") - proc.kill() - - -def run( - cmd, - wait=0, - quiet=False, - get_stdout=False, - shell=False, - universal_newlines=True, - **kwargs, -): - """run in subprocess. Optional wait after return.""" - if not quiet: - log.debug(f"run: {cmd}") - if get_stdout: - kwargs["stdout"] = sp.PIPE - - args = cmd if shell else shlex.split(cmd) - ret = sp.run(args, shell=shell, universal_newlines=universal_newlines, **kwargs) - if wait: - time.sleep(wait) - return ret - - -def run_out(cmd, **kwargs): - kwargs["get_stdout"] = True - kwargs["universal_newlines"] = True - return run(cmd, **kwargs).stdout - - -def spawn(cmd, quiet=False, shell=False, **kwargs): - """nonblocking spawn of subprocess""" - if not quiet: - log.debug(f"spawn: {cmd}") - kwargs["universal_newlines"] = True - args = cmd if shell else shlex.split(cmd) - return sp.Popen(args, shell=shell, **kwargs) - - -def wait_until(check, *args, max_wait=None): - if max_wait is None: - max_wait = 360 - for wait in backoff_delay(1, timeout=max_wait): - if check(*args): - return True - time.sleep(wait) - return False - - -def wait_job_state(cluster, job_id, *states, max_wait=None): - states = set(states) - states_str = "{{ {} }}".format(", ".join(states)) - - def is_job_state(): - state_arr = cluster.get_job(job_id)["job_state"] - log.info(f"job {job_id}: {state_arr} waiting for {states_str}") - return len(state_arr) == 1 and state_arr[0] in states - - assert wait_until(is_job_state, max_wait=max_wait) - return cluster.get_job(job_id) - - -def wait_node_flags_subset(cluster, nodename, state, *flags, max_wait=None): - flags = set(flags) - flags_str = "+".join(chain([state], flags)) - - def check_node_flags(): - info = cluster.get_node(nodename) - node_state, *node_flags = info["state"] - node_state = node_state.lower() - node_flags = set(node_flags) - log.info( - f"waiting for node {nodename} to be {flags_str}; state={node_state} flags={','.join(node_flags)}" - ) - return node_state == state and (flags <= node_flags) - - assert wait_until(check_node_flags, max_wait=max_wait) - return cluster.get_node(nodename) - - -def wait_node_flags_any(cluster, nodename, state, *flags, max_wait=None): - flags = set(flags) - flags_str = "{state}+{flags}".format(state=state, flags=" or ".join(flags)) - - def check_node_flags(): - info = cluster.get_node(nodename) - node_state, *node_flags = info["state"] - node_state = node_state.lower() - node_flags = set(node_flags) - log.info( - f"waiting for node {nodename} to be {flags_str}; state={node_state} flags={','.join(node_flags)}" - ) - return node_state == state and (flags & node_flags) - - assert wait_until(check_node_flags, max_wait=max_wait) - return cluster.get_node(nodename) - - -def wait_node_state(cluster, nodename, *states, max_wait=None): - states = set(states) - states_str = "{{ {} }}".format(", ".join(states)) - - def is_node_state(): - state, *flags = cluster.get_node(nodename)["state"] - state = state.lower() - log.info(f"node {nodename}: {state} waiting for {states_str}") - return state in states - - assert wait_until(is_node_state, max_wait=max_wait) - return cluster.get_node(nodename) - - -# https://stackoverflow.com/questions/3844801/check-if-all-elements-in-a-list-are-identical -def all_equal(coll): - """Return true if coll is empty or all elements are equal""" - it = iter(coll) - try: - first = next(it) - except StopIteration: - return True - return all(first == x for x in it) - - -batch_id = re.compile(r"^Submitted batch job (\d+)$") - - -def sbatch(cluster, cmd): - log.info(cmd) - submit = cluster.login_exec(cmd) - m = batch_id.match(submit.stdout) - if submit.exit_status or m is None: - raise Exception(f"job submit failed: {yaml.safe_dump(submit.to_dict())}") - assert m is not None - job_id = int(m[1]) - return job_id - - -def get_zone(instance): - zone = yaml.safe_load( - run_out(f"gcloud compute instances describe {instance} --format=yaml(zone)") - )["zone"] - return zone diff --git a/test/tfvars/.gitignore b/test/tfvars/.gitignore deleted file mode 100644 index 9477745f..00000000 --- a/test/tfvars/.gitignore +++ /dev/null @@ -1 +0,0 @@ -!*.tfvars diff --git a/test/tfvars/arm64-basic.tfvars b/test/tfvars/arm64-basic.tfvars deleted file mode 100644 index 23a6fc64..00000000 --- a/test/tfvars/arm64-basic.tfvars +++ /dev/null @@ -1,200 +0,0 @@ -region = "us-central1" - -enable_bigquery_load = true -enable_cleanup_compute = true - -create_bucket = false -bucket_name = "slurm-test" - -controller_instance_config = { - disk_size_gb = 32 - disk_type = "pd-ssd" - machine_type = "t2a-standard-8" -} - -controller_startup_scripts = [ - { - filename = "hello_controller.sh" - content = <<-EOF - #!/bin/bash - set -ex - echo "Hello, $(hostname) from $(dirname $0) !" - mkdir -p /slurm/out - touch /slurm/out/controller - EOF - }, -] - -login_startup_scripts = [ - { - filename = "hello_login.sh" - content = <<-EOF - #!/bin/bash - set -ex - echo "Hello, $(hostname) from $(dirname $0) !" - mkdir -p /slurm/out - touch /slurm/out/login - EOF - }, - { - filename = "login2.sh" - content = <<-EOF - #!/bin/bash - set -ex - echo "Hello, $(hostname) from $(dirname $0) !" - mkdir -p /slurm/out - touch /slurm/out/login2 - EOF - }, -] - -compute_startup_scripts = [ - { - filename = "hello_compute.sh" - content = <<-EOF - #!/bin/bash - set -ex - echo "Hello, $(hostname) from $(dirname $0) !" - mkdir -p /slurm/out - touch /slurm/out/compute - EOF - }, -] - -prolog_scripts = [ - { - filename = "hello_prolog.sh" - content = <<-EOF - #!/bin/bash - set -ex - echo "Hello, $(hostname) from $(dirname $0) !" - mkdir -p /slurm/out - touch /slurm/out/prolog_$SLURM_JOBID - EOF - }, -] - -epilog_scripts = [ - { - filename = "hello_epilog.sh" - content = <<-EOF - #!/bin/bash - set -ex - echo "Hello, $(hostname) from $(dirname $0) !" - mkdir -p /slurm/out - touch /slurm/out/epilog_$SLURM_JOBID - EOF - }, -] - -login_nodes = [ - { - # Group Definition - group_name = "frontend" - - # Template By Definition - disk_size_gb = 32 - disk_type = "pd-standard" - machine_type = "t2a-standard-1" - service_account = { - email = "default" - scopes = ["https://www.googleapis.com/auth/cloud-platform"] - } - - # Instance Definition - num_instances = 1 - }, -] - -nodeset = [ - { - # Group Definition - nodeset_name = "t2a2" - node_count_dynamic_max = 20 - node_count_static = 1 - - # Template By Definition - disk_size_gb = 32 - disk_type = "pd-standard" - machine_type = "t2a-standard-2" - service_account = { - email = "default" - scopes = ["https://www.googleapis.com/auth/cloud-platform"] - } - }, - { - # Group Definition - nodeset_name = "t2s2spot" - node_count_dynamic_max = 10 - - # Template By Definition - disk_size_gb = 32 - disk_type = "pd-standard" - machine_type = "t2a-standard-2" - preemptible = true - service_account = { - email = "default" - scopes = [ - "https://www.googleapis.com/auth/cloud-platform", - ] - } - - # Instance Definition - spot_instance_config = { - termination_action = "STOP" - } - }, - { - # Group Definition - nodeset_name = "t2shield" - node_count_dynamic_max = 10 - - # Template By Definition - disk_size_gb = 32 - disk_type = "pd-standard" - enable_shielded_vm = true - machine_type = "t2a-standard-2" - service_account = { - email = "default" - scopes = [ - "https://www.googleapis.com/auth/cloud-platform", - ] - } - shielded_instance_config = { - enable_integrity_monitoring = true - enable_secure_boot = true - enable_vtpm = true - } - }, -] - -partitions = [ - { - partition_name = "debug" - partition_nodeset = ["t2a2", ] - # Options - default = true - enable_job_exclusive = false - resume_timeout = 300 - suspend_timeout = 300 - suspend_time = 300 - }, - { - partition_name = "spot" - partition_nodeset = ["t2s2spot", ] - # Options - enable_job_exclusive = false - resume_timeout = 300 - suspend_timeout = 300 - suspend_time = 300 - }, - { - partition_name = "shield" - partition_nodeset = ["t2shield", ] - # Options - enable_job_exclusive = false - resume_timeout = 300 - suspend_timeout = 300 - suspend_time = 300 - }, -] diff --git a/test/tfvars/x86_64-basic.tfvars b/test/tfvars/x86_64-basic.tfvars deleted file mode 100644 index ae25be23..00000000 --- a/test/tfvars/x86_64-basic.tfvars +++ /dev/null @@ -1,286 +0,0 @@ -region = "us-central1" - -enable_bigquery_load = true -enable_cleanup_compute = true - -create_bucket = false -bucket_name = "slurm-test" - -controller_instance_config = { - disk_size_gb = 32 - disk_type = "pd-ssd" - machine_type = "n1-standard-8" -} - -controller_startup_scripts = [ - { - filename = "hello_controller.sh" - content = <<-EOF - #!/bin/bash - set -ex - echo "Hello, $(hostname) from $(dirname $0) !" - mkdir -p /slurm/out - touch /slurm/out/controller - EOF - }, -] - -login_startup_scripts = [ - { - filename = "hello_login.sh" - content = <<-EOF - #!/bin/bash - set -ex - echo "Hello, $(hostname) from $(dirname $0) !" - mkdir -p /slurm/out - touch /slurm/out/login - EOF - }, - { - filename = "login2.sh" - content = <<-EOF - #!/bin/bash - set -ex - echo "Hello, $(hostname) from $(dirname $0) !" - mkdir -p /slurm/out - touch /slurm/out/login2 - EOF - }, -] - -compute_startup_scripts = [ - { - filename = "hello_compute.sh" - content = <<-EOF - #!/bin/bash - set -ex - echo "Hello, $(hostname) from $(dirname $0) !" - mkdir -p /slurm/out - touch /slurm/out/compute - EOF - }, -] - -prolog_scripts = [ - { - filename = "hello_prolog.sh" - content = <<-EOF - #!/bin/bash - set -ex - echo "Hello, $(hostname) from $(dirname $0) !" - mkdir -p /slurm/out - touch /slurm/out/prolog_$SLURM_JOBID - EOF - }, -] - -epilog_scripts = [ - { - filename = "hello_epilog.sh" - content = <<-EOF - #!/bin/bash - set -ex - echo "Hello, $(hostname) from $(dirname $0) !" - mkdir -p /slurm/out - touch /slurm/out/epilog_$SLURM_JOBID - EOF - }, -] - -login_nodes = [ - { - # Group Definition - group_name = "frontend" - - # Template By Definition - disk_size_gb = 32 - disk_type = "pd-standard" - machine_type = "n1-standard-1" - service_account = { - email = "default" - scopes = ["https://www.googleapis.com/auth/cloud-platform"] - } - - # Instance Definition - num_instances = 1 - }, -] - -nodeset = [ - { - # Group Definition - nodeset_name = "n1s2" - node_count_dynamic_max = 20 - node_count_static = 1 - - # Template By Definition - disk_size_gb = 32 - disk_type = "pd-standard" - machine_type = "n1-standard-2" - service_account = { - email = "default" - scopes = ["https://www.googleapis.com/auth/cloud-platform"] - } - }, - { - # Group Definition - nodeset_name = "v100" - node_count_dynamic_max = 10 - - # Template By Definition - disk_size_gb = 32 - disk_type = "pd-standard" - gpu = { - count = 1 - type = "nvidia-tesla-v100" - } - machine_type = "n1-standard-4" - service_account = { - email = "default" - scopes = ["https://www.googleapis.com/auth/cloud-platform"] - } - }, - { - # Group Definition - nodeset_name = "c2s4" - node_count_dynamic_max = 10 - - # Template By Definition - disk_size_gb = 32 - disk_type = "pd-standard" - machine_type = "c2-standard-4" - service_account = { - email = "default" - scopes = ["https://www.googleapis.com/auth/cloud-platform"] - } - - # Instance Definition - enable_placement = true - }, - { - # Group Definition - nodeset_name = "n1s2spot" - node_count_dynamic_max = 10 - - # Template By Definition - disk_size_gb = 32 - disk_type = "pd-standard" - machine_type = "n1-standard-2" - preemptible = true - service_account = { - email = "default" - scopes = [ - "https://www.googleapis.com/auth/cloud-platform", - ] - } - - # Instance Definition - spot_instance_config = { - termination_action = "STOP" - } - }, - { - # Group Definition - nodeset_name = "v100shield" - node_count_dynamic_max = 10 - - # Template By Definition - disk_size_gb = 32 - disk_type = "pd-standard" - enable_shielded_vm = true - gpu = { - count = 1 - type = "nvidia-tesla-v100" - } - machine_type = "n1-standard-4" - service_account = { - email = "default" - scopes = ["https://www.googleapis.com/auth/cloud-platform"] - } - shielded_instance_config = { - enable_integrity_monitoring = true - enable_secure_boot = true - enable_vtpm = true - } - }, - { - # Group Definition - nodeset_name = "n1s4shield" - node_count_dynamic_max = 10 - - # Template By Definition - disk_size_gb = 32 - disk_type = "pd-standard" - enable_shielded_vm = true - machine_type = "n1-standard-4" - service_account = { - email = "default" - scopes = [ - "https://www.googleapis.com/auth/cloud-platform", - ] - } - shielded_instance_config = { - enable_integrity_monitoring = true - enable_secure_boot = true - enable_vtpm = true - } - }, -] - -partitions = [ - { - partition_name = "debug" - partition_nodeset = ["n1s2", ] - # Options - default = true - enable_job_exclusive = false - resume_timeout = 300 - suspend_timeout = 300 - suspend_time = 300 - }, - { - partition_name = "gpu" - partition_nodeset = ["v100", ] - # Options - enable_job_exclusive = false - resume_timeout = 300 - suspend_timeout = 300 - suspend_time = 300 - }, - { - partition_name = "c2" - partition_nodeset = ["c2s4", ] - # Options - enable_job_exclusive = true - resume_timeout = 300 - suspend_timeout = 300 - suspend_time = 300 - }, - { - partition_name = "spot" - partition_nodeset = ["n1s2spot", ] - # Options - enable_job_exclusive = false - resume_timeout = 300 - suspend_timeout = 300 - suspend_time = 300 - }, - { - partition_name = "shgpu" - partition_nodeset = ["v100shield", ] - # Options - enable_job_exclusive = false - resume_timeout = 300 - suspend_timeout = 300 - suspend_time = 300 - }, - { - partition_name = "shield" - partition_nodeset = ["n1s4shield", ] - # Options - enable_job_exclusive = false - resume_timeout = 300 - suspend_timeout = 300 - suspend_time = 300 - }, -]