Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions secator/ai/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,29 @@ def _handle_query(action: Dict, ctx: ActionContext) -> Generator:
query_filter = action.get("query", {})
limit = action.get("limit", 100)

# The query_workspace tool schema declares `query` as an object, but some
# models/providers serialize it as a JSON *string* (a known tool-calling
# quirk). Coerce a stringified query back to a dict so the tool works
# regardless of the provider, mirroring the add_finding scalar coercion.
# On a genuinely malformed query, return a clear error the LLM can act on
# instead of crashing _decrypt_dict/search on a non-dict.
if isinstance(query_filter, str):
try:
query_filter = json.loads(query_filter)
except (json.JSONDecodeError, TypeError):
yield Error(
message='query must be a JSON object (e.g. {"_type": "vulnerability"}); '
f'got an unparseable string: {query_filter[:120]!r}',
_context=context,
)
return
if not isinstance(query_filter, dict):
yield Error(
message=f'query must be a JSON object; got {type(query_filter).__name__}.',
_context=context,
)
return

# Decrypt query values
if ctx.encryptor:
query_filter = _decrypt_dict(query_filter, ctx.encryptor)
Expand Down Expand Up @@ -1126,6 +1149,11 @@ def _decrypt_dict(d: Dict, encryptor: Any) -> Dict:
Returns:
Decrypted dictionary
"""
# Backstop: callers should pass a dict, but a non-dict (e.g. an LLM that
# stringified an object arg) must not raise `.items()` here — return it
# unchanged rather than crash the whole action.
if not isinstance(d, dict):
return d
result = {}
for k, v in d.items():
if isinstance(v, str):
Expand Down
43 changes: 43 additions & 0 deletions tests/unit/test_ai_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ def test_decrypt_nested_dict(self):

self.assertEqual(result['outer']['inner'], 'VALUE')

def test_decrypt_non_dict_returned_unchanged(self):
"""Backstop: a non-dict (e.g. a stringified query arg) must not raise
`.items()` — it is returned unchanged instead of crashing the action."""
encryptor = MagicMock()
self.assertEqual(_decrypt_dict('{"_type": "url"}', encryptor), '{"_type": "url"}')
self.assertEqual(_decrypt_dict(['a', 'b'], encryptor), ['a', 'b'])
encryptor.decrypt.assert_not_called()

def test_decrypt_list_values(self):
encryptor = MagicMock()
encryptor.decrypt.side_effect = lambda x: x.upper()
Expand Down Expand Up @@ -291,6 +299,41 @@ def test_query_success(self, mock_get_engine):
for r in result_dicts:
self.assertTrue(r['_context'].get('ai_query_result'))

@patch('secator.ai.actions.ActionContext.get_query_engine')
def test_query_stringified_json_is_coerced(self, mock_get_engine):
"""A model that passes `query` as a JSON *string* (schema says object) must
still work — coerced to a dict, then searched. Regression for the
AttributeError('str' object has no attribute 'items') in _decrypt_dict."""
mock_engine = MagicMock()
mock_engine.search.return_value = [{'_type': 'url', '_context': {}}]
mock_get_engine.return_value = mock_engine
# Encryptor active is the exact condition that made the original crash fire.
encryptor = MagicMock()
encryptor.decrypt.side_effect = lambda s: s
ctx = ActionContext(targets=['t.com'], model='m', context={'workspace_id': 'ws1'}, encryptor=encryptor)

results = list(_handle_query(
{'action': 'query', 'query': '{"_type": "url", "verified": true}'}, ctx))

self.assertFalse([r for r in results if isinstance(r, Error)], 'stringified query must not error')
mock_engine.search.assert_called_once_with({'_type': 'url', 'verified': True}, limit=100)

def test_query_unparseable_string_returns_clean_error(self):
"""A non-JSON string yields an Error the LLM can act on — not a crash."""
ctx = ActionContext(targets=['t.com'], model='m', context={'workspace_id': 'ws1'})
results = list(_handle_query({'action': 'query', 'query': 'not json at all'}, ctx))
errors = [r for r in results if isinstance(r, Error)]
self.assertEqual(len(errors), 1)
self.assertIn('JSON object', errors[0].message)

def test_query_non_dict_returns_clean_error(self):
"""A non-dict, non-str query (e.g. a list) yields a clean Error, not a crash."""
ctx = ActionContext(targets=['t.com'], model='m', context={'workspace_id': 'ws1'})
results = list(_handle_query({'action': 'query', 'query': ['_type', 'url']}, ctx))
errors = [r for r in results if isinstance(r, Error)]
self.assertEqual(len(errors), 1)
self.assertIn('JSON object', errors[0].message)

@patch('secator.ai.actions.ActionContext.get_query_engine')
def test_query_failure(self, mock_get_engine):
mock_engine = MagicMock()
Expand Down