diff --git a/composer/tools/composer_dags.py b/composer/tools/composer_dags.py index a5306fa52d5..02fe8e458aa 100644 --- a/composer/tools/composer_dags.py +++ b/composer/tools/composer_dags.py @@ -111,6 +111,22 @@ def pause_dag( logger.info("Unable to pause DAG %s", dag_id) logger.info(command_output[1]) + @staticmethod + def pause_all_dags( + project_name: str, + environment: str, + location: str, + sdk_endpoint: str, + ) -> None: + """Pause all the DAGs in the given environment.""" + command = ( + f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={sdk_endpoint} gcloud composer environments" + f" run {environment} --project={project_name} --location={location}" + f" dags pause -- \"^(?!airflow_monitoring$).*\" --treat-dag-id-as-regex -y" + ) + command_output = DAG._run_shell_command_locally_once(command=command) + logger.info(command_output[1]) + @staticmethod def unpause_dag( project_name: str, @@ -136,6 +152,22 @@ def unpause_dag( logger.info("Unable to Unpause DAG %s", dag_id) logger.info(command_output[1]) + @staticmethod + def unpause_all_dags( + project_name: str, + environment: str, + location: str, + sdk_endpoint: str, + ) -> None: + """UnPause all the DAGs in the given environment.""" + command = ( + f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={sdk_endpoint} gcloud composer environments" + f" run {environment} --project={project_name} --location={location}" + f" dags unpause -- \".*\" --treat-dag-id-as-regex -y" + ) + command_output = DAG._run_shell_command_locally_once(command=command) + logger.info(command_output[1]) + @staticmethod def describe_environment( project_name: str, environment: str, location: str, sdk_endpoint: str @@ -151,9 +183,74 @@ def describe_environment( logger.info("Environment Info:\n %s", environment_json["name"]) return environment_json + @staticmethod + def pause_unpause_dags_individually( + project_name: str, + environment: str, + location: str, + sdk_endpoint: str, + airflow_version: tuple[int, int, int], + operation: str, + ) -> None: + """Pause or unpause DAGs individually.""" + list_of_dags = DAG.get_list_of_dags( + project_name=project_name, + environment=environment, + location=location, + sdk_endpoint=sdk_endpoint, + airflow_version=airflow_version, + ) + logger.info("List of dags : %s", list_of_dags) + + if operation == "pause": + for dag in list_of_dags: + if dag == "airflow_monitoring": + continue + DAG.pause_dag( + project_name=project_name, + environment=environment, + location=location, + sdk_endpoint=sdk_endpoint, + dag_id=dag, + airflow_version=airflow_version, + ) + else: + for dag in list_of_dags: + DAG.unpause_dag( + project_name=project_name, + environment=environment, + location=location, + sdk_endpoint=sdk_endpoint, + dag_id=dag, + airflow_version=airflow_version, + ) + + @staticmethod + def pause_unpause_all_dags_at_once( + project_name: str, + environment: str, + location: str, + sdk_endpoint: str, + operation: str, + ) -> None: + """Pause or unpause all DAGs at once.""" + if operation == "pause": + DAG.pause_all_dags( + project_name=project_name, + environment=environment, + location=location, + sdk_endpoint=sdk_endpoint, + ) + else: + DAG.unpause_all_dags( + project_name=project_name, + environment=environment, + location=location, + sdk_endpoint=sdk_endpoint, + ) def main( - project_name: str, environment: str, location: str, operation: str, sdk_endpoint=str + project_name: str, environment: str, location: str, operation: str, sdk_endpoint: str ) -> int: logger.info("DAG Pause/UnPause Script for Cloud Composer") environment_info = DAG.describe_environment( @@ -170,37 +267,10 @@ def main( environment_info["config"]["softwareConfig"]["imageVersion"], ) airflow_version = (int(versions[3]), int(versions[4]), int(versions[5])) - list_of_dags = DAG.get_list_of_dags( - project_name=project_name, - environment=environment, - location=location, - sdk_endpoint=sdk_endpoint, - airflow_version=airflow_version, - ) - logger.info("List of dags : %s", list_of_dags) - - if operation == "pause": - for dag in list_of_dags: - if dag == "airflow_monitoring": - continue - DAG.pause_dag( - project_name=project_name, - environment=environment, - location=location, - sdk_endpoint=sdk_endpoint, - dag_id=dag, - airflow_version=airflow_version, - ) + if airflow_version < (2, 9, 0): + DAG.pause_unpause_dags_individually(project_name, environment, location, sdk_endpoint, airflow_version, operation) else: - for dag in list_of_dags: - DAG.unpause_dag( - project_name=project_name, - environment=environment, - location=location, - sdk_endpoint=sdk_endpoint, - dag_id=dag, - airflow_version=airflow_version, - ) + DAG.pause_unpause_all_dags_at_once(project_name, environment, location, sdk_endpoint, operation) return 0