Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions secator/ai/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,14 @@ def get_system_prompt(mode: str, workspace_path: str = "", backend=None) -> str:
system_prompt = mode_config["system_prompt"]
ws = workspace_path or "<workspace>"

path_vars = dict(tasks_path=str(TASKS_PATH), workflows_path=str(WORKFLOWS_PATH), profiles_path=str(PROFILES_PATH))
if mode == "attack":
result = system_prompt.safe_substitute(library_reference=build_library_reference(), **path_vars)
elif mode == "exploit":
result = system_prompt.safe_substitute(library_reference=build_library_reference(), **path_vars)
else: # chat mode
result = system_prompt.safe_substitute(output_types_reference=build_output_types_reference())
# The queries.txt constraint (included by every mode) references $query_types and
# $output_types_reference, so they must be substituted for all modes — derive both
# from FINDING_TYPES so they never drift from the registry.
subst = dict(query_types=build_query_types(), output_types_reference=build_output_types_reference())
if mode in ("attack", "exploit"):
path_vars = dict(tasks_path=str(TASKS_PATH), workflows_path=str(WORKFLOWS_PATH), profiles_path=str(PROFILES_PATH))
subst.update(library_reference=build_library_reference(), **path_vars)
result = system_prompt.safe_substitute(**subst)

# Determine interaction rules based on backend
# The mode templates already include ${follow_up} for interactive modes.
Expand Down
2 changes: 1 addition & 1 deletion secator/ai/prompts/constraints/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ run_task(name="httpx", targets=["target3.com"], opts={"rate_limit": 30, "proxy":
run_workflow(name="domain_recon", targets=["example.com"])
run_shell(command="curl -sk https://10.0.0.1/ | head -50")
run_task(name="ai", targets=["example.com"], opts={"prompt": "Enumerate subdomains", "mode": "attack", "session_name": "Subdomain enumeration on example.com", "max_iterations": 5})
run_query(query={'vulnerability': {'severity': {'$in': ['high', 'critical']})
query_workspace(query={"_type": "vulnerability", "severity": {"$in": ["high", "critical"]}})
add_finding(name="XSS vuln", matched_at=["http://testphp.vulnweb.com/hpp/?pp=1"], )
</correct>

Expand Down
27 changes: 27 additions & 0 deletions tests/unit/test_ai_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,33 @@ def test_common_rules_has_no_shouting(self):
self.assertNotIn("NEVER INVENT", COMMON_RULES)
self.assertNotIn("ALWAYS provide", COMMON_RULES)

# === Template-drift regression tests (D1) ===

def test_rendered_prompts_have_no_unsubstituted_template_vars(self):
"""Rendered prompts must not leak $query_types / $output_types_reference (D1)."""
for mode in ("attack", "chat", "exploit"):
prompt = get_system_prompt(mode)
self.assertNotIn("$query_types", prompt, f"$query_types leaked in {mode!r} prompt")
self.assertNotIn("$output_types_reference", prompt, f"$output_types_reference leaked in {mode!r} prompt")

def test_rendered_prompts_substitute_query_types_from_registry(self):
"""$query_types renders to the real FINDING_TYPES names, not a placeholder."""
from secator.ai.prompts import build_query_types
expected = build_query_types()
self.assertIn("vulnerability", expected)
for mode in ("attack", "chat", "exploit"):
self.assertIn(expected, get_system_prompt(mode))

def test_rendered_prompts_have_no_phantom_run_query_tool(self):
"""Examples must call the real query_workspace tool, never a phantom run_query (D1)."""
from secator.ai.tools import TOOL_ACTION_MAP
self.assertEqual(TOOL_ACTION_MAP["query_workspace"], "query")
self.assertNotIn("run_query", TOOL_ACTION_MAP)
for mode in ("attack", "chat", "exploit"):
prompt = get_system_prompt(mode)
self.assertNotIn("run_query", prompt, f"phantom run_query in {mode!r} prompt")
self.assertIn("query_workspace", prompt)


if __name__ == '__main__':
unittest.main()