Skip to content

Commit

Permalink
AIP-72: Gracefully handle not-found XCOMs in task sdk API client
Browse files Browse the repository at this point in the history
  • Loading branch information
amoghrajesh committed Jan 2, 2025
1 parent 3dd5b0c commit 8862d7d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 5 deletions.
18 changes: 16 additions & 2 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,28 @@ def __init__(self, client: Client):

def get(
self, dag_id: str, run_id: str, task_id: str, key: str, map_index: int | None = None
) -> XComResponse:
) -> XComResponse | ErrorResponse:
"""Get a XCom value from the API server."""
# TODO: check if we need to use map_index as params in the uri
# ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81
params = {}
if map_index is not None:
params.update({"map_index": map_index})
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params)
try:
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params)
except ServerResponseError as e:
if e.response.status_code == HTTPStatus.NOT_FOUND:
log.error(
"XCom not found",
dag_id=dag_id,
run_id=run_id,
task_id=task_id,
key=key,
map_index=map_index,
detail=e.detail,
status_code=e.response.status_code,
)
return ErrorResponse(error=ErrorType.XCOM_NOT_FOUND, detail={"key": key})
return XComResponse.model_validate_json(resp.read())

def set(
Expand Down
8 changes: 7 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
IntermediateTIState,
TaskInstance,
TerminalTIState,
XComResponse,
)
from airflow.sdk.execution_time.comms import (
ConnectionResult,
Expand Down Expand Up @@ -726,7 +727,12 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
resp = var_result.model_dump_json().encode()
elif isinstance(msg, GetXCom):
xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
xcom_result = XComResult.from_xcom_response(xcom)
if isinstance(xcom, XComResponse):
xcom_result = XComResult.from_xcom_response(xcom)
else:
# Airflow 2.x just ignores the absence of an XCom and moves on with a return value of None
# Hence making this resp with key as `key` and value as None, so that the message is sent back to task runner.
xcom_result = XComResult.from_xcom_response(XComResponse(key=msg.key, value=None))
resp = xcom_result.model_dump_json().encode()
elif isinstance(msg, DeferTask):
self._terminal_state = IntermediateTIState.DEFERRED
Expand Down
28 changes: 26 additions & 2 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@
from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
ErrorResponse,
GetConnection,
GetVariable,
GetXCom,
Expand Down Expand Up @@ -762,14 +764,15 @@ def watched_subprocess(self, mocker):
)

@pytest.mark.parametrize(
["message", "expected_buffer", "client_attr_path", "method_arg", "mock_response"],
["message", "expected_buffer", "client_attr_path", "method_arg", "mock_response", "decoded_buffer"],
[
pytest.param(
GetConnection(conn_id="test_conn"),
b'{"conn_id":"test_conn","conn_type":"mysql","type":"ConnectionResult"}\n',
"connections.get",
("test_conn",),
ConnectionResult(conn_id="test_conn", conn_type="mysql"),
None,
id="get_connection",
),
pytest.param(
Expand All @@ -778,6 +781,7 @@ def watched_subprocess(self, mocker):
"variables.get",
("test_key",),
VariableResult(key="test_key", value="test_value"),
None,
id="get_variable",
),
pytest.param(
Expand All @@ -786,6 +790,7 @@ def watched_subprocess(self, mocker):
"variables.set",
("test_key", "test_value", "test_description"),
{"ok": True},
None,
id="set_variable",
),
pytest.param(
Expand All @@ -794,6 +799,7 @@ def watched_subprocess(self, mocker):
"task_instances.defer",
(TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")),
"",
None,
id="patch_task_instance_to_deferred",
),
pytest.param(
Expand All @@ -811,6 +817,7 @@ def watched_subprocess(self, mocker):
),
),
"",
None,
id="patch_task_instance_to_up_for_reschedule",
),
pytest.param(
Expand All @@ -819,6 +826,7 @@ def watched_subprocess(self, mocker):
"xcoms.get",
("test_dag", "test_run", "test_task", "test_key", None),
XComResult(key="test_key", value="test_value"),
None,
id="get_xcom",
),
pytest.param(
Expand All @@ -829,8 +837,18 @@ def watched_subprocess(self, mocker):
"xcoms.get",
("test_dag", "test_run", "test_task", "test_key", 2),
XComResult(key="test_key", value="test_value"),
None,
id="get_xcom_map_index",
),
pytest.param(
GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"),
b'{"key":"test_key","value":null,"type":"XComResult"}\n',
"xcoms.get",
("test_dag", "test_run", "test_task", "test_key", None),
ErrorResponse(error=ErrorType.XCOM_NOT_FOUND, detail={"key": "test_key"}),
XComResult(key="test_key", value=None, type="XComResult"),
id="get_xcom_not_found",
),
pytest.param(
SetXCom(
dag_id="test_dag",
Expand All @@ -850,6 +868,7 @@ def watched_subprocess(self, mocker):
None,
),
{"ok": True},
None,
id="set_xcom",
),
pytest.param(
Expand All @@ -872,6 +891,7 @@ def watched_subprocess(self, mocker):
2,
),
{"ok": True},
None,
id="set_xcom_with_map_index",
),
# we aren't adding all states under TerminalTIState here, because this test's scope is only to check
Expand All @@ -882,6 +902,7 @@ def watched_subprocess(self, mocker):
"",
(),
"",
None,
id="patch_task_instance_to_skipped",
),
pytest.param(
Expand All @@ -890,6 +911,7 @@ def watched_subprocess(self, mocker):
"task_instances.set_rtif",
(TI_ID, {"field1": "rendered_value1", "field2": "rendered_value2"}),
{"ok": True},
None,
id="set_rtif",
),
],
Expand All @@ -903,6 +925,7 @@ def test_handle_requests(
client_attr_path,
method_arg,
mock_response,
decoded_buffer,
time_machine,
):
"""
Expand Down Expand Up @@ -944,4 +967,5 @@ def test_handle_requests(
# Using BytesIO to simulate a readable stream for CommsDecoder.
input_stream = BytesIO(val)
decoder = CommsDecoder(input=input_stream)
assert decoder.get_message() == mock_response
decoded_buffer_message = decoded_buffer if decoded_buffer else mock_response
assert decoder.get_message() == decoded_buffer_message

0 comments on commit 8862d7d

Please sign in to comment.