Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 8 additions & 311 deletions src/codex_plugin_scanner/guard/runtime/command_executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import json
import tempfile
from collections.abc import Callable
from datetime import datetime, timezone
Expand All @@ -17,7 +16,7 @@
list_harness_setup_items,
uninstall_confirmation_token,
)
from ..config import VALID_RECEIPT_REDACTION_LEVELS, load_guard_config
from ..config import load_guard_config
from ..local_supply_chain import (
build_workspace_audit_payload,
managed_install_audit_workspace_dirs,
Expand All @@ -26,10 +25,8 @@
)
from ..models import DECISION_SCOPE_VALUES, GUARD_ACTION_VALUES, DecisionScope, GuardAction, PolicyDecision
from ..package_shim_status import record_package_shim_audit_result
from ..redaction import redact_text
from ..review_contracts import (
GuardReviewContractError,
build_local_review_request_claim,
guard_review_oauth_metadata,
validate_decision_memory_bundle_target,
validate_remote_approval_request_binding,
Expand All @@ -44,6 +41,7 @@
probe_package_shim_intercepts,
)
from ..store import GuardStore
from . import local_request_snapshots

_GUARD_REVIEW_MEMORY_REGISTRY_SYNC_KEY = "guard_review_memory_registry"
Comment thread
greptile-apps[bot] marked this conversation as resolved.
_GUARD_REVIEW_MEMORY_VERSION_SYNC_KEY = "guard_review_memory_policy_version"
Expand All @@ -70,11 +68,8 @@
)
SUPPORTED_COMMAND_OPERATIONS: tuple[str, ...] = (*PACKAGE_SHIM_OPERATIONS, *APP_OPERATIONS, *APPROVAL_OPERATIONS)
COMMAND_OPERATION_SCHEMA_VERSIONS: dict[str, int] = {operation: 1 for operation in SUPPORTED_COMMAND_OPERATIONS}
LOCAL_REQUEST_PENDING_SNAPSHOT_LIMIT = 200
LOCAL_REQUEST_RESOLVED_SNAPSHOT_LIMIT = 50
LOCAL_REQUEST_SNAPSHOT_MAX_BYTES = 600_000
LOCAL_REQUEST_SNAPSHOT_MAX_STRING_CHARS = 2_000
LOCAL_REQUEST_SNAPSHOT_MAX_LIST_ITEMS = 20
LOCAL_REQUEST_PENDING_SNAPSHOT_LIMIT = local_request_snapshots.LOCAL_REQUEST_PENDING_SNAPSHOT_LIMIT
LOCAL_REQUEST_RESOLVED_SNAPSHOT_LIMIT = local_request_snapshots.LOCAL_REQUEST_RESOLVED_SNAPSHOT_LIMIT


def execute_guard_command_job(
Expand Down Expand Up @@ -424,311 +419,13 @@ def _is_guard_action(value: object) -> TypeGuard[GuardAction]:


def _local_request_snapshot_items(store: GuardStore) -> list[dict[str, object]]:
pending_items, _ = _local_request_snapshot_items_for_status(
store,
status="pending",
limit=100,
)
resolved_items, _ = _local_request_snapshot_items_for_status(
store,
status="resolved",
limit=100,
)
return [*pending_items, *resolved_items]
return local_request_snapshots.local_request_snapshot_items(store)


def _local_request_snapshot_payload(store: GuardStore) -> dict[str, object]:
pending_items, pending_complete = _local_request_snapshot_items_for_status(
store,
status="pending",
limit=LOCAL_REQUEST_PENDING_SNAPSHOT_LIMIT,
)
resolved_items, resolved_complete = _local_request_snapshot_items_for_status(
store,
status="resolved",
limit=LOCAL_REQUEST_RESOLVED_SNAPSHOT_LIMIT,
)
requests, pending_byte_complete, resolved_byte_complete = _local_request_snapshot_byte_capped_statuses(
pending_items,
resolved_items,
max_bytes=LOCAL_REQUEST_SNAPSHOT_MAX_BYTES,
)
return {
"requests": requests,
"pendingComplete": pending_complete and pending_byte_complete,
"resolvedComplete": resolved_complete and resolved_byte_complete,
"pendingLimit": LOCAL_REQUEST_PENDING_SNAPSHOT_LIMIT,
"resolvedLimit": LOCAL_REQUEST_RESOLVED_SNAPSHOT_LIMIT,
"pendingCount": len(pending_items),
"resolvedCount": len(resolved_items),
}


def _local_request_snapshot_byte_capped_statuses(
pending_items: list[dict[str, object]],
resolved_items: list[dict[str, object]],
*,
max_bytes: int,
) -> tuple[list[dict[str, object]], bool, bool]:
selected, pending_complete = _local_request_snapshot_byte_capped_items(
pending_items,
max_bytes=max_bytes,
)
if not pending_complete:
return selected, False, False

selected, resolved_complete = _local_request_snapshot_byte_capped_items(
resolved_items,
existing_items=selected,
max_bytes=max_bytes,
)
return selected, True, resolved_complete


def _local_request_snapshot_byte_capped_items(
items: list[dict[str, object]],
*,
max_bytes: int,
existing_items: list[dict[str, object]] | None = None,
) -> tuple[list[dict[str, object]], bool]:
selected: list[dict[str, object]] = list(existing_items or [])
initial_len = len(selected)
for item in items:
candidate = [*selected, item]
candidate_bytes = len(
json.dumps({"requests": candidate}, separators=(",", ":"), sort_keys=True).encode(
"utf-8",
),
)
if candidate_bytes > max_bytes:
if len(selected) == initial_len:
compact_item = _compact_local_request_snapshot_item(item)
compact_candidate = [*selected, compact_item]
compact_bytes = len(
json.dumps(
{"requests": compact_candidate},
separators=(",", ":"),
sort_keys=True,
).encode("utf-8"),
)
if compact_bytes <= max_bytes:
selected.append(compact_item)
return selected, False
selected.append(item)
return selected, True


def _compact_local_request_snapshot_item(item: dict[str, object]) -> dict[str, object]:
compact = {key: _compact_local_request_snapshot_value(value) for key, value in item.items()}
compact_bytes = len(json.dumps(compact, separators=(",", ":"), sort_keys=True).encode("utf-8"))
if compact_bytes <= LOCAL_REQUEST_SNAPSHOT_MAX_BYTES:
return compact
safe_keys = (
"localRequestId",
"status",
"harness",
"artifactId",
"artifactName",
"artifactType",
"policyAction",
"recommendedScope",
"createdAt",
"lastSeenAt",
"riskHeadline",
"riskSummary",
"rawCommandText",
"reviewCommand",
)
return {key: compact[key] for key in safe_keys if key in compact}


def _compact_local_request_snapshot_value(value: object) -> object:
if isinstance(value, str):
if len(value) <= LOCAL_REQUEST_SNAPSHOT_MAX_STRING_CHARS:
return value
return f"{value[:LOCAL_REQUEST_SNAPSHOT_MAX_STRING_CHARS]}...[truncated]"
if isinstance(value, list):
return [_compact_local_request_snapshot_value(item) for item in value[:LOCAL_REQUEST_SNAPSHOT_MAX_LIST_ITEMS]]
if isinstance(value, dict):
return {str(key): _compact_local_request_snapshot_value(item) for key, item in value.items()}
return value


def _local_request_snapshot_items_for_status(
store: GuardStore,
*,
status: str,
limit: int,
) -> tuple[list[dict[str, object]], bool]:
items: list[dict[str, object]] = []
redaction_level = _resolve_cloud_receipt_redaction_level(store)
try:
oauth = guard_review_oauth_metadata(store)
except GuardReviewContractError:
oauth = None
rows = store.list_approval_requests(status=status, limit=limit + 1)
for item in rows[:limit]:
request_id = item.get("request_id")
if not isinstance(request_id, str) or not request_id:
continue
created_at = str(item.get("created_at") or _now())
last_seen_at = str(item.get("last_seen_at") or created_at)
resolved_at = item.get("resolved_at")
claim = None
if oauth is not None:
try:
claim = build_local_review_request_claim(
request_row=item,
oauth=oauth,
store=store,
)
except GuardReviewContractError:
claim = None
items.append(
{
"claim": claim,
"localRequestId": request_id,
"requestKind": str(item.get("harness") or "guard-review"),
"requestPayload": _cloud_safe_local_request_payload(
item,
redaction_level=redaction_level,
),
"localStatus": str(item.get("status") or status),
"firstSeenAt": created_at,
"lastSeenAt": last_seen_at,
"resolvedAt": str(resolved_at) if isinstance(resolved_at, str) and resolved_at else None,
}
)
return items, len(rows) <= limit


def _resolve_cloud_receipt_redaction_level(store: GuardStore) -> str:
payload = store.get_sync_payload("cloud_receipt_redaction_level")
if isinstance(payload, dict):
level = payload.get("level")
if isinstance(level, str) and level in VALID_RECEIPT_REDACTION_LEVELS:
return level
try:
config = load_guard_config(store.guard_home)
if config.receipt_redaction_level in VALID_RECEIPT_REDACTION_LEVELS:
return config.receipt_redaction_level
except Exception:
pass
return "full"


def _optional_payload_mapping(value: object) -> dict[str, object] | None:
return dict(value) if isinstance(value, dict) else None


def _cloud_safe_local_request_payload(
item: dict[str, object],
*,
redaction_level: str,
) -> dict[str, object]:
payload: dict[str, object] = {}
for key in (
"request_id",
"status",
"harness",
"artifact_id",
"artifact_name",
"artifact_type",
"artifact_hash",
"artifact_label",
"source_label",
"trigger_summary",
"why_now",
"risk_headline",
"risk_summary",
"policy_action",
"recommended_scope",
"created_at",
"last_seen_at",
"queue_group_id",
"review_kind",
"risk_category",
"capability_category",
"publisher",
"package_manager",
"package_name",
):
value = item.get(key)
if isinstance(value, (str, int, float, bool)) or value is None:
payload[key] = value

envelope = _optional_payload_mapping(item.get("action_envelope_json"))
safe_envelope = _cloud_safe_action_envelope(envelope, redaction_level=redaction_level)
if safe_envelope is not None:
payload["action_envelope_json"] = safe_envelope

if redaction_level == "full":
payload["raw_command_text"] = None
payload["command_text"] = None
return payload

command_text = _local_request_command_text(item, envelope)
if command_text:
scrubbed = redact_text(command_text).text
payload["raw_command_text"] = scrubbed
payload["command_text"] = scrubbed
payload_envelope = payload.get("action_envelope_json")
if isinstance(payload_envelope, dict):
payload_envelope["command"] = scrubbed
return payload


def _local_request_command_text(
payload: dict[str, object],
envelope: dict[str, object] | None,
) -> str | None:
for key in ("raw_command_text", "rawCommandText", "command_text", "commandText"):
value = payload.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
if envelope is None:
return None
command = envelope.get("command")
return command.strip() if isinstance(command, str) and command.strip() else None


def _cloud_safe_action_envelope(
envelope: dict[str, object] | None,
*,
redaction_level: str,
) -> dict[str, object] | None:
if envelope is None:
return None
safe: dict[str, object] = {}
for key in (
"schema_version",
"action_id",
"harness",
"event_name",
"action_type",
"workspace_hash",
"tool_name",
"mcp_server",
"mcp_tool",
"target_path_count",
"network_host_count",
"package_manager",
):
value = envelope.get(key)
if isinstance(value, (str, int, float, bool)) or value is None:
safe[key] = value
if redaction_level != "full":
command = envelope.get("command")
if isinstance(command, str) and command.strip():
safe["command"] = redact_text(command).text
if redaction_level == "none":
for key in ("target_paths", "network_hosts", "package_name", "package_targets"):
value = envelope.get(key)
if isinstance(value, list):
safe[key] = [item for item in value if isinstance(item, str)]
elif isinstance(value, str):
safe[key] = value
return safe or None
local_request_snapshots.LOCAL_REQUEST_PENDING_SNAPSHOT_LIMIT = LOCAL_REQUEST_PENDING_SNAPSHOT_LIMIT
local_request_snapshots.LOCAL_REQUEST_RESOLVED_SNAPSHOT_LIMIT = LOCAL_REQUEST_RESOLVED_SNAPSHOT_LIMIT
return local_request_snapshots.local_request_snapshot_payload(store)


def _package_shim_context(
Expand Down
Loading
Loading