diff --git a/backend/app/services/policy/engine.py b/backend/app/services/policy/engine.py index f9fb75b..3feeffb 100644 --- a/backend/app/services/policy/engine.py +++ b/backend/app/services/policy/engine.py @@ -18,6 +18,7 @@ from __future__ import annotations +import ast import hashlib import inspect import json @@ -82,7 +83,7 @@ class PolicyDescriptor: name: str version: str - source_hash: str + policy_hash: str enabled: bool @@ -131,11 +132,6 @@ def evaluate( """ ... - def compute_source_hash(self) -> str: - """SHA-256 of this policy's evaluate() method source code.""" - source = inspect.getsource(self.evaluate) - return hashlib.sha256(source.encode("utf-8")).hexdigest() - class PolicyEngine: """ @@ -182,20 +178,22 @@ def register(self, policy: Policy) -> None: The policy is only active if enabled in policy.yaml. """ - policy_config = self._config.get("policies", {}).get( - policy.name.lower().replace("_policy", "").replace("_", ""), {} - ) + # Strict config lookup by policy_id + policy_config = self._config.get("policies", {}).get(policy.policy_id, {}) + + # Fallback to name-based lookup for backward compatibility (legacy) if not policy_config: - # Also try exact name match - policy_config = self._config.get("policies", {}).get(policy.name, {}) + policy_config = self._config.get("policies", {}).get( + policy.name.lower().replace("_policy", "").replace("_", ""), {} + ) enabled = policy_config.get("enabled", False) if policy_config else False if enabled: self._policies.append(policy) - logger.info("Policy registered: %s (version: %s)", policy.name, policy.version) + logger.info("Policy registered: %s (id: %s, version: %s)", policy.name, policy.policy_id, policy.version) else: - logger.info("Policy skipped (disabled): %s", policy.name) + logger.info("Policy skipped (disabled): %s (id: %s)", policy.name, policy.policy_id) # Invalidate cached policy set self._policy_set = None @@ -211,7 +209,7 @@ def policy_set(self) -> PolicySet: PolicyDescriptor( name=p.name, version=p.version, - source_hash=p.compute_source_hash(), + policy_hash=self.compute_policy_hash(p), enabled=True, ) for p in self._policies @@ -227,25 +225,39 @@ def policy_set(self) -> PolicySet: def compute_policy_hash(self, policy: Policy) -> str: """ - Compute policy_hash = SHA-256(policy source + canonical config subset). - - This captures the full evaluation semantics, not just the code. - Config changes (e.g., allowed_tools list) change the hash. + Compute hash based on normalized class source and canonical config. + System of Record invariant: Hash must capture full semantics and be stable across environments. """ - source = inspect.getsource(policy.evaluate) - - # Extract this policy's config subset - policy_config = self._config.get("policies", {}).get( - policy.name.lower().replace("_policy", "").replace("_", ""), {} - ) - if not policy_config: - policy_config = self._config.get("policies", {}).get(policy.name, {}) - + try: + source = inspect.getsource(policy.__class__) + except (OSError, TypeError) as e: + raise RuntimeError( + f"CRITICAL: Could not retrieve source for policy {policy.name}. " + "System of Record requires source hashing." + ) from e + + # Normalize source using AST to be robust against formatting/comments + try: + tree = ast.parse(source) + # ast.dump with include_attributes=False (Python 3.9+) removes line numbers/cols + # effectively strictly hashing the logic structure + source_normalized = ast.dump(tree, include_attributes=False) + except SyntaxError: + # Fallback to string normalization if AST fails (unlikely if getsource worked) + source_normalized = source.strip().replace("\r\n", "\n") + + # Extract config using explicit policy_id + if not hasattr(policy, "policy_id"): + raise ValueError(f"Policy {policy.name} violates contract: missing 'policy_id'") + + policy_config = self._config.get("policies", {}).get(policy.policy_id, {}) + + # Canonicalize config (RFC 8785 style - no spaces) config_canonical = json.dumps( policy_config or {}, sort_keys=True, separators=(",", ":") ) - combined = source + "\n---\n" + config_canonical + combined = source_normalized + "\n---\n" + config_canonical return hashlib.sha256(combined.encode("utf-8")).hexdigest() def evaluate(self, events: list[CanonicalEvent]) -> list[ViolationRecord]: