Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions hindsight-api-slim/hindsight_api/worker/poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,11 @@ async def _execute_task_inner(self, task: ClaimedTask):
traceback.print_exc()
await self._mark_failed(task.operation_id, str(e), task.schema)

# Tasks claimed longer than this are considered abandoned by a dead worker.
# The longest observed legitimate task is ~7 minutes (large PDF extraction).
# 30 minutes provides a safe margin.
_STALE_TASK_THRESHOLD_MINUTES = 30

async def recover_own_tasks(self) -> int:
"""
Recover tasks that were assigned to this worker but not completed.
Expand All @@ -483,7 +488,9 @@ async def recover_own_tasks(self) -> int:
On startup, we reset any tasks stuck in 'processing' for this worker_id
back to 'pending' so they can be picked up again.

Also recovers batch API operations that were in-flight.
Also recovers batch API operations that were in-flight, and reclaims
stale tasks from dead workers (other worker_ids whose tasks have been
stuck in 'processing' beyond the stale threshold).

If tenant_extension is configured, recovers across all tenant schemas.

Expand All @@ -501,7 +508,7 @@ async def recover_own_tasks(self) -> int:
batch_count = await self._recover_batch_operations(schema)
total_count += batch_count

# Then reset normal worker tasks
# Then reset normal worker tasks (own worker_id)
result = await self._pool.execute(
f"""
UPDATE {table}
Expand All @@ -514,6 +521,32 @@ async def recover_own_tasks(self) -> int:
# Parse "UPDATE N" to get count
count = int(result.split()[-1]) if result else 0
total_count += count

# Reclaim stale tasks from dead workers.
# When a worker pod is terminated (restart, deploy, OOM, node
# eviction), it may not release its claimed tasks. The new pod
# gets a different worker_id, so the above query won't match
# the old pod's tasks. Any task stuck in 'processing' with a
# claimed_at older than the threshold is assumed abandoned.
stale_result = await self._pool.execute(
f"""
UPDATE {table}
SET status = 'pending', worker_id = NULL, claimed_at = NULL, updated_at = now()
WHERE status = 'processing'
AND worker_id != $1
AND claimed_at < now() - make_interval(mins => $2)
AND result_metadata->>'batch_id' IS NULL
""",
self._worker_id,
self._STALE_TASK_THRESHOLD_MINUTES,
)
stale_count = int(stale_result.split()[-1]) if stale_result else 0
if stale_count > 0:
logger.warning(
f"Worker {self._worker_id} reclaimed {stale_count} stale tasks "
f"from dead workers (claimed_at > {self._STALE_TASK_THRESHOLD_MINUTES}m ago)"
)
total_count += stale_count
except Exception as e:
# Format schema for logging: custom schemas in quotes, None as-is
schema_display = f'"{schema}"' if schema else str(schema)
Expand Down
67 changes: 67 additions & 0 deletions hindsight-api-slim/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,73 @@ async def test_recover_own_tasks_returns_zero_when_no_stale_tasks(self, pool, cl
recovered_count = await poller.recover_own_tasks()
assert recovered_count == 0

@pytest.mark.asyncio
async def test_recover_reclaims_stale_tasks_from_dead_workers(self, pool, clean_operations):
"""Test that tasks stuck on dead workers for >30min are reclaimed on startup."""
from hindsight_api.worker import WorkerPoller

bank_id = f"test-worker-{uuid.uuid4().hex[:8]}"
await _ensure_bank(pool, bank_id)

# Create a task claimed by a dead worker 60 minutes ago
stale_op_id = uuid.uuid4()
payload = json.dumps({"type": "consolidation", "bank_id": bank_id})
await pool.execute(
"""
INSERT INTO async_operations
(operation_id, bank_id, operation_type, status, task_payload,
worker_id, claimed_at)
VALUES ($1, $2, 'consolidation', 'processing', $3::jsonb,
'dead-worker-abc123', now() - interval '60 minutes')
""",
stale_op_id,
bank_id,
payload,
)

# Create a task claimed by a dead worker only 5 minutes ago (not stale yet)
recent_op_id = uuid.uuid4()
payload2 = json.dumps({"type": "retain", "bank_id": bank_id})
await pool.execute(
"""
INSERT INTO async_operations
(operation_id, bank_id, operation_type, status, task_payload,
worker_id, claimed_at)
VALUES ($1, $2, 'retain', 'processing', $3::jsonb,
'dead-worker-abc123', now() - interval '5 minutes')
""",
recent_op_id,
bank_id,
payload2,
)

# New worker starts up and recovers
poller = WorkerPoller(
pool=pool,
worker_id="new-worker",
executor=lambda x: None,
)

recovered_count = await poller.recover_own_tasks()
# Only the stale task (60 min old) should be recovered
assert recovered_count == 1

# Verify the stale task was reset to pending
stale_row = await pool.fetchrow(
"SELECT status, worker_id FROM async_operations WHERE operation_id = $1",
stale_op_id,
)
assert stale_row["status"] == "pending"
assert stale_row["worker_id"] is None

# Verify the recent task is still processing (not reclaimed)
recent_row = await pool.fetchrow(
"SELECT status, worker_id FROM async_operations WHERE operation_id = $1",
recent_op_id,
)
assert recent_row["status"] == "processing"
assert recent_row["worker_id"] == "dead-worker-abc123"


class TestConcurrentWorkers:
"""Tests for concurrent worker task claiming (FOR UPDATE SKIP LOCKED)."""
Expand Down