diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 86c51673fcb94..8f636f4fe7673 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -425,9 +425,17 @@ def run(ti: RuntimeTaskInstance, log: Logger): ) # TODO: Run task failure callbacks here - except (AirflowTaskTimeout, AirflowException): + except AirflowTaskTimeout: # TODO: handle the case of up_for_retry here + # TODO: coagulate this exception handling with AirflowException + # once https://github.com/apache/airflow/issues/45307 is handled ... + except AirflowException: + # TODO: handle the case of up_for_retry here + msg = TaskState( + state=TerminalTIState.FAILED, + end_date=datetime.now(tz=timezone.utc), + ) except AirflowTaskTerminated: # External state updates are already handled with `ti_heartbeat` and will be # updated already be another UI API. So, these exceptions should ideally never be thrown. diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 54a38c3948067..cb756921e1013 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -27,6 +27,7 @@ from uuid6 import uuid7 from airflow.exceptions import ( + AirflowException, AirflowFailException, AirflowSensorTimeout, AirflowSkipException, @@ -330,6 +331,46 @@ def test_run_raises_system_exit(time_machine, mocked_parse, make_ti_context, moc ) +def test_run_raises_airflow_exception(time_machine, mocked_parse, make_ti_context, mock_supervisor_comms): + """Test running a basic task that exits with AirflowException.""" + from airflow.providers.standard.operators.python import PythonOperator + + task = PythonOperator( + task_id="af_exception_task", + python_callable=lambda: (_ for _ in ()).throw( + AirflowException("Oops! I am failing with AirflowException!"), + ), + ) + + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id="af_exception_task", + dag_id="basic_dag_af_exception", + run_id="c", + try_number=1, + ), + file="", + requests_fd=0, + ti_context=make_ti_context(), + ) + + ti = mocked_parse(what, "basic_dag_af_exception", task) + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + run(ti, log=mock.MagicMock()) + + mock_supervisor_comms.send_request.assert_called_once_with( + msg=TaskState( + state=TerminalTIState.FAILED, + end_date=instant, + ), + log=mock.ANY, + ) + + def test_startup_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comms): """Test running a DAG with templated task.""" from airflow.providers.standard.operators.bash import BashOperator