Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions secator/hooks/mongodb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import threading
import time

import pymongo
Expand All @@ -9,7 +10,7 @@
from secator.hooks._dedup import compute_duplicate_updates
from secator.output_types import OUTPUT_TYPES
from secator.runners import Scan, Task, Workflow
from secator.utils import debug, escape_mongodb_url, should_update
from secator.utils import debug, escape_mongodb_url

# import gevent.monkey
# gevent.monkey.patch_all()
Expand All @@ -23,6 +24,10 @@
# Time-based flushing is throttled by CONFIG.runners.backend_update_frequency.
MONGODB_FLUSH_SIZE = 1000

# Guards the per-runner findings buffer: appends happen on the main thread
# (on_item) while flushes can run from the runner's interval thread (on_interval).
_findings_buffer_lock = threading.Lock()

logger = logging.getLogger(__name__)

_mongodb_client = None
Expand Down Expand Up @@ -117,38 +122,50 @@ def update_finding(self, item):
"""
if type(item) not in OUTPUT_TYPES:
return item
# Add all in-memory context HERE (synchronously, before the item is yielded):
# the batched DB write below must never mutate the item, since the flush can
# run later / on another thread, after the item has already been yielded.
if not ObjectId.is_valid(str(item._uuid)):
item._uuid = str(ObjectId())
buffer = getattr(self, '_mongodb_findings_buffer', None)
if buffer is None:
buffer = self._mongodb_findings_buffer = []
buffer.append(
pymongo.UpdateOne({'_id': ObjectId(item._uuid)}, {'$set': item.toDict()}, upsert=True)
)
if len(buffer) >= MONGODB_FLUSH_SIZE:
op = pymongo.UpdateOne({'_id': ObjectId(item._uuid)}, {'$set': item.toDict()}, upsert=True)
with _findings_buffer_lock:
buffer = getattr(self, '_mongodb_findings_buffer', None)
if buffer is None:
buffer = self._mongodb_findings_buffer = []
buffer.append(op)
over_cap = len(buffer) >= MONGODB_FLUSH_SIZE
if over_cap:
flush_findings_buffer(self)
return item


def flush_findings_buffer(self):
"""Write all buffered finding upserts to MongoDB in a single bulk_write."""
buffer = getattr(self, '_mongodb_findings_buffer', None)
if not buffer:
return
"""Write all buffered finding upserts to MongoDB in a single bulk_write.

The buffer is swapped out under a lock (so concurrent on_item appends aren't
lost), then written outside the lock to avoid holding it during DB I/O.
"""
with _findings_buffer_lock:
buffer = getattr(self, '_mongodb_findings_buffer', None)
if not buffer:
return
self._mongodb_findings_buffer = []
start_time = time.time()
client = get_mongodb_client()
db = client.main
count = len(buffer)
db.findings.bulk_write(buffer, ordered=False)
self._mongodb_findings_buffer = []
self._last_findings_flush = start_time
debug(f'flushed {count} findings in {time.time() - start_time:.4f}s', sub='hooks.mongodb', obj_after=False)


def flush_findings(self):
"""on_interval hook: flush buffered findings, throttled by backend_update_frequency."""
if should_update(CONFIG.runners.backend_update_frequency, getattr(self, '_last_findings_flush', None)):
flush_findings_buffer(self)
"""on_interval hook: flush buffered findings.

Cadence is handled by run_hooks' on_interval throttle (backend_update_frequency)
and the runner's interval thread; this just drains the buffer when called.
"""
flush_findings_buffer(self)


def flush_findings_final(self):
Expand Down
54 changes: 51 additions & 3 deletions secator/runners/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import sys
import textwrap
import threading
import uuid

from datetime import datetime, timezone
Expand Down Expand Up @@ -107,6 +108,10 @@ def __init__(self, config, inputs=[], results=[], run_opts={}, hooks={}, validat
self.results = []
self.results_count = 0
self.threads = []
# Interval-hook thread (lazily started in __iter__ so the runner stays
# picklable for Celery, like the monitor thread). None until started.
self._interval_thread = None
self._interval_stop_event = None
self.output = ''
self.started = False
self.done = False
Expand Down Expand Up @@ -458,6 +463,9 @@ def __getstate__(self):
state = self.__dict__.copy()
state['_hooks'] = {}
state['resolved_hooks'] = {name: [] for name in state['resolved_hooks']}
# Threads/Events aren't picklable; they're recreated lazily on run.
state['_interval_thread'] = None
state['_interval_stop_event'] = None
return state

def __setstate__(self, state):
Expand Down Expand Up @@ -545,6 +553,9 @@ def __iter__(self):
else:
self.log_start()

# Fire on_interval hooks on a timer (covers quiet periods); stopped in _finalize
self._start_interval_thread()

# Yield results buffer
yield from self.results_buffer
self.results_buffer = []
Expand Down Expand Up @@ -577,6 +588,9 @@ def __iter__(self):

def _finalize(self):
"""Finalize the runner."""
# Stop interval thread first so it can't flush concurrently with the
# final on_end flush (mark_completed below).
self._stop_interval_thread()
self.join_threads()
gc.collect()
if self.sync:
Expand All @@ -595,6 +609,38 @@ def _finalize(self):
if self.enable_reports:
self.export_reports()

def _start_interval_thread(self):
"""Start a daemon thread firing on_interval hooks on a time interval.

Created here (not in __init__) so the runner stays picklable for Celery,
mirroring the monitor thread. Without this, on_interval only fires when the
runner produces an item, so it stalls during quiet periods. on_interval is
still throttled by backend_update_frequency in run_hooks; a value <= 0
disables time-based backend updates, so we don't start the thread at all.
"""
frequency = CONFIG.runners.backend_update_frequency
if frequency <= 0 or self._interval_thread is not None:
return
self._interval_stop_event = threading.Event()
self._interval_thread = threading.Thread(target=self._interval_loop, args=(frequency,), daemon=True)
self._interval_thread.start()

def _interval_loop(self, frequency):
"""Fire on_interval hooks every `frequency` seconds until stopped."""
while not self._interval_stop_event.wait(frequency):
try:
self.run_hooks('on_interval', sub='interval')
except Exception as e:
self.debug(f'interval thread hook error: {e}', sub='interval')

def _stop_interval_thread(self):
"""Stop the interval thread (called before the final on_end flush)."""
if self._interval_thread and self._interval_stop_event:
self._interval_stop_event.set()
self._interval_thread.join(timeout=2.0)
self._interval_thread = None
self._interval_stop_event = None

def join_threads(self):
"""Wait for all running threads to complete."""
if not self.threads:
Expand Down Expand Up @@ -938,9 +984,11 @@ def toDict(self):
'output': self.output,
'progress': self.progress,
'last_updated_db': self.last_updated_db,
'context': {**self.context, 'celery_ids': list(self.celery_ids_map.keys())},
'errors': [e.toDict() for e in self.errors],
'warnings': [w.toDict() for w in self.warnings],
# Snapshot mutable collections (copy before iterating): toDict can be
# called from the interval thread while the main thread appends.
'context': {**self.context, 'celery_ids': list(self.celery_ids_map.copy())},
'errors': [e.toDict() for e in list(self.errors)],
'warnings': [w.toDict() for w in list(self.warnings)],
}
)
return data
Expand Down