Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 113 additions & 42 deletions agent_sdks/python/src/a2ui/core/schema/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,52 @@ def from_path(
)


def _collect_refs(obj: Any) -> set[str]:
"""Recursively collects all $ref values from a JSON object."""
refs = set()
if isinstance(obj, dict):
for k, v in obj.items():
if k == "$ref" and isinstance(v, str):
refs.add(v)
else:
refs.update(_collect_refs(v))
elif isinstance(obj, list):
for item in obj:
refs.update(_collect_refs(item))
return refs


def _prune_defs_by_reachability(
defs: Dict[str, Any],
root_def_names: List[str],
internal_ref_prefix: str = "#/$defs/",
) -> Dict[str, Any]:
"""Prunes definitions not reachable from the provided roots.

Args:
defs: The dictionary of definitions to prune.
root_def_names: The names of the definitions to start the traversal from.
internal_ref_prefix: The prefix used for internal references.

Returns:
A new dictionary containing only reachable definitions.
"""
visited_defs = set()
refs_queue = collections.deque(root_def_names)

while refs_queue:
def_name = refs_queue.popleft()
if def_name in defs and def_name not in visited_defs:
visited_defs.add(def_name)

internal_refs = _collect_refs(defs[def_name])
for ref in internal_refs:
if ref.startswith(internal_ref_prefix):
refs_queue.append(ref.split(internal_ref_prefix)[-1])

return {k: v for k, v in defs.items() if k in visited_defs}


@dataclass(frozen=True)
class A2uiCatalog:
"""Represents a processed component catalog with its schema.
Expand Down Expand Up @@ -89,7 +135,7 @@ def validator(self) -> "A2uiValidator":

return A2uiValidator(self)

def with_pruned_components(self, allowed_components: List[str]) -> "A2uiCatalog":
def _with_pruned_components(self, allowed_components: List[str]) -> "A2uiCatalog":
"""Returns a new catalog with only allowed components.

Args:
Expand All @@ -99,11 +145,10 @@ def with_pruned_components(self, allowed_components: List[str]) -> "A2uiCatalog"
A copy of the catalog with only allowed components.
"""

schema_copy = copy.deepcopy(self.catalog_schema)

# Allow all components if no allowed components are specified
if not allowed_components:
return self._with_pruned_common_types()
return self

schema_copy = copy.deepcopy(self.catalog_schema)

if CATALOG_COMPONENTS_KEY in schema_copy and isinstance(
schema_copy[CATALOG_COMPONENTS_KEY], dict
Expand Down Expand Up @@ -133,56 +178,82 @@ def with_pruned_components(self, allowed_components: List[str]) -> "A2uiCatalog"

any_comp["oneOf"] = filtered_one_of

pruned_catalog = replace(self, catalog_schema=schema_copy)
return pruned_catalog._with_pruned_common_types()
return replace(self, catalog_schema=schema_copy)

def _with_pruned_messages(self, allowed_messages: List[str]) -> "A2uiCatalog":
"""Returns a new catalog with only allowed messages.

Args:
allowed_messages: List of message names to include in s2c_schema.

Returns:
A copy of the catalog with only allowed messages.
"""
if not allowed_messages:
return self
Comment on lines +193 to +194
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better performance, consider converting allowed_messages to a set. This ensures that lookups during the filtering process (lines 156 and 161) are $O(1)$ instead of $O(N)$.

Suggested change
if not allowed_messages:
return self
if not allowed_messages:
return self
allowed_messages = set(allowed_messages)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done separately. This PR sticks to the existing conventions. The list is also small.


s2c_schema_copy = copy.deepcopy(self.s2c_schema)
if "oneOf" in s2c_schema_copy and isinstance(s2c_schema_copy["oneOf"], list):
s2c_schema_copy["oneOf"] = [
item
for item in s2c_schema_copy["oneOf"]
if "$ref" in item
and item["$ref"].startswith("#/$defs/")
and item["$ref"].split("/")[-1] in allowed_messages
]

if "$defs" in s2c_schema_copy and isinstance(s2c_schema_copy["$defs"], dict):
# Start with allowed messages as roots for internal reachability analysis
s2c_schema_copy["$defs"] = _prune_defs_by_reachability(
defs=s2c_schema_copy["$defs"],
root_def_names=allowed_messages,
internal_ref_prefix="#/$defs/",
)

return replace(self, s2c_schema=s2c_schema_copy)

def with_pruning(
self,
allowed_components: Optional[List[str]] = None,
allowed_messages: Optional[List[str]] = None,
) -> "A2uiCatalog":
"""Returns a new catalog with pruned components and messages.

Args:
allowed_components: List of component names to include.
allowed_messages: List of message names to include in s2c_schema.

Returns:
A copy of the catalog with pruned components and messages.
"""
catalog = self
if allowed_components:
catalog = catalog._with_pruned_components(allowed_components)

if allowed_messages:
catalog = catalog._with_pruned_messages(allowed_messages)

return catalog._with_pruned_common_types()

def _with_pruned_common_types(self) -> "A2uiCatalog":
"""Returns a new catalog with unused common types pruned from the schema."""
if not self.common_types_schema or "$defs" not in self.common_types_schema:
return self

def _collect_refs(obj: Any) -> set[str]:
refs = set()
if isinstance(obj, dict):
for k, v in obj.items():
if k == "$ref" and isinstance(v, str):
refs.add(v)
else:
refs.update(_collect_refs(v))
elif isinstance(obj, list):
for item in obj:
refs.update(_collect_refs(item))
return refs

visited_defs = set()
internal_refs_queue = collections.deque()

# Initialize queue with ONLY refs targeting common_types.json from external schemas
# Initialize roots with ONLY refs targeting common_types.json from external schemas
external_refs = _collect_refs(self.catalog_schema)
external_refs.update(_collect_refs(self.s2c_schema))

root_common_types = []
for ref in external_refs:
if ref.startswith("common_types.json#/$defs/"):
internal_refs_queue.append(ref.split("#/$defs/")[-1])

while internal_refs_queue:
def_name = internal_refs_queue.popleft()
if def_name in self.common_types_schema["$defs"] and def_name not in visited_defs:
visited_defs.add(def_name)

# Collect internal references (which just use #/$defs/)
internal_refs = _collect_refs(self.common_types_schema["$defs"][def_name])
for ref in internal_refs:
if ref.startswith("#/$defs/"):
# Note: This assumes a flat `$defs` namespace and no escaped
# slashes (~1) or tildes (~0) in the definition names as per RFC 6901.
internal_refs_queue.append(ref.split("#/$defs/")[-1])
root_common_types.append(ref.split("#/$defs/")[-1])

new_common_types_schema = copy.deepcopy(self.common_types_schema)
all_defs = new_common_types_schema["$defs"]
new_common_types_schema["$defs"] = {
k: v for k, v in all_defs.items() if k in visited_defs
}
new_common_types_schema["$defs"] = _prune_defs_by_reachability(
defs=new_common_types_schema["$defs"],
root_def_names=root_common_types,
)

return replace(self, common_types_schema=new_common_types_schema)

Expand Down
13 changes: 8 additions & 5 deletions agent_sdks/python/src/a2ui/core/schema/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ def _apply_modifiers(self, schema: Dict[str, Any]) -> Dict[str, Any]:
def _load_schemas(
self,
version: str,
catalogs: List[CatalogConfig] = [],
catalogs: Optional[List[CatalogConfig]] = None,
):
"""Loads separate schema components and processes catalogs."""
catalogs = catalogs or []
if version not in SPEC_VERSION_MAP:
raise ValueError(
f"Unknown A2UI specification version: {version}. Supported:"
Expand Down Expand Up @@ -181,11 +182,12 @@ def _select_catalog(
def get_selected_catalog(
self,
client_ui_capabilities: Optional[dict[str, Any]] = None,
allowed_components: List[str] = [],
allowed_components: Optional[List[str]] = None,
allowed_messages: Optional[List[str]] = None,
) -> A2uiCatalog:
"""Gets the selected catalog after selection and component pruning."""
catalog = self._select_catalog(client_ui_capabilities)
pruned_catalog = catalog.with_pruned_components(allowed_components)
pruned_catalog = catalog.with_pruning(allowed_components, allowed_messages)
return pruned_catalog

def load_examples(self, catalog: A2uiCatalog, validate: bool = False) -> str:
Expand All @@ -202,7 +204,8 @@ def generate_system_prompt(
workflow_description: str = "",
ui_description: str = "",
client_ui_capabilities: Optional[dict[str, Any]] = None,
allowed_components: List[str] = [],
allowed_components: Optional[List[str]] = None,
allowed_messages: Optional[List[str]] = None,
include_schema: bool = False,
include_examples: bool = False,
validate_examples: bool = False,
Expand All @@ -219,7 +222,7 @@ def generate_system_prompt(
parts.append(f"## UI Description:\n{ui_description}")

selected_catalog = self.get_selected_catalog(
client_ui_capabilities, allowed_components
client_ui_capabilities, allowed_components, allowed_messages
)

if include_schema:
Expand Down
116 changes: 110 additions & 6 deletions agent_sdks/python/tests/core/schema/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_load_examples_none_or_invalid_path():
assert catalog.load_examples("/non/existent/path") == ""


def test_with_pruned_components():
def test_with_pruning_components():
catalog_schema = {
"catalogId": "basic",
"components": {
Expand All @@ -151,7 +151,7 @@ def test_with_pruned_components():
)

# Test basic pruning
pruned_catalog = catalog.with_pruned_components(["Text", "Button"])
pruned_catalog = catalog.with_pruning(allowed_components=["Text", "Button"])
pruned = pruned_catalog.catalog_schema
assert "Text" in pruned["components"]
assert "Button" in pruned["components"]
Expand Down Expand Up @@ -179,13 +179,81 @@ def test_with_pruned_components():
common_types_schema={},
catalog_schema=catalog_schema_with_defs,
)
pruned_catalog_defs = catalog_with_defs.with_pruned_components(["Text"])
pruned_catalog_defs = catalog_with_defs.with_pruning(allowed_components=["Text"])
any_comp = pruned_catalog_defs.catalog_schema["$defs"]["anyComponent"]
assert len(any_comp["oneOf"]) == 1
assert any_comp["oneOf"][0]["$ref"] == "#/components/Text"

# Test empty allowed components (should return original self)
assert catalog.with_pruned_components([]) is catalog
assert catalog.with_pruning(allowed_components=[]) is catalog


def test_with_pruning_messages():
s2c_schema = {
"oneOf": [
{"$ref": "#/$defs/MessageA"},
{"$ref": "#/$defs/MessageB"},
{"$ref": "#/$defs/MessageC"},
],
"$defs": {
"MessageA": {"type": "object", "properties": {"a": {"type": "string"}}},
"MessageB": {"type": "object", "properties": {"b": {"type": "string"}}},
"MessageC": {"type": "object", "properties": {"c": {"type": "string"}}},
},
}
catalog_schema = {"catalogId": "basic"}
catalog = A2uiCatalog(
version=VERSION_0_9,
name=BASIC_CATALOG_NAME,
s2c_schema=s2c_schema,
common_types_schema={},
catalog_schema=catalog_schema,
)

# Prune to only MessageA and MessageC
pruned_catalog = catalog.with_pruning([], allowed_messages=["MessageA", "MessageC"])
pruned_s2c = pruned_catalog.s2c_schema

assert len(pruned_s2c["oneOf"]) == 2
assert {"$ref": "#/$defs/MessageA"} in pruned_s2c["oneOf"]
assert {"$ref": "#/$defs/MessageC"} in pruned_s2c["oneOf"]
assert {"$ref": "#/$defs/MessageB"} not in pruned_s2c["oneOf"]

assert "MessageA" in pruned_s2c["$defs"]
assert "MessageC" in pruned_s2c["$defs"]
assert "MessageB" not in pruned_s2c["$defs"]


def test_with_pruning_messages_internal_reachability():
s2c_schema = {
"oneOf": [
{"$ref": "#/$defs/MessageA"},
],
"$defs": {
"MessageA": {
"type": "object",
"properties": {"shared": {"$ref": "#/$defs/SharedType"}},
},
"SharedType": {"type": "string"},
"UnusedType": {"type": "number"},
},
}
catalog_schema = {"catalogId": "basic"}
catalog = A2uiCatalog(
version=VERSION_0_9,
name=BASIC_CATALOG_NAME,
s2c_schema=s2c_schema,
common_types_schema={},
catalog_schema=catalog_schema,
)

# Prune to MessageA. SharedType should be kept, UnusedType should be removed.
pruned_catalog = catalog.with_pruning([], allowed_messages=["MessageA"])
pruned_defs = pruned_catalog.s2c_schema["$defs"]

assert "MessageA" in pruned_defs
assert "SharedType" in pruned_defs
assert "UnusedType" not in pruned_defs


def test_render_as_llm_instructions():
Expand Down Expand Up @@ -261,7 +329,7 @@ def test_render_as_llm_instructions_drops_empty_common_types():
assert "### Common Types Schema:" not in schema_str_empty_defs


def test_with_pruned_components_prunes_common_types():
def test_with_pruning_common_types():
common_types = {
"$defs": {
"TypeForCompA": {"type": "string"},
Expand All @@ -283,8 +351,44 @@ def test_with_pruned_components_prunes_common_types():
catalog_schema=catalog_schema,
)

pruned_catalog = catalog.with_pruned_components(["CompA"])
pruned_catalog = catalog.with_pruning(allowed_components=["CompA"])
pruned_defs = pruned_catalog.common_types_schema["$defs"]

assert "TypeForCompA" in pruned_defs
assert "TypeForCompB" not in pruned_defs


def test_with_pruning_s2c_also_prunes_common_types():
common_types = {
"$defs": {
"TypeForA": {"type": "string"},
"TypeForB": {"type": "number"},
}
}
s2c_schema = {
"oneOf": [
{"$ref": "#/$defs/MessageA"},
{"$ref": "#/$defs/MessageB"},
],
"$defs": {
"MessageA": {"$ref": "common_types.json#/$defs/TypeForA"},
"MessageB": {"$ref": "common_types.json#/$defs/TypeForB"},
},
}
catalog_schema = {"catalogId": "basic"}
catalog = A2uiCatalog(
version=VERSION_0_9,
name=BASIC_CATALOG_NAME,
s2c_schema=s2c_schema,
common_types_schema=common_types,
catalog_schema=catalog_schema,
)

# Prune to only MessageA
pruned_catalog = catalog.with_pruning([], allowed_messages=["MessageA"])

assert "MessageA" in pruned_catalog.s2c_schema["$defs"]
assert "MessageB" not in pruned_catalog.s2c_schema["$defs"]

assert "TypeForA" in pruned_catalog.common_types_schema["$defs"]
assert "TypeForB" not in pruned_catalog.common_types_schema["$defs"]
Loading