From aedee9cfab4ef48c5532d4af2d974c395c1dd808 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 3 Jan 2025 15:22:01 +0800 Subject: [PATCH] Typing fixes --- tests/models/test_dag.py | 2 ++ tests/models/test_dagrun.py | 2 ++ tests/sensors/test_external_task_sensor.py | 22 ++++++++++++++-------- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 53090f6225bc82..9a1f0dc789b597 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1733,6 +1733,8 @@ def test_clear_dag( session.merge(dagrun_1) task_instance_1 = dagrun_1.get_task_instance(task_id) + if TYPE_CHECKING: + assert task_instance_1 task_instance_1.state = ti_state_begin task_instance_1.job_id = 123 session.merge(task_instance_1) diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 13614d6abeaea2..4701fcd2b251b6 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -123,6 +123,8 @@ def create_dag_run( if task_states is not None: for task_id, task_state in task_states.items(): ti = dag_run.get_task_instance(task_id) + if TYPE_CHECKING: + assert ti ti.set_state(task_state, session) session.flush() diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index ba3cfd6b4480cf..eec3b8c3b09ebe 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -45,11 +45,12 @@ from airflow.providers.standard.sensors.time import TimeSensor from airflow.providers.standard.triggers.external_task import WorkflowTrigger from airflow.serialization.serialized_objects import SerializedBaseOperator +from airflow.timetables.base import DataInterval from airflow.utils.hashlib_wrapper import md5 from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.task_group import TaskGroup -from airflow.utils.timezone import datetime +from airflow.utils.timezone import coerce_datetime, datetime from airflow.utils.types import DagRunType from tests.models import TEST_DAGS_FOLDER @@ -1252,19 +1253,24 @@ def run_tasks( """ runs: dict[str, DagRun] = {} tis: dict[str, TaskInstance] = {} - triggered_by = DagRunTriggeredByType.TEST if AIRFLOW_V_3_0_PLUS else None for dag in dag_bag.dags.values(): - dagrun = dag.create_dagrun( - state=DagRunState.RUNNING, + data_interval = DataInterval(coerce_datetime(logical_date), coerce_datetime(logical_date)) + runs[dag.dag_id] = dagrun = dag.create_dagrun( + run_id=dag.timetable.generate_run_id( + run_type=DagRunType.MANUAL, + logical_date=logical_date, + data_interval=data_interval, + ), logical_date=logical_date, - start_date=logical_date, + data_interval=data_interval, run_type=DagRunType.MANUAL, + triggered_by=DagRunTriggeredByType.TEST, + dag_version=None, + state=DagRunState.RUNNING, + start_date=logical_date, session=session, - data_interval=(logical_date, logical_date), - triggered_by=triggered_by, ) - runs[dag.dag_id] = dagrun # we use sorting by task_id here because for the test DAG structure of ours # this is equivalent to topological sort. It would not work in general case # but it works for our case because we specifically constructed test DAGS