Skip to content

Commit

Permalink
Fix DataflowJobLink for Beam operators in deferrable mode (#45023)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaksYermak authored Dec 21, 2024
1 parent 9316ed6 commit 279f1fa
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 66 deletions.
34 changes: 29 additions & 5 deletions providers/src/airflow/providers/apache/beam/hooks/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ async def start_python_pipeline_async(
py_interpreter: str = "python3",
py_requirements: list[str] | None = None,
py_system_site_packages: bool = False,
process_line_callback: Callable[[str], None] | None = None,
):
"""
Start Apache Beam python pipeline.
Expand All @@ -470,6 +471,8 @@ async def start_python_pipeline_async(
:param py_system_site_packages: Whether to include system_site_packages in your virtualenv.
See virtualenv documentation for more information.
This option is only relevant if the ``py_requirements`` parameter is not None.
:param process_line_callback: Optional callback which can be used to process
stdout and stderr to detect job id
"""
py_options = py_options or []
if "labels" in variables:
Expand Down Expand Up @@ -518,16 +521,25 @@ async def start_python_pipeline_async(
return_code = await self.start_pipeline_async(
variables=variables,
command_prefix=command_prefix,
process_line_callback=process_line_callback,
)
return return_code

async def start_java_pipeline_async(self, variables: dict, jar: str, job_class: str | None = None):
async def start_java_pipeline_async(
self,
variables: dict,
jar: str,
job_class: str | None = None,
process_line_callback: Callable[[str], None] | None = None,
):
"""
Start Apache Beam Java pipeline.
:param variables: Variables passed to the job.
:param jar: Name of the jar for the pipeline.
:param job_class: Name of the java class for the pipeline.
:param process_line_callback: Optional callback which can be used to process
stdout and stderr to detect job id
:return: Beam command execution return code.
"""
if "labels" in variables:
Expand All @@ -537,6 +549,7 @@ async def start_java_pipeline_async(self, variables: dict, jar: str, job_class:
return_code = await self.start_pipeline_async(
variables=variables,
command_prefix=command_prefix,
process_line_callback=process_line_callback,
)
return return_code

Expand All @@ -545,6 +558,7 @@ async def start_pipeline_async(
variables: dict,
command_prefix: list[str],
working_directory: str | None = None,
process_line_callback: Callable[[str], None] | None = None,
) -> int:
cmd = [*command_prefix, f"--runner={self.runner}"]
if variables:
Expand All @@ -553,20 +567,24 @@ async def start_pipeline_async(
cmd=cmd,
working_directory=working_directory,
log=self.log,
process_line_callback=process_line_callback,
)

async def run_beam_command_async(
self,
cmd: list[str],
log: logging.Logger,
working_directory: str | None = None,
process_line_callback: Callable[[str], None] | None = None,
) -> int:
"""
Run pipeline command in subprocess.
:param cmd: Parts of the command to be run in subprocess
:param working_directory: Working directory
:param log: logger.
:param log: logger
:param process_line_callback: Optional callback which can be used to process
stdout and stderr to detect job id
"""
cmd_str_representation = " ".join(shlex.quote(c) for c in cmd)
log.info("Running command: %s", cmd_str_representation)
Expand All @@ -584,8 +602,8 @@ async def run_beam_command_async(
log.info("Start waiting for Apache Beam process to complete.")

# Creating separate threads for stdout and stderr
stdout_task = asyncio.create_task(self.read_logs(process.stdout))
stderr_task = asyncio.create_task(self.read_logs(process.stderr))
stdout_task = asyncio.create_task(self.read_logs(process.stdout, process_line_callback))
stderr_task = asyncio.create_task(self.read_logs(process.stderr, process_line_callback))

# Waiting for the both tasks to complete
await asyncio.gather(stdout_task, stderr_task)
Expand All @@ -598,10 +616,16 @@ async def run_beam_command_async(
raise AirflowException(f"Apache Beam process failed with return code {return_code}")
return return_code

async def read_logs(self, stream_reader):
async def read_logs(
self,
stream_reader,
process_line_callback: Callable[[str], None] | None = None,
):
while True:
line = await stream_reader.readline()
if not line:
break
decoded_line = line.decode().strip()
if process_line_callback:
process_line_callback(decoded_line)
self.log.info(decoded_line)
27 changes: 13 additions & 14 deletions providers/src/airflow/providers/apache/beam/operators/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,17 @@ def execute_complete(self, context: Context, event: dict[str, Any]):
self.task_id,
event["message"],
)
self.dataflow_job_id = event["dataflow_job_id"]
self.project_id = event["project_id"]
self.location = event["location"]

DataflowJobLink.persist(
self,
context,
self.project_id,
self.location,
self.dataflow_job_id,
)
return {"dataflow_job_id": self.dataflow_job_id}


Expand Down Expand Up @@ -425,13 +436,6 @@ def execute_sync(self, context: Context):

def execute_async(self, context: Context):
if self.is_dataflow and self.dataflow_hook:
DataflowJobLink.persist(
self,
context,
self.dataflow_config.project_id,
self.dataflow_config.location,
self.dataflow_job_id,
)
with self.dataflow_hook.provide_authorized_gcloud():
self.defer(
trigger=BeamPythonPipelineTrigger(
Expand All @@ -443,6 +447,8 @@ def execute_async(self, context: Context):
py_system_site_packages=self.py_system_site_packages,
runner=self.runner,
gcp_conn_id=self.gcp_conn_id,
project_id=self.dataflow_config.project_id,
location=self.dataflow_config.location,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -613,13 +619,6 @@ def execute_sync(self, context: Context):

def execute_async(self, context: Context):
if self.is_dataflow and self.dataflow_hook:
DataflowJobLink.persist(
self,
context,
self.dataflow_config.project_id,
self.dataflow_config.location,
self.dataflow_job_id,
)
with self.dataflow_hook.provide_authorized_gcloud():
self.pipeline_options["jobName"] = self.dataflow_job_name
self.defer(
Expand Down
42 changes: 38 additions & 4 deletions providers/src/airflow/providers/apache/beam/triggers/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
import asyncio
import contextlib
from collections.abc import AsyncIterator, Sequence
from typing import IO, Any
from typing import IO, Any, Callable

from google.cloud.dataflow_v1beta3 import ListJobsRequest

from airflow.providers.apache.beam.hooks.beam import BeamAsyncHook
from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook
from airflow.providers.apache.beam.hooks.beam import BeamAsyncHook, BeamRunnerType
from airflow.providers.google.cloud.hooks.dataflow import (
AsyncDataflowHook,
process_line_and_extract_dataflow_job_id_callback,
)
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.triggers.base import BaseTrigger, TriggerEvent

Expand All @@ -40,6 +43,14 @@ def _get_async_hook(*args, **kwargs) -> BeamAsyncHook:
def _get_sync_dataflow_hook(**kwargs) -> AsyncDataflowHook:
return AsyncDataflowHook(**kwargs)

def _get_dataflow_process_callback(self) -> Callable[[str], None]:
def set_current_dataflow_job_id(job_id):
self.dataflow_job_id = job_id

return process_line_and_extract_dataflow_job_id_callback(
on_new_job_id_callback=set_current_dataflow_job_id
)


class BeamPythonPipelineTrigger(BeamPipelineBaseTrigger):
"""
Expand All @@ -59,6 +70,8 @@ class BeamPythonPipelineTrigger(BeamPipelineBaseTrigger):
:param py_system_site_packages: Whether to include system_site_packages in your virtualenv.
See virtualenv documentation for more information.
This option is only relevant if the ``py_requirements`` parameter is not None.
:param project_id: Optional, the Google Cloud project ID in which to start a job.
:param location: Optional, Job location.
:param runner: Runner on which pipeline will be run. By default, "DirectRunner" is being used.
Other possible options: DataflowRunner, SparkRunner, FlinkRunner, PortableRunner.
See: :class:`~providers.apache.beam.hooks.beam.BeamRunnerType`
Expand All @@ -74,6 +87,8 @@ def __init__(
py_interpreter: str = "python3",
py_requirements: list[str] | None = None,
py_system_site_packages: bool = False,
project_id: str | None = None,
location: str | None = None,
runner: str = "DirectRunner",
gcp_conn_id: str = "google_cloud_default",
):
Expand All @@ -84,6 +99,9 @@ def __init__(
self.py_interpreter = py_interpreter
self.py_requirements = py_requirements
self.py_system_site_packages = py_system_site_packages
self.dataflow_job_id: str | None = None
self.project_id = project_id
self.location = location
self.runner = runner
self.gcp_conn_id = gcp_conn_id

Expand All @@ -98,6 +116,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"py_interpreter": self.py_interpreter,
"py_requirements": self.py_requirements,
"py_system_site_packages": self.py_system_site_packages,
"project_id": self.project_id,
"location": self.location,
"runner": self.runner,
"gcp_conn_id": self.gcp_conn_id,
},
Expand All @@ -106,6 +126,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current pipeline status and yields a TriggerEvent."""
hook = self._get_async_hook(runner=self.runner)
is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()

try:
# Get the current running event loop to manage I/O operations asynchronously
loop = asyncio.get_running_loop()
Expand All @@ -130,6 +152,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
py_interpreter=self.py_interpreter,
py_requirements=self.py_requirements,
py_system_site_packages=self.py_system_site_packages,
process_line_callback=self._get_dataflow_process_callback() if is_dataflow else None,
)
except Exception as e:
self.log.exception("Exception occurred while checking for pipeline state")
Expand All @@ -140,6 +163,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
{
"status": "success",
"message": "Pipeline has finished SUCCESSFULLY",
"dataflow_job_id": self.dataflow_job_id,
"project_id": self.project_id,
"location": self.location,
}
)
else:
Expand Down Expand Up @@ -205,6 +231,7 @@ def __init__(
self.impersonation_chain = impersonation_chain
self.poll_sleep = poll_sleep
self.cancel_timeout = cancel_timeout
self.dataflow_job_id: str | None = None

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize BeamJavaPipelineTrigger arguments and classpath."""
Expand All @@ -229,6 +256,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current Java pipeline status and yields a TriggerEvent."""
hook = self._get_async_hook(runner=self.runner)
is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()

return_code = 0
if self.check_if_running:
Expand Down Expand Up @@ -271,7 +299,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
self.jar = tmp_gcs_file.name

return_code = await hook.start_java_pipeline_async(
variables=self.variables, jar=self.jar, job_class=self.job_class
variables=self.variables,
jar=self.jar,
job_class=self.job_class,
process_line_callback=self._get_dataflow_process_callback() if is_dataflow else None,
)
except Exception as e:
self.log.exception("Exception occurred while starting the Java pipeline")
Expand All @@ -282,6 +313,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
{
"status": "success",
"message": "Pipeline has finished SUCCESSFULLY",
"dataflow_job_id": self.dataflow_job_id,
"project_id": self.project_id,
"location": self.location,
}
)
else:
Expand Down
12 changes: 10 additions & 2 deletions providers/tests/apache/beam/hooks/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,10 @@ async def test_start_pipline_async(self, mock_runner):
)

mock_runner.assert_called_once_with(
cmd=expected_cmd, working_directory=WORKING_DIRECTORY, log=hook.log
cmd=expected_cmd,
working_directory=WORKING_DIRECTORY,
log=hook.log,
process_line_callback=None,
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -516,6 +519,7 @@ async def test_start_python_pipeline(self, mock_create_dir, mock_runner, mocked_
cmd=expected_cmd,
working_directory=None,
log=ANY,
process_line_callback=None,
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -580,6 +584,7 @@ async def test_start_python_pipeline_with_custom_interpreter(
cmd=expected_cmd,
working_directory=None,
log=ANY,
process_line_callback=None,
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -630,6 +635,7 @@ async def test_start_python_pipeline_with_non_empty_py_requirements_and_without_
cmd=expected_cmd,
working_directory=None,
log=ANY,
process_line_callback=None,
)
mock_virtualenv.assert_called_once_with(
venv_directory=mock.ANY,
Expand Down Expand Up @@ -671,5 +677,7 @@ async def test_start_java_pipeline_async(self, mock_start_pipeline, job_class, c
await hook.start_java_pipeline_async(variables=variables, jar=JAR_FILE, job_class=job_class)

mock_start_pipeline.assert_called_once_with(
variables=BEAM_VARIABLES_JAVA_STRING_LABELS, command_prefix=command_prefix
variables=BEAM_VARIABLES_JAVA_STRING_LABELS,
command_prefix=command_prefix,
process_line_callback=None,
)
Loading

0 comments on commit 279f1fa

Please sign in to comment.