Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 48 additions & 34 deletions secator/runners/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(self, config, inputs=[], results=[], run_opts={}, hooks={}, validat
self.resolved_hooks = {name: [] for name in HOOKS + getattr(self, 'hooks', [])}
self.debug('registering hooks', obj=list(self.resolved_hooks.keys()), sub='init')
self.register_hooks(hooks)
self._apply_context_drivers()

# Validators
self.resolved_validators = {name: [] for name in VALIDATORS + getattr(self, 'validators', [])}
Expand Down Expand Up @@ -464,46 +465,50 @@ def run(self):
"""
return list(self.__iter__())

def __getstate__(self):
"""Custom pickle: strip hook functions so dynamically-loaded modules
(e.g. secator.hooks.cockpit) don't cause ModuleNotFoundError on workers.
Driver names in context['drivers'] are used to re-load hooks in __setstate__.
def _apply_context_drivers(self):
"""Register hooks for drivers named in ``context['drivers']``.

``context['drivers']`` is the cross-process source of truth for driver
hooks. Loading them at construction (rather than on unpickle) means every
runner gets its driver hooks exactly once, wherever it is built: the
initial dispatch, a chunk task rebuilt by ``run_command`` via
``task_cls(targets, **opts)``, and a chord callback runner. Because
``register_hooks()`` is idempotent, drivers also supplied explicitly via
``self._hooks`` (CLI-resolved hooks or library callers passing
``hooks=HOOKS``) are not registered twice.

This replaces the former ``__getstate__``/``__setstate__`` pair: with the
driver modules discovered at worker startup (``celery.py`` ``IN_WORKER``),
hook functions now pickle/unpickle natively by qualified name, so runners
no longer need to strip hooks on pickle and rebuild them on every unpickle
(which, under ``replace``/chord synchronization, re-registered hooks
O(chunks) times and flooded ``SECATOR_DEBUG=runner`` logs).
"""
state = self.__dict__.copy()
state['_hooks'] = {}
state['resolved_hooks'] = {name: [] for name in state['resolved_hooks']}
return state

def __setstate__(self, state):
"""Custom unpickle: restore runner state then re-register hooks."""
self.__dict__.update(state)
drivers = self.context.get('drivers', [])
if drivers:
from secator.loader import discover_external_drivers, order_drivers

discover_external_drivers()
# Order by canonical priority so authoritative backends (e.g. mongodb)
# register their hooks before relay drivers (e.g. api). Hook lists are
# concatenated in driver order, so this decides hook execution order.
drivers = order_drivers(drivers)
if not drivers:
return
from secator.loader import discover_external_drivers, order_drivers

discover_external_drivers()
# Order by canonical priority so authoritative backends (e.g. mongodb)
# register their hooks before relay drivers (e.g. api). Hook lists are
# concatenated in driver order, so this decides hook execution order.
drivers = order_drivers(drivers)
hooks_list = []
for driver in drivers:
driver_hooks = import_dynamic(f'secator.hooks.{driver}', 'HOOKS')
if driver_hooks:
hooks_list.append(driver_hooks)
merged_hooks = {}
if hooks_list:
from secator.utils import deep_merge_dicts
if not hooks_list:
return
from secator.utils import deep_merge_dicts

merged_hooks = deep_merge_dicts(*hooks_list)
merged_hooks = deep_merge_dicts(*hooks_list)
# Driver HOOKS dicts are keyed by base runner class (Scan/Workflow/Task). A task
# runner's class is its command subclass (e.g. ``whois``), never the base ``Task``,
# so register_hooks()' exact ``hooks.get(self.__class__)`` lookup would miss the
# ``Task`` entry and the task's on_end hook would never re-register on unpickle.
# That left chunk-parent tasks — the only tasks pickled into a chord callback —
# stuck in RUNNING because mark_runner_completed() ran zero on_end hooks. Flatten
# to the base runner type's hooks first (same convention as Workflow handing
# ``self._hooks.get(Task)`` to its task signatures) so they restore in the worker.
# ``Task`` entry. Flatten to the base runner type's hooks first (same convention as
# Workflow handing ``self._hooks.get(Task)`` to its task signatures).
from secator.runners import Scan, Task, Workflow

base_cls = {'scan': Scan, 'workflow': Workflow, 'task': Task}.get(self.config.type)
Expand Down Expand Up @@ -1046,24 +1051,33 @@ def run_validators(self, validator_type, *args, error=True, sub='validators'):
def register_hooks(self, hooks):
"""Register hooks.

Idempotent: a hook already present in ``resolved_hooks[key]`` is skipped
(and not re-logged). This lets the same driver be supplied via both
``self._hooks`` (e.g. library callers passing ``hooks=HOOKS``) and
``context['drivers']`` without registering — or logging — it twice.

Args:
hooks (dict[str, List[Callable]]): List of hooks to register.
"""
for key in self.resolved_hooks:
registered = self.resolved_hooks[key]

# Register class + derived class hooks
class_hook = getattr(self, key, None)
if class_hook:
if class_hook and class_hook not in registered:
fun = self.get_func_path(class_hook)
self.debug('hook registered', obj={'name': key, 'fun': fun}, sub='init')
self.resolved_hooks[key].append(class_hook)
registered.append(class_hook)

# Register user hooks
user_hooks = hooks.get(self.__class__, {}).get(key, [])
# Register user hooks (copy so we never mutate a caller's/shared list)
user_hooks = list(hooks.get(self.__class__, {}).get(key, []))
user_hooks.extend(hooks.get(key, []))
for hook in user_hooks:
if hook in registered:
continue
fun = self.get_func_path(hook)
self.debug('hook registered', obj={'name': key, 'fun': fun}, sub='init')
self.resolved_hooks[key].extend(user_hooks)
registered.append(hook)

def register_validators(self, validators):
"""Register validators.
Expand Down
60 changes: 36 additions & 24 deletions tests/unit/test_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,14 +344,18 @@ def test_run_scan_celery_task(self):
class TestRunnerPickle(unittest.TestCase):
"""Test that Runner objects with dynamic driver hooks can be pickled/unpickled."""

def test_runner_pickle_survives_missing_module_on_worker(self):
"""Reproduces the original bug: unpickling a runner whose hooks reference a
dynamically-loaded module that does not exist on the worker side.
Without the __getstate__/__setstate__ fix this raises ModuleNotFoundError."""
def test_runner_pickle_survives_with_discovered_module(self):
"""Runners pickle/unpickle natively (no __getstate__/__setstate__).

Driver hook functions reference dynamically-loaded modules. The worker
discovers external drivers at startup (celery.py IN_WORKER), so those
modules are in sys.modules and hook functions resolve on unpickle. The
hook survives the round-trip and unpickling does NOT re-register hooks
(which, under replace()/chord synchronization, previously flooded logs)."""
import pickle
import types
import sys
from secator.runners import Workflow
from secator.runners import Runner, Workflow
from secator.loader import get_configs_by_type

workflows = get_configs_by_type('workflow')
Expand All @@ -360,7 +364,7 @@ def test_runner_pickle_survives_missing_module_on_worker(self):

config = workflows[0]

# 1) CLI side: load the driver into sys.modules (as the loader does)
# Driver module present in sys.modules, as on a worker post-discovery
fake_module = types.ModuleType('secator.hooks.testdriver')

def on_start(runner, *args):
Expand All @@ -371,23 +375,31 @@ def on_start(runner, *args):
fake_module.on_start = on_start
sys.modules['secator.hooks.testdriver'] = fake_module

hooks = {Workflow: {'on_start': [on_start]}}
runner = Workflow(config, inputs=['example.com'], run_opts={'dry_run': True}, hooks=hooks, context={})

# Pickle while the module is available (simulates the CLI/sender side)
pickled = pickle.dumps(runner)

# 2) Worker side: remove the module to simulate it not being installed
del sys.modules['secator.hooks.testdriver']

# Without the fix, unpickling would raise:
# ModuleNotFoundError: No module named 'secator.hooks.testdriver'
# With the fix, __getstate__ strips hooks so the bytes contain no reference
# to the dynamic module and unpickling succeeds.
restored = pickle.loads(pickled)
self.assertEqual(restored.name, runner.name)
# Hooks were stripped; dynamic hook not re-registered (driver not in context['drivers'])
self.assertNotIn(on_start, restored.resolved_hooks.get('on_start', []))
try:
hooks = {Workflow: {'on_start': [on_start]}}
runner = Workflow(config, inputs=['example.com'], run_opts={'dry_run': True}, hooks=hooks, context={})
self.assertIn(on_start, runner.resolved_hooks.get('on_start', []))

# Unpickling must NOT call register_hooks (native pickling restores state)
calls = {'n': 0}
orig = Runner.register_hooks

def counting(self, h):
calls['n'] += 1
return orig(self, h)

Runner.register_hooks = counting
try:
restored = pickle.loads(pickle.dumps(runner))
finally:
Runner.register_hooks = orig

self.assertEqual(calls['n'], 0, 'unpickle must not re-register hooks')
self.assertEqual(restored.name, runner.name)
# The dynamically-referenced hook survives natively
self.assertIn(on_start, restored.resolved_hooks.get('on_start', []))
finally:
del sys.modules['secator.hooks.testdriver']

def test_runner_pickle_restores_hooks_from_context_drivers(self):
"""Unpickling a Runner re-registers hooks from context['drivers']."""
Expand Down Expand Up @@ -426,7 +438,7 @@ def on_end(runner, *args):
pickled = pickle.dumps(runner)
restored = pickle.loads(pickled)
self.assertEqual(restored.name, runner.name)
# __setstate__ re-registers on_end from fakedriver via context['drivers']
# on_end is loaded at init from context['drivers'] and survives pickling natively
self.assertIn(on_end, restored.resolved_hooks.get('on_end', []))

finally:
Expand Down
117 changes: 117 additions & 0 deletions tests/unit/test_runner_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import pickle
import sys
import tempfile
import unittest
from pathlib import Path

from secator.config import CONFIG
from secator.loader import discover_external_drivers
from secator.runners._base import Runner

# A minimal external driver, as a user would drop into ~/.secator/templates/.
CUSTOM_DRIVER = '''
from secator.runners import Task


def cd_on_item(self, item):
return item


def cd_on_end(self, *args, **kwargs):
return None


HOOKS = {Task: {'on_item': [cd_on_item], 'on_end': [cd_on_end]}}
'''


def lib_on_item(self, item):
"""Module-level hook (picklable by qualified name), for the library-mode test."""
return item


class TestRunnerHooks(unittest.TestCase):
"""Driver hooks are registered once at construction and survive pickling
natively (no __getstate__/__setstate__), so unpickling never re-registers.
"""

@classmethod
def setUpClass(cls):
cls._tmpdir = tempfile.TemporaryDirectory()
tmp = Path(cls._tmpdir.name)
(tmp / 'mydriver.py').write_text(CUSTOM_DRIVER)
cls._orig_templates = CONFIG.dirs.templates
CONFIG.dirs.templates = tmp
# discover_external_drivers is @cache'd; clear it so it re-scans our tmp dir
# (an earlier test/import may have populated the cache with the real dir)
discover_external_drivers.cache_clear()
discover_external_drivers()

@classmethod
def tearDownClass(cls):
CONFIG.dirs.templates = cls._orig_templates
cls._tmpdir.cleanup()
sys.modules.pop('secator.hooks.mydriver', None)
# reset the cache so later tests re-discover against the restored dir
discover_external_drivers.cache_clear()

def _build_task(self, **kwargs):
# Bind the class straight from sys.modules so its identity matches what
# pickle resolves (other suite tests may reload the task module, which
# would otherwise make pickle reject the "not the same object" class).
import importlib
mod = importlib.import_module('secator.tasks.httpx')
return mod.httpx(['http://localhost'], enable_hooks=False, dry_run=True, **kwargs)

@staticmethod
def _hook_names(task, event):
return [h.__name__ for h in task.resolved_hooks[event]]

def _count_register_hooks(self, fn):
calls = {'n': 0}
orig = Runner.register_hooks

def wrapped(self, hooks):
calls['n'] += 1
return orig(self, hooks)

Runner.register_hooks = wrapped
try:
fn()
finally:
Runner.register_hooks = orig
return calls['n']

def test_driver_hooks_loaded_at_init(self):
# context['drivers'] hooks must be present right after construction (no pickle)
task = self._build_task(context={'drivers': ['mydriver']})
self.assertIn('cd_on_item', self._hook_names(task, 'on_item'))
self.assertIn('cd_on_end', self._hook_names(task, 'on_end'))

def test_custom_driver_hook_survives_pickle_without_reregistration(self):
task = self._build_task(context={'drivers': ['mydriver']})
# unpickling must NOT call register_hooks (this is what flooded the logs)
reg = self._count_register_hooks(lambda: pickle.loads(pickle.dumps(task)))
self.assertEqual(reg, 0, 'unpickle must not re-register hooks')
# and the hook must still be there + callable
back = pickle.loads(pickle.dumps(task))
self.assertIn('cd_on_item', self._hook_names(back, 'on_item'))
self.assertTrue(callable(back.resolved_hooks['on_item'][0]))

def test_registration_is_idempotent(self):
task = self._build_task(context={'drivers': ['mydriver']})
before = len(task.resolved_hooks['on_item'])
# re-apply the same context drivers -> no duplicate registration
task._apply_context_drivers()
self.assertEqual(len(task.resolved_hooks['on_item']), before)

def test_library_mode_explicit_hooks_survive_pickle(self):
# Library callers pass hooks explicitly and may set NO context['drivers'].
task = self._build_task(hooks={'on_item': [lib_on_item]})
self.assertIn('lib_on_item', self._hook_names(task, 'on_item'))
back = pickle.loads(pickle.dumps(task))
self.assertIn('lib_on_item', self._hook_names(back, 'on_item'))


if __name__ == '__main__':
unittest.main()
Loading