diff --git a/src/aind_airflow_jobs/submit_slurm_job.py b/src/aind_airflow_jobs/submit_slurm_job.py index b24c181..3dcaa35 100644 --- a/src/aind_airflow_jobs/submit_slurm_job.py +++ b/src/aind_airflow_jobs/submit_slurm_job.py @@ -5,12 +5,10 @@ import logging import sys from argparse import ArgumentParser -from datetime import datetime from enum import Enum from pathlib import Path from time import sleep -from typing import Dict, List -from uuid import uuid4 +from typing import List from aind_slurm_rest import ApiClient as Client from aind_slurm_rest import Configuration as Config @@ -18,7 +16,7 @@ from aind_slurm_rest.api.slurm_api import SlurmApi from aind_slurm_rest.models.v0036_job_properties import V0036JobProperties from aind_slurm_rest.models.v0036_job_submission import V0036JobSubmission -from pydantic import Field, SecretStr +from pydantic import SecretStr from pydantic_settings import BaseSettings, SettingsConfigDict logging.basicConfig(level="INFO") @@ -57,34 +55,6 @@ def create_api_client(self) -> SlurmApi: return slurm -class DefaultSlurmSettings(BaseSettings): - """Configurations with default values or expected to be pulled from env - vars.""" - - model_config = SettingsConfigDict(env_prefix="SLURM_") - log_path: str - partition: str - name: str = Field( - default_factory=lambda: ( - f"job" - f"_{str(int(datetime.utcnow().timestamp()))}" - f"_{str(uuid4())[0:5]}" - ) - ) - qos: str = Field(default="dev") - environment: Dict[str, str] = Field( - default={ - "PATH": "/bin:/usr/bin/:/usr/local/bin/", - "LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib", - } - ) - memory_per_node: int = Field(default=50000) - tasks: int = Field(default=1) - minimum_cpus_per_node: int = Field(default=1) - nodes: List[int] = Field(default=[1, 1]) - time_limit: int = Field(default=360) - - class JobState(str, Enum): """The possible job_state values in the V0036JobsResponse class. The enums don't appear to be importable from the aind-slurm-rest api.""" @@ -196,54 +166,9 @@ def __init__( """ self.slurm = slurm self.job_properties = job_properties - self._set_default_job_props(self.job_properties) self.script = script self.polling_request_sleep = poll_job_interval - @staticmethod - def _set_default_job_props(job_properties: V0036JobProperties) -> None: - """ - Set default values for the slurm job if they are not explicitly set - in the job_properties. - Parameters - ---------- - job_properties : V0036JobProperties - The job_properties used to initially construct the class. - """ - # Check if any default values need to be set - basic_attributes_to_check = [ - "name", - "memory_per_node", - "tasks", - "minimum_cpus_per_node", - "nodes", - "time_limit", - "qos", - "partition", - ] - for attribute in basic_attributes_to_check: - if getattr(job_properties, attribute) is None: - setattr( - job_properties, - attribute, - getattr(DefaultSlurmSettings(), attribute), - ) - if ( - job_properties.environment is None - or job_properties.environment == {} - ): - job_properties.environment = DefaultSlurmSettings().environment - if job_properties.standard_out is None: - job_properties.standard_out = str( - Path(DefaultSlurmSettings().log_path) - / f"{job_properties.name}.out" - ) - if job_properties.standard_error is None: - job_properties.standard_error = str( - Path(DefaultSlurmSettings().log_path) - / f"{job_properties.name}_error.out" - ) - def _submit_job(self) -> V0036JobSubmissionResponse: """Submit the job to the slurm cluster.""" job_submission = V0036JobSubmission( diff --git a/tests/test_submit_slurm_job.py b/tests/test_submit_slurm_job.py index fde08f8..f427de7 100644 --- a/tests/test_submit_slurm_job.py +++ b/tests/test_submit_slurm_job.py @@ -32,15 +32,24 @@ class TestSubmitSlurmJob(unittest.TestCase): "SLURM_CLIENT_USERNAME": "username", "SLURM_CLIENT_PASSWORD": "password", "SLURM_CLIENT_ACCESS_TOKEN": "abc-123", - "SLURM_LOG_PATH": "/a/dir/to/write/logs/to", - "SLURM_PARTITION": "some_part", } @patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True) def test_default_job_properties(self): """Tests that default job properties are set correctly.""" slurm_client_settings = SlurmClientSettings() - job_properties = V0036JobProperties(environment={}) + job_properties = V0036JobProperties( + environment={ + "LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib", + "PATH": "/bin:/usr/bin/:/usr/local/bin/", + }, + partition="some_part", + standard_error="/a/dir/to/write/logs/to/job_123_error.out", + standard_out="/a/dir/to/write/logs/to/job_123.out", + qos="dev", + name="job_123", + time_limit=360, + ) script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) slurm = slurm_client_settings.create_api_client() slurm_job = SubmitSlurmJob( @@ -73,51 +82,6 @@ def test_default_job_properties(self): slurm_job.job_properties.environment, ) self.assertEqual(360, slurm_job.job_properties.time_limit) - self.assertEqual(50000, slurm_job.job_properties.memory_per_node) - self.assertEqual([1, 1], slurm_job.job_properties.nodes) - self.assertEqual(1, slurm_job.job_properties.tasks) - - @patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True) - def test_mixed_job_properties(self): - """Tests that job properties are not overwritten.""" - slurm_client_settings = SlurmClientSettings() - job_properties = V0036JobProperties( - environment={}, - name="my_job", - partition="part2", - qos="prod", - time_limit=5, - memory_per_node=500, - ) - script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) - slurm = slurm_client_settings.create_api_client() - slurm_job = SubmitSlurmJob( - slurm=slurm, - script=script, - job_properties=job_properties, - ) - self.assertEqual("part2", slurm_job.job_properties.partition) - self.assertEqual("prod", slurm_job.job_properties.qos) - self.assertTrue("my_job", slurm_job.job_properties.name) - self.assertEqual( - "/a/dir/to/write/logs/to/my_job.out", - slurm_job.job_properties.standard_out, - ) - self.assertEqual( - "/a/dir/to/write/logs/to/my_job_error.out", - slurm_job.job_properties.standard_error, - ) - self.assertEqual( - { - "PATH": "/bin:/usr/bin/:/usr/local/bin/", - "LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib", - }, - slurm_job.job_properties.environment, - ) - self.assertEqual(5, slurm_job.job_properties.time_limit) - self.assertEqual(500, slurm_job.job_properties.memory_per_node) - self.assertEqual([1, 1], slurm_job.job_properties.nodes) - self.assertEqual(1, slurm_job.job_properties.tasks) @patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True) @patch("aind_slurm_rest.api.slurm_api.SlurmApi.slurmctld_submit_job_0") @@ -129,7 +93,18 @@ def test_submit_job_with_errors(self, mock_submit_job: MagicMock): errors=[V0036Error(error="An error occurred.")] ) slurm_client_settings = SlurmClientSettings() - job_properties = V0036JobProperties(environment={}) + job_properties = V0036JobProperties( + environment={ + "LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib", + "PATH": "/bin:/usr/bin/:/usr/local/bin/", + }, + partition="some_part", + standard_error="/a/dir/to/write/logs/to/job_123_error.out", + standard_out="/a/dir/to/write/logs/to/job_123.out", + qos="dev", + name="job_123", + time_limit=360, + ) script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) slurm = slurm_client_settings.create_api_client() slurm_job = SubmitSlurmJob( @@ -154,7 +129,18 @@ def test_submit_job(self, mock_submit_job: MagicMock): errors=[], job_id=12345 ) slurm_client_settings = SlurmClientSettings() - job_properties = V0036JobProperties(environment={}) + job_properties = V0036JobProperties( + environment={ + "LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib", + "PATH": "/bin:/usr/bin/:/usr/local/bin/", + }, + partition="some_part", + standard_error="/a/dir/to/write/logs/to/job_123_error.out", + standard_out="/a/dir/to/write/logs/to/job_123.out", + qos="dev", + name="job_123", + time_limit=360, + ) script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) slurm = slurm_client_settings.create_api_client() slurm_job = SubmitSlurmJob( @@ -216,7 +202,18 @@ def test_monitor_job( ), ] slurm_client_settings = SlurmClientSettings() - job_properties = V0036JobProperties(environment={}, name="mock_job") + job_properties = V0036JobProperties( + environment={ + "LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib", + "PATH": "/bin:/usr/bin/:/usr/local/bin/", + }, + partition="some_part", + standard_error="/a/dir/to/write/logs/to/job_123_error.out", + standard_out="/a/dir/to/write/logs/to/job_123.out", + qos="dev", + name="mock_job", + time_limit=360, + ) script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) slurm = slurm_client_settings.create_api_client() slurm_job = SubmitSlurmJob( @@ -275,7 +272,18 @@ def test_monitor_job_with_errors( ) ] slurm_client_settings = SlurmClientSettings() - job_properties = V0036JobProperties(environment={}, name="mock_job") + job_properties = V0036JobProperties( + environment={ + "LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib", + "PATH": "/bin:/usr/bin/:/usr/local/bin/", + }, + partition="some_part", + standard_error="/a/dir/to/write/logs/to/job_123_error.out", + standard_out="/a/dir/to/write/logs/to/job_123.out", + qos="dev", + name="mock_job", + time_limit=360, + ) script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) slurm = slurm_client_settings.create_api_client() slurm_job = SubmitSlurmJob( @@ -353,7 +361,18 @@ def test_monitor_job_with_fail_code( ), ] slurm_client_settings = SlurmClientSettings() - job_properties = V0036JobProperties(environment={}, name="mock_job") + job_properties = V0036JobProperties( + environment={ + "LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib", + "PATH": "/bin:/usr/bin/:/usr/local/bin/", + }, + partition="some_part", + standard_error="/a/dir/to/write/logs/to/job_123_error.out", + standard_out="/a/dir/to/write/logs/to/job_123.out", + qos="dev", + name="mock_job", + time_limit=360, + ) script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) slurm = slurm_client_settings.create_api_client() slurm_job = SubmitSlurmJob( @@ -395,7 +414,18 @@ def test_monitor_job_with_fail_code( def test_run_job(self, mock_monitor: MagicMock, mock_submit: MagicMock): """Tests that run_job calls right methods.""" slurm_client_settings = SlurmClientSettings() - job_properties = V0036JobProperties(environment={}) + job_properties = V0036JobProperties( + environment={ + "LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib", + "PATH": "/bin:/usr/bin/:/usr/local/bin/", + }, + partition="some_part", + standard_error="/a/dir/to/write/logs/to/job_123_error.out", + standard_out="/a/dir/to/write/logs/to/job_123.out", + qos="dev", + name="job_123", + time_limit=360, + ) script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) slurm = slurm_client_settings.create_api_client() slurm_job = SubmitSlurmJob( @@ -414,7 +444,16 @@ def test_from_args_script_path(self): slurm_client_settings = SlurmClientSettings() slurm = slurm_client_settings.create_api_client() job_properties_json = V0036JobProperties( - environment={} + environment={ + "LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib", + "PATH": "/bin:/usr/bin/:/usr/local/bin/", + }, + partition="some_part", + standard_error="/a/dir/to/write/logs/to/job_123_error.out", + standard_out="/a/dir/to/write/logs/to/job_123.out", + qos="dev", + name="job_123", + time_limit=360, ).model_dump_json() sys_args = [ "--script-path", @@ -442,7 +481,16 @@ def test_from_args_script_encoded(self): slurm_client_settings = SlurmClientSettings() slurm = slurm_client_settings.create_api_client() job_properties_json = V0036JobProperties( - environment={} + environment={ + "LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib", + "PATH": "/bin:/usr/bin/:/usr/local/bin/", + }, + partition="some_part", + standard_error="/a/dir/to/write/logs/to/job_123_error.out", + standard_out="/a/dir/to/write/logs/to/job_123.out", + qos="dev", + name="job_123", + time_limit=360, ).model_dump_json() sys_args = [ "--script-encoded", @@ -460,7 +508,16 @@ def test_from_args_error(self): slurm_client_settings = SlurmClientSettings() slurm = slurm_client_settings.create_api_client() job_properties_json = V0036JobProperties( - environment={} + environment={ + "LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib", + "PATH": "/bin:/usr/bin/:/usr/local/bin/", + }, + partition="some_part", + standard_error="/a/dir/to/write/logs/to/job_123_error.out", + standard_out="/a/dir/to/write/logs/to/job_123.out", + qos="dev", + name="job_123", + time_limit=360, ).model_dump_json() sys_args = ["--job-properties", job_properties_json] with self.assertRaises(Exception) as e: