Skip to content

Commit

Permalink
Revert "Don't mistakenly take a lock on DagRun via ti.refresh_from_fb (
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy committed Aug 19, 2022
1 parent c93c56c commit 0d08201
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 22 deletions.
28 changes: 10 additions & 18 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ def clear_task_instances(
if dag_run_state == DagRunState.QUEUED:
dr.last_scheduling_decision = None
dr.start_date = None
session.flush()


class _LazyXComAccessIterator(collections.abc.Iterator):
Expand Down Expand Up @@ -871,35 +870,28 @@ def refresh_from_db(self, session=NEW_SESSION, lock_for_update=False) -> None:
"""
self.log.debug("Refreshing TaskInstance %s from DB", self)

if self in session:
session.refresh(self, TaskInstance.__mapper__.column_attrs.keys())

qry = (
# To avoid joining any relationships, by default select all
# columns, not the object. This also means we get (effectively) a
# namedtuple back, not a TI object
session.query(*TaskInstance.__table__.columns).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.run_id == self.run_id,
TaskInstance.map_index == self.map_index,
)
qry = session.query(TaskInstance).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.run_id == self.run_id,
TaskInstance.map_index == self.map_index,
)

if lock_for_update:
for attempt in run_with_db_retries(logger=self.log):
with attempt:
ti: Optional[TaskInstance] = qry.with_for_update().one_or_none()
ti: Optional[TaskInstance] = qry.with_for_update().first()
else:
ti = qry.one_or_none()
ti = qry.first()
if ti:
# Fields ordered per model definition
self.start_date = ti.start_date
self.end_date = ti.end_date
self.duration = ti.duration
self.state = ti.state
# Since we selected columns, not the object, this is the raw value
self.try_number = ti.try_number
# Get the raw value of try_number column, don't read through the
# accessor here otherwise it will be incremented by one already.
self.try_number = ti._try_number
self.max_tries = ti.max_tries
self.hostname = ti.hostname
self.unixname = ti.unixname
Expand Down
6 changes: 2 additions & 4 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,6 @@ def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker)
ti1.state = State.SCHEDULED

self.scheduler_job._critical_section_enqueue_task_instances(session)
session.flush()
ti1.refresh_from_db(session=session)
assert State.SCHEDULED == ti1.state
session.rollback()
Expand Down Expand Up @@ -1316,9 +1315,8 @@ def test_enqueue_task_instances_sets_ti_state_to_None_if_dagrun_in_finish_state(
session.commit()

with patch.object(BaseExecutor, 'queue_command') as mock_queue_command:
self.scheduler_job._enqueue_task_instances_with_queued_state([ti], session=session)
session.flush()
ti.refresh_from_db(session=session)
self.scheduler_job._enqueue_task_instances_with_queued_state([ti])
ti.refresh_from_db()
assert ti.state == State.NONE
mock_queue_command.assert_not_called()

Expand Down

0 comments on commit 0d08201

Please sign in to comment.