Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 57 additions & 7 deletions agent_sdks/python/src/a2ui/core/schema/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def validator(self) -> "A2uiValidator":

return A2uiValidator(self)

def with_pruned_components(self, allowed_components: List[str]) -> "A2uiCatalog":
def _with_pruned_components(self, allowed_components: List[str]) -> "A2uiCatalog":
"""Returns a new catalog with only allowed components.

Args:
Expand All @@ -99,11 +99,10 @@ def with_pruned_components(self, allowed_components: List[str]) -> "A2uiCatalog"
A copy of the catalog with only allowed components.
"""

schema_copy = copy.deepcopy(self.catalog_schema)

# Allow all components if no allowed components are specified
if not allowed_components:
return self._with_pruned_common_types()
return self

schema_copy = copy.deepcopy(self.catalog_schema)

if CATALOG_COMPONENTS_KEY in schema_copy and isinstance(
schema_copy[CATALOG_COMPONENTS_KEY], dict
Expand Down Expand Up @@ -133,8 +132,59 @@ def with_pruned_components(self, allowed_components: List[str]) -> "A2uiCatalog"

any_comp["oneOf"] = filtered_one_of

pruned_catalog = replace(self, catalog_schema=schema_copy)
return pruned_catalog._with_pruned_common_types()
return replace(self, catalog_schema=schema_copy)

def _with_pruned_messages(self, allowed_messages: List[str]) -> "A2uiCatalog":
"""Returns a new catalog with only allowed messages.

Args:
allowed_messages: List of message names to include in s2c_schema.

Returns:
A copy of the catalog with only allowed messages.
"""
if not allowed_messages:
return self
Comment on lines +193 to +194
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better performance, consider converting allowed_messages to a set. This ensures that lookups during the filtering process (lines 156 and 161) are $O(1)$ instead of $O(N)$.

Suggested change
if not allowed_messages:
return self
if not allowed_messages:
return self
allowed_messages = set(allowed_messages)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done separately. This PR sticks to the existing conventions. The list is also small.


s2c_schema_copy = copy.deepcopy(self.s2c_schema)
if "oneOf" in s2c_schema_copy and isinstance(s2c_schema_copy["oneOf"], list):
s2c_schema_copy["oneOf"] = [
item
for item in s2c_schema_copy["oneOf"]
if "$ref" in item
and item["$ref"].startswith("#/$defs/")
and item["$ref"].split("/")[-1] in allowed_messages
]

if "$defs" in s2c_schema_copy and isinstance(s2c_schema_copy["$defs"], dict):
s2c_schema_copy["$defs"] = {
k: v for k, v in s2c_schema_copy["$defs"].items() if k in allowed_messages
}

return replace(self, s2c_schema=s2c_schema_copy)

def with_pruning(
self,
allowed_components: List[str] = [],
allowed_messages: List[str] = [],
) -> "A2uiCatalog":
"""Returns a new catalog with pruned components and messages.

Args:
allowed_components: List of component names to include.
allowed_messages: List of message names to include in s2c_schema.

Returns:
A copy of the catalog with pruned components and messages.
"""
catalog = self
if allowed_components:
catalog = catalog._with_pruned_components(allowed_components)

if allowed_messages:
catalog = catalog._with_pruned_messages(allowed_messages)

return catalog._with_pruned_common_types()

def _with_pruned_common_types(self) -> "A2uiCatalog":
"""Returns a new catalog with unused common types pruned from the schema."""
Expand Down
6 changes: 4 additions & 2 deletions agent_sdks/python/src/a2ui/core/schema/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,11 @@ def get_selected_catalog(
self,
client_ui_capabilities: Optional[dict[str, Any]] = None,
allowed_components: List[str] = [],
allowed_messages: List[str] = [],
) -> A2uiCatalog:
"""Gets the selected catalog after selection and component pruning."""
catalog = self._select_catalog(client_ui_capabilities)
pruned_catalog = catalog.with_pruned_components(allowed_components)
pruned_catalog = catalog.with_pruning(allowed_components, allowed_messages)
return pruned_catalog

def load_examples(self, catalog: A2uiCatalog, validate: bool = False) -> str:
Expand All @@ -203,6 +204,7 @@ def generate_system_prompt(
ui_description: str = "",
client_ui_capabilities: Optional[dict[str, Any]] = None,
allowed_components: List[str] = [],
allowed_messages: List[str] = [],
include_schema: bool = False,
include_examples: bool = False,
validate_examples: bool = False,
Expand All @@ -219,7 +221,7 @@ def generate_system_prompt(
parts.append(f"## UI Description:\n{ui_description}")

selected_catalog = self.get_selected_catalog(
client_ui_capabilities, allowed_components
client_ui_capabilities, allowed_components, allowed_messages
)

if include_schema:
Expand Down
84 changes: 78 additions & 6 deletions agent_sdks/python/tests/core/schema/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_load_examples_none_or_invalid_path():
assert catalog.load_examples("/non/existent/path") == ""


def test_with_pruned_components():
def test_with_pruning_components():
catalog_schema = {
"catalogId": "basic",
"components": {
Expand All @@ -151,7 +151,7 @@ def test_with_pruned_components():
)

# Test basic pruning
pruned_catalog = catalog.with_pruned_components(["Text", "Button"])
pruned_catalog = catalog.with_pruning(allowed_components=["Text", "Button"])
pruned = pruned_catalog.catalog_schema
assert "Text" in pruned["components"]
assert "Button" in pruned["components"]
Expand Down Expand Up @@ -179,13 +179,49 @@ def test_with_pruned_components():
common_types_schema={},
catalog_schema=catalog_schema_with_defs,
)
pruned_catalog_defs = catalog_with_defs.with_pruned_components(["Text"])
pruned_catalog_defs = catalog_with_defs.with_pruning(allowed_components=["Text"])
any_comp = pruned_catalog_defs.catalog_schema["$defs"]["anyComponent"]
assert len(any_comp["oneOf"]) == 1
assert any_comp["oneOf"][0]["$ref"] == "#/components/Text"

# Test empty allowed components (should return original self)
assert catalog.with_pruned_components([]) is catalog
assert catalog.with_pruning(allowed_components=[]) is catalog


def test_with_pruning_messages():
s2c_schema = {
"oneOf": [
{"$ref": "#/$defs/MessageA"},
{"$ref": "#/$defs/MessageB"},
{"$ref": "#/$defs/MessageC"},
],
"$defs": {
"MessageA": {"type": "object", "properties": {"a": {"type": "string"}}},
"MessageB": {"type": "object", "properties": {"b": {"type": "string"}}},
"MessageC": {"type": "object", "properties": {"c": {"type": "string"}}},
},
}
catalog_schema = {"catalogId": "basic"}
catalog = A2uiCatalog(
version=VERSION_0_9,
name=BASIC_CATALOG_NAME,
s2c_schema=s2c_schema,
common_types_schema={},
catalog_schema=catalog_schema,
)

# Prune to only MessageA and MessageC
pruned_catalog = catalog.with_pruning([], allowed_messages=["MessageA", "MessageC"])
pruned_s2c = pruned_catalog.s2c_schema

assert len(pruned_s2c["oneOf"]) == 2
assert {"$ref": "#/$defs/MessageA"} in pruned_s2c["oneOf"]
assert {"$ref": "#/$defs/MessageC"} in pruned_s2c["oneOf"]
assert {"$ref": "#/$defs/MessageB"} not in pruned_s2c["oneOf"]

assert "MessageA" in pruned_s2c["$defs"]
assert "MessageC" in pruned_s2c["$defs"]
assert "MessageB" not in pruned_s2c["$defs"]


def test_render_as_llm_instructions():
Expand Down Expand Up @@ -261,7 +297,7 @@ def test_render_as_llm_instructions_drops_empty_common_types():
assert "### Common Types Schema:" not in schema_str_empty_defs


def test_with_pruned_components_prunes_common_types():
def test_with_pruning_common_types():
common_types = {
"$defs": {
"TypeForCompA": {"type": "string"},
Expand All @@ -283,8 +319,44 @@ def test_with_pruned_components_prunes_common_types():
catalog_schema=catalog_schema,
)

pruned_catalog = catalog.with_pruned_components(["CompA"])
pruned_catalog = catalog.with_pruning(allowed_components=["CompA"])
pruned_defs = pruned_catalog.common_types_schema["$defs"]

assert "TypeForCompA" in pruned_defs
assert "TypeForCompB" not in pruned_defs


def test_with_pruning_s2c_also_prunes_common_types():
common_types = {
"$defs": {
"TypeForA": {"type": "string"},
"TypeForB": {"type": "number"},
}
}
s2c_schema = {
"oneOf": [
{"$ref": "#/$defs/MessageA"},
{"$ref": "#/$defs/MessageB"},
],
"$defs": {
"MessageA": {"$ref": "common_types.json#/$defs/TypeForA"},
"MessageB": {"$ref": "common_types.json#/$defs/TypeForB"},
},
}
catalog_schema = {"catalogId": "basic"}
catalog = A2uiCatalog(
version=VERSION_0_9,
name=BASIC_CATALOG_NAME,
s2c_schema=s2c_schema,
common_types_schema=common_types,
catalog_schema=catalog_schema,
)

# Prune to only MessageA
pruned_catalog = catalog.with_pruning([], allowed_messages=["MessageA"])

assert "MessageA" in pruned_catalog.s2c_schema["$defs"]
assert "MessageB" not in pruned_catalog.s2c_schema["$defs"]

assert "TypeForA" in pruned_catalog.common_types_schema["$defs"]
assert "TypeForB" not in pruned_catalog.common_types_schema["$defs"]
Loading