diff --git a/secator/ai/actions.py b/secator/ai/actions.py index 5c722f0f9..ba29ee266 100644 --- a/secator/ai/actions.py +++ b/secator/ai/actions.py @@ -726,6 +726,29 @@ def _handle_query(action: Dict, ctx: ActionContext) -> Generator: query_filter = action.get("query", {}) limit = action.get("limit", 100) + # The query_workspace tool schema declares `query` as an object, but some + # models/providers serialize it as a JSON *string* (a known tool-calling + # quirk). Coerce a stringified query back to a dict so the tool works + # regardless of the provider, mirroring the add_finding scalar coercion. + # On a genuinely malformed query, return a clear error the LLM can act on + # instead of crashing _decrypt_dict/search on a non-dict. + if isinstance(query_filter, str): + try: + query_filter = json.loads(query_filter) + except (json.JSONDecodeError, TypeError): + yield Error( + message='query must be a JSON object (e.g. {"_type": "vulnerability"}); ' + f'got an unparseable string: {query_filter[:120]!r}', + _context=context, + ) + return + if not isinstance(query_filter, dict): + yield Error( + message=f'query must be a JSON object; got {type(query_filter).__name__}.', + _context=context, + ) + return + # Decrypt query values if ctx.encryptor: query_filter = _decrypt_dict(query_filter, ctx.encryptor) @@ -1126,6 +1149,11 @@ def _decrypt_dict(d: Dict, encryptor: Any) -> Dict: Returns: Decrypted dictionary """ + # Backstop: callers should pass a dict, but a non-dict (e.g. an LLM that + # stringified an object arg) must not raise `.items()` here — return it + # unchanged rather than crash the whole action. + if not isinstance(d, dict): + return d result = {} for k, v in d.items(): if isinstance(v, str): diff --git a/tests/unit/test_ai_actions.py b/tests/unit/test_ai_actions.py index 07b5ed64b..9e905330d 100644 --- a/tests/unit/test_ai_actions.py +++ b/tests/unit/test_ai_actions.py @@ -39,6 +39,14 @@ def test_decrypt_nested_dict(self): self.assertEqual(result['outer']['inner'], 'VALUE') + def test_decrypt_non_dict_returned_unchanged(self): + """Backstop: a non-dict (e.g. a stringified query arg) must not raise + `.items()` — it is returned unchanged instead of crashing the action.""" + encryptor = MagicMock() + self.assertEqual(_decrypt_dict('{"_type": "url"}', encryptor), '{"_type": "url"}') + self.assertEqual(_decrypt_dict(['a', 'b'], encryptor), ['a', 'b']) + encryptor.decrypt.assert_not_called() + def test_decrypt_list_values(self): encryptor = MagicMock() encryptor.decrypt.side_effect = lambda x: x.upper() @@ -291,6 +299,41 @@ def test_query_success(self, mock_get_engine): for r in result_dicts: self.assertTrue(r['_context'].get('ai_query_result')) + @patch('secator.ai.actions.ActionContext.get_query_engine') + def test_query_stringified_json_is_coerced(self, mock_get_engine): + """A model that passes `query` as a JSON *string* (schema says object) must + still work — coerced to a dict, then searched. Regression for the + AttributeError('str' object has no attribute 'items') in _decrypt_dict.""" + mock_engine = MagicMock() + mock_engine.search.return_value = [{'_type': 'url', '_context': {}}] + mock_get_engine.return_value = mock_engine + # Encryptor active is the exact condition that made the original crash fire. + encryptor = MagicMock() + encryptor.decrypt.side_effect = lambda s: s + ctx = ActionContext(targets=['t.com'], model='m', context={'workspace_id': 'ws1'}, encryptor=encryptor) + + results = list(_handle_query( + {'action': 'query', 'query': '{"_type": "url", "verified": true}'}, ctx)) + + self.assertFalse([r for r in results if isinstance(r, Error)], 'stringified query must not error') + mock_engine.search.assert_called_once_with({'_type': 'url', 'verified': True}, limit=100) + + def test_query_unparseable_string_returns_clean_error(self): + """A non-JSON string yields an Error the LLM can act on — not a crash.""" + ctx = ActionContext(targets=['t.com'], model='m', context={'workspace_id': 'ws1'}) + results = list(_handle_query({'action': 'query', 'query': 'not json at all'}, ctx)) + errors = [r for r in results if isinstance(r, Error)] + self.assertEqual(len(errors), 1) + self.assertIn('JSON object', errors[0].message) + + def test_query_non_dict_returns_clean_error(self): + """A non-dict, non-str query (e.g. a list) yields a clean Error, not a crash.""" + ctx = ActionContext(targets=['t.com'], model='m', context={'workspace_id': 'ws1'}) + results = list(_handle_query({'action': 'query', 'query': ['_type', 'url']}, ctx)) + errors = [r for r in results if isinstance(r, Error)] + self.assertEqual(len(errors), 1) + self.assertIn('JSON object', errors[0].message) + @patch('secator.ai.actions.ActionContext.get_query_engine') def test_query_failure(self, mock_get_engine): mock_engine = MagicMock()