Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 40 additions & 28 deletions backend/app/services/policy/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations

import ast
import hashlib
import inspect
import json
Expand Down Expand Up @@ -82,7 +83,7 @@ class PolicyDescriptor:

name: str
version: str
source_hash: str
policy_hash: str
enabled: bool


Expand Down Expand Up @@ -131,11 +132,6 @@ def evaluate(
"""
...

def compute_source_hash(self) -> str:
"""SHA-256 of this policy's evaluate() method source code."""
source = inspect.getsource(self.evaluate)
return hashlib.sha256(source.encode("utf-8")).hexdigest()


class PolicyEngine:
"""
Expand Down Expand Up @@ -182,20 +178,22 @@ def register(self, policy: Policy) -> None:

The policy is only active if enabled in policy.yaml.
"""
policy_config = self._config.get("policies", {}).get(
policy.name.lower().replace("_policy", "").replace("_", ""), {}
)
# Strict config lookup by policy_id
policy_config = self._config.get("policies", {}).get(policy.policy_id, {})

# Fallback to name-based lookup for backward compatibility (legacy)
if not policy_config:
# Also try exact name match
policy_config = self._config.get("policies", {}).get(policy.name, {})
policy_config = self._config.get("policies", {}).get(
policy.name.lower().replace("_policy", "").replace("_", ""), {}
)

enabled = policy_config.get("enabled", False) if policy_config else False

if enabled:
self._policies.append(policy)
logger.info("Policy registered: %s (version: %s)", policy.name, policy.version)
logger.info("Policy registered: %s (id: %s, version: %s)", policy.name, policy.policy_id, policy.version)
else:
logger.info("Policy skipped (disabled): %s", policy.name)
logger.info("Policy skipped (disabled): %s (id: %s)", policy.name, policy.policy_id)

# Invalidate cached policy set
self._policy_set = None
Expand All @@ -211,7 +209,7 @@ def policy_set(self) -> PolicySet:
PolicyDescriptor(
name=p.name,
version=p.version,
source_hash=p.compute_source_hash(),
policy_hash=self.compute_policy_hash(p),
enabled=True,
)
for p in self._policies
Expand All @@ -227,25 +225,39 @@ def policy_set(self) -> PolicySet:

def compute_policy_hash(self, policy: Policy) -> str:
"""
Compute policy_hash = SHA-256(policy source + canonical config subset).

This captures the full evaluation semantics, not just the code.
Config changes (e.g., allowed_tools list) change the hash.
Compute hash based on normalized class source and canonical config.
System of Record invariant: Hash must capture full semantics and be stable across environments.
"""
source = inspect.getsource(policy.evaluate)

# Extract this policy's config subset
policy_config = self._config.get("policies", {}).get(
policy.name.lower().replace("_policy", "").replace("_", ""), {}
)
if not policy_config:
policy_config = self._config.get("policies", {}).get(policy.name, {})

try:
source = inspect.getsource(policy.__class__)
except (OSError, TypeError) as e:
raise RuntimeError(
f"CRITICAL: Could not retrieve source for policy {policy.name}. "
"System of Record requires source hashing."
) from e

# Normalize source using AST to be robust against formatting/comments
try:
tree = ast.parse(source)
# ast.dump with include_attributes=False (Python 3.9+) removes line numbers/cols
# effectively strictly hashing the logic structure
source_normalized = ast.dump(tree, include_attributes=False)
except SyntaxError:
# Fallback to string normalization if AST fails (unlikely if getsource worked)
source_normalized = source.strip().replace("\r\n", "\n")

# Extract config using explicit policy_id
if not hasattr(policy, "policy_id"):
raise ValueError(f"Policy {policy.name} violates contract: missing 'policy_id'")

policy_config = self._config.get("policies", {}).get(policy.policy_id, {})

# Canonicalize config (RFC 8785 style - no spaces)
config_canonical = json.dumps(
policy_config or {}, sort_keys=True, separators=(",", ":")
)

combined = source + "\n---\n" + config_canonical
combined = source_normalized + "\n---\n" + config_canonical
return hashlib.sha256(combined.encode("utf-8")).hexdigest()

def evaluate(self, events: list[CanonicalEvent]) -> list[ViolationRecord]:
Expand Down
Loading