Skip to content

Commit

Permalink
Typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr committed Jan 3, 2025
1 parent 99e55bd commit aedee9c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
2 changes: 2 additions & 0 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1733,6 +1733,8 @@ def test_clear_dag(
session.merge(dagrun_1)

task_instance_1 = dagrun_1.get_task_instance(task_id)
if TYPE_CHECKING:
assert task_instance_1
task_instance_1.state = ti_state_begin
task_instance_1.job_id = 123
session.merge(task_instance_1)
Expand Down
2 changes: 2 additions & 0 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def create_dag_run(
if task_states is not None:
for task_id, task_state in task_states.items():
ti = dag_run.get_task_instance(task_id)
if TYPE_CHECKING:
assert ti
ti.set_state(task_state, session)
session.flush()

Expand Down
22 changes: 14 additions & 8 deletions tests/sensors/test_external_task_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@
from airflow.providers.standard.sensors.time import TimeSensor
from airflow.providers.standard.triggers.external_task import WorkflowTrigger
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.timetables.base import DataInterval
from airflow.utils.hashlib_wrapper import md5
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.task_group import TaskGroup
from airflow.utils.timezone import datetime
from airflow.utils.timezone import coerce_datetime, datetime
from airflow.utils.types import DagRunType

from tests.models import TEST_DAGS_FOLDER
Expand Down Expand Up @@ -1252,19 +1253,24 @@ def run_tasks(
"""
runs: dict[str, DagRun] = {}
tis: dict[str, TaskInstance] = {}
triggered_by = DagRunTriggeredByType.TEST if AIRFLOW_V_3_0_PLUS else None

for dag in dag_bag.dags.values():
dagrun = dag.create_dagrun(
state=DagRunState.RUNNING,
data_interval = DataInterval(coerce_datetime(logical_date), coerce_datetime(logical_date))
runs[dag.dag_id] = dagrun = dag.create_dagrun(
run_id=dag.timetable.generate_run_id(
run_type=DagRunType.MANUAL,
logical_date=logical_date,
data_interval=data_interval,
),
logical_date=logical_date,
start_date=logical_date,
data_interval=data_interval,
run_type=DagRunType.MANUAL,
triggered_by=DagRunTriggeredByType.TEST,
dag_version=None,
state=DagRunState.RUNNING,
start_date=logical_date,
session=session,
data_interval=(logical_date, logical_date),
triggered_by=triggered_by,
)
runs[dag.dag_id] = dagrun
# we use sorting by task_id here because for the test DAG structure of ours
# this is equivalent to topological sort. It would not work in general case
# but it works for our case because we specifically constructed test DAGS
Expand Down

0 comments on commit aedee9c

Please sign in to comment.