diff --git a/agent_sdks/python/src/a2ui/core/schema/catalog.py b/agent_sdks/python/src/a2ui/core/schema/catalog.py index 15898a3a6..13239c565 100644 --- a/agent_sdks/python/src/a2ui/core/schema/catalog.py +++ b/agent_sdks/python/src/a2ui/core/schema/catalog.py @@ -26,6 +26,7 @@ A2UI_SCHEMA_BLOCK_END, CATALOG_COMPONENTS_KEY, CATALOG_ID_KEY, + VERSION_0_8, ) @@ -59,6 +60,52 @@ def from_path( ) +def _collect_refs(obj: Any) -> set[str]: + """Recursively collects all $ref values from a JSON object.""" + refs = set() + if isinstance(obj, dict): + for k, v in obj.items(): + if k == "$ref" and isinstance(v, str): + refs.add(v) + else: + refs.update(_collect_refs(v)) + elif isinstance(obj, list): + for item in obj: + refs.update(_collect_refs(item)) + return refs + + +def _prune_defs_by_reachability( + defs: Dict[str, Any], + root_def_names: List[str], + internal_ref_prefix: str = "#/$defs/", +) -> Dict[str, Any]: + """Prunes definitions not reachable from the provided roots. + + Args: + defs: The dictionary of definitions to prune. + root_def_names: The names of the definitions to start the traversal from. + internal_ref_prefix: The prefix used for internal references. + + Returns: + A new dictionary containing only reachable definitions. + """ + visited_defs = set() + refs_queue = collections.deque(root_def_names) + + while refs_queue: + def_name = refs_queue.popleft() + if def_name in defs and def_name not in visited_defs: + visited_defs.add(def_name) + + internal_refs = _collect_refs(defs[def_name]) + for ref in internal_refs: + if ref.startswith(internal_ref_prefix): + refs_queue.append(ref.split(internal_ref_prefix)[-1]) + + return {k: v for k, v in defs.items() if k in visited_defs} + + @dataclass(frozen=True) class A2uiCatalog: """Represents a processed component catalog with its schema. @@ -89,7 +136,7 @@ def validator(self) -> "A2uiValidator": return A2uiValidator(self) - def with_pruned_components(self, allowed_components: List[str]) -> "A2uiCatalog": + def _with_pruned_components(self, allowed_components: List[str]) -> "A2uiCatalog": """Returns a new catalog with only allowed components. Args: @@ -99,11 +146,10 @@ def with_pruned_components(self, allowed_components: List[str]) -> "A2uiCatalog" A copy of the catalog with only allowed components. """ - schema_copy = copy.deepcopy(self.catalog_schema) - - # Allow all components if no allowed components are specified if not allowed_components: - return self._with_pruned_common_types() + return self + + schema_copy = copy.deepcopy(self.catalog_schema) if CATALOG_COMPONENTS_KEY in schema_copy and isinstance( schema_copy[CATALOG_COMPONENTS_KEY], dict @@ -133,56 +179,94 @@ def with_pruned_components(self, allowed_components: List[str]) -> "A2uiCatalog" any_comp["oneOf"] = filtered_one_of - pruned_catalog = replace(self, catalog_schema=schema_copy) - return pruned_catalog._with_pruned_common_types() + return replace(self, catalog_schema=schema_copy) + + def _with_pruned_messages(self, allowed_messages: List[str]) -> "A2uiCatalog": + """Returns a new catalog with only allowed messages. + + Args: + allowed_messages: List of message names to include in s2c_schema. + + Returns: + A copy of the catalog with only allowed messages. + """ + if not allowed_messages: + return self + + s2c_schema_copy = copy.deepcopy(self.s2c_schema) + + if self.version == VERSION_0_8: + # 0.8 style: Messages are in root properties. + if "properties" in s2c_schema_copy and isinstance( + s2c_schema_copy["properties"], dict + ): + s2c_schema_copy["properties"] = _prune_defs_by_reachability( + defs=s2c_schema_copy["properties"], + root_def_names=allowed_messages, + internal_ref_prefix="#/properties/", + ) + else: + # 0.9+ style: Messages are in $defs and referenced via oneOf. + if "oneOf" in s2c_schema_copy and isinstance(s2c_schema_copy["oneOf"], list): + s2c_schema_copy["oneOf"] = [ + item + for item in s2c_schema_copy["oneOf"] + if "$ref" in item + and item["$ref"].startswith("#/$defs/") + and item["$ref"].split("/")[-1] in allowed_messages + ] + + if "$defs" in s2c_schema_copy and isinstance(s2c_schema_copy["$defs"], dict): + s2c_schema_copy["$defs"] = _prune_defs_by_reachability( + defs=s2c_schema_copy["$defs"], + root_def_names=allowed_messages, + internal_ref_prefix="#/$defs/", + ) + + return replace(self, s2c_schema=s2c_schema_copy) + + def with_pruning( + self, + allowed_components: Optional[List[str]] = None, + allowed_messages: Optional[List[str]] = None, + ) -> "A2uiCatalog": + """Returns a new catalog with pruned components and messages. + + Args: + allowed_components: List of component names to include. + allowed_messages: List of message names to include in s2c_schema. + + Returns: + A copy of the catalog with pruned components and messages. + """ + catalog = self + if allowed_components: + catalog = catalog._with_pruned_components(allowed_components) + + if allowed_messages: + catalog = catalog._with_pruned_messages(allowed_messages) + + return catalog._with_pruned_common_types() def _with_pruned_common_types(self) -> "A2uiCatalog": """Returns a new catalog with unused common types pruned from the schema.""" if not self.common_types_schema or "$defs" not in self.common_types_schema: return self - def _collect_refs(obj: Any) -> set[str]: - refs = set() - if isinstance(obj, dict): - for k, v in obj.items(): - if k == "$ref" and isinstance(v, str): - refs.add(v) - else: - refs.update(_collect_refs(v)) - elif isinstance(obj, list): - for item in obj: - refs.update(_collect_refs(item)) - return refs - - visited_defs = set() - internal_refs_queue = collections.deque() - - # Initialize queue with ONLY refs targeting common_types.json from external schemas + # Initialize roots with ONLY refs targeting common_types.json from external schemas external_refs = _collect_refs(self.catalog_schema) external_refs.update(_collect_refs(self.s2c_schema)) + root_common_types = [] for ref in external_refs: if ref.startswith("common_types.json#/$defs/"): - internal_refs_queue.append(ref.split("#/$defs/")[-1]) - - while internal_refs_queue: - def_name = internal_refs_queue.popleft() - if def_name in self.common_types_schema["$defs"] and def_name not in visited_defs: - visited_defs.add(def_name) - - # Collect internal references (which just use #/$defs/) - internal_refs = _collect_refs(self.common_types_schema["$defs"][def_name]) - for ref in internal_refs: - if ref.startswith("#/$defs/"): - # Note: This assumes a flat `$defs` namespace and no escaped - # slashes (~1) or tildes (~0) in the definition names as per RFC 6901. - internal_refs_queue.append(ref.split("#/$defs/")[-1]) + root_common_types.append(ref.split("#/$defs/")[-1]) new_common_types_schema = copy.deepcopy(self.common_types_schema) - all_defs = new_common_types_schema["$defs"] - new_common_types_schema["$defs"] = { - k: v for k, v in all_defs.items() if k in visited_defs - } + new_common_types_schema["$defs"] = _prune_defs_by_reachability( + defs=new_common_types_schema["$defs"], + root_def_names=root_common_types, + ) return replace(self, common_types_schema=new_common_types_schema) diff --git a/agent_sdks/python/src/a2ui/core/schema/manager.py b/agent_sdks/python/src/a2ui/core/schema/manager.py index 185df6a2e..8c0cbf9c5 100644 --- a/agent_sdks/python/src/a2ui/core/schema/manager.py +++ b/agent_sdks/python/src/a2ui/core/schema/manager.py @@ -64,9 +64,10 @@ def _apply_modifiers(self, schema: Dict[str, Any]) -> Dict[str, Any]: def _load_schemas( self, version: str, - catalogs: List[CatalogConfig] = [], + catalogs: Optional[List[CatalogConfig]] = None, ): """Loads separate schema components and processes catalogs.""" + catalogs = catalogs or [] if version not in SPEC_VERSION_MAP: raise ValueError( f"Unknown A2UI specification version: {version}. Supported:" @@ -181,11 +182,12 @@ def _select_catalog( def get_selected_catalog( self, client_ui_capabilities: Optional[dict[str, Any]] = None, - allowed_components: List[str] = [], + allowed_components: Optional[List[str]] = None, + allowed_messages: Optional[List[str]] = None, ) -> A2uiCatalog: """Gets the selected catalog after selection and component pruning.""" catalog = self._select_catalog(client_ui_capabilities) - pruned_catalog = catalog.with_pruned_components(allowed_components) + pruned_catalog = catalog.with_pruning(allowed_components, allowed_messages) return pruned_catalog def load_examples(self, catalog: A2uiCatalog, validate: bool = False) -> str: @@ -202,7 +204,8 @@ def generate_system_prompt( workflow_description: str = "", ui_description: str = "", client_ui_capabilities: Optional[dict[str, Any]] = None, - allowed_components: List[str] = [], + allowed_components: Optional[List[str]] = None, + allowed_messages: Optional[List[str]] = None, include_schema: bool = False, include_examples: bool = False, validate_examples: bool = False, @@ -219,7 +222,7 @@ def generate_system_prompt( parts.append(f"## UI Description:\n{ui_description}") selected_catalog = self.get_selected_catalog( - client_ui_capabilities, allowed_components + client_ui_capabilities, allowed_components, allowed_messages ) if include_schema: diff --git a/agent_sdks/python/tests/core/schema/test_catalog.py b/agent_sdks/python/tests/core/schema/test_catalog.py index 8c773ac34..eff39eb0a 100644 --- a/agent_sdks/python/tests/core/schema/test_catalog.py +++ b/agent_sdks/python/tests/core/schema/test_catalog.py @@ -133,7 +133,7 @@ def test_load_examples_none_or_invalid_path(): assert catalog.load_examples("/non/existent/path") == "" -def test_with_pruned_components(): +def test_with_pruning_components(): catalog_schema = { "catalogId": "basic", "components": { @@ -151,7 +151,7 @@ def test_with_pruned_components(): ) # Test basic pruning - pruned_catalog = catalog.with_pruned_components(["Text", "Button"]) + pruned_catalog = catalog.with_pruning(allowed_components=["Text", "Button"]) pruned = pruned_catalog.catalog_schema assert "Text" in pruned["components"] assert "Button" in pruned["components"] @@ -179,13 +179,81 @@ def test_with_pruned_components(): common_types_schema={}, catalog_schema=catalog_schema_with_defs, ) - pruned_catalog_defs = catalog_with_defs.with_pruned_components(["Text"]) + pruned_catalog_defs = catalog_with_defs.with_pruning(allowed_components=["Text"]) any_comp = pruned_catalog_defs.catalog_schema["$defs"]["anyComponent"] assert len(any_comp["oneOf"]) == 1 assert any_comp["oneOf"][0]["$ref"] == "#/components/Text" # Test empty allowed components (should return original self) - assert catalog.with_pruned_components([]) is catalog + assert catalog.with_pruning(allowed_components=[]) is catalog + + +def test_with_pruning_messages(): + s2c_schema = { + "oneOf": [ + {"$ref": "#/$defs/MessageA"}, + {"$ref": "#/$defs/MessageB"}, + {"$ref": "#/$defs/MessageC"}, + ], + "$defs": { + "MessageA": {"type": "object", "properties": {"a": {"type": "string"}}}, + "MessageB": {"type": "object", "properties": {"b": {"type": "string"}}}, + "MessageC": {"type": "object", "properties": {"c": {"type": "string"}}}, + }, + } + catalog_schema = {"catalogId": "basic"} + catalog = A2uiCatalog( + version=VERSION_0_9, + name=BASIC_CATALOG_NAME, + s2c_schema=s2c_schema, + common_types_schema={}, + catalog_schema=catalog_schema, + ) + + # Prune to only MessageA and MessageC + pruned_catalog = catalog.with_pruning([], allowed_messages=["MessageA", "MessageC"]) + pruned_s2c = pruned_catalog.s2c_schema + + assert len(pruned_s2c["oneOf"]) == 2 + assert {"$ref": "#/$defs/MessageA"} in pruned_s2c["oneOf"] + assert {"$ref": "#/$defs/MessageC"} in pruned_s2c["oneOf"] + assert {"$ref": "#/$defs/MessageB"} not in pruned_s2c["oneOf"] + + assert "MessageA" in pruned_s2c["$defs"] + assert "MessageC" in pruned_s2c["$defs"] + assert "MessageB" not in pruned_s2c["$defs"] + + +def test_with_pruning_messages_internal_reachability(): + s2c_schema = { + "oneOf": [ + {"$ref": "#/$defs/MessageA"}, + ], + "$defs": { + "MessageA": { + "type": "object", + "properties": {"shared": {"$ref": "#/$defs/SharedType"}}, + }, + "SharedType": {"type": "string"}, + "UnusedType": {"type": "number"}, + }, + } + catalog_schema = {"catalogId": "basic"} + catalog = A2uiCatalog( + version=VERSION_0_9, + name=BASIC_CATALOG_NAME, + s2c_schema=s2c_schema, + common_types_schema={}, + catalog_schema=catalog_schema, + ) + + # Prune to MessageA. SharedType should be kept, UnusedType should be removed. + pruned_catalog = catalog.with_pruning([], allowed_messages=["MessageA"]) + pruned_defs = pruned_catalog.s2c_schema["$defs"] + + assert "MessageA" in pruned_defs + assert "SharedType" in pruned_defs + assert "UnusedType" not in pruned_defs def test_render_as_llm_instructions(): @@ -261,7 +329,7 @@ def test_render_as_llm_instructions_drops_empty_common_types(): assert "### Common Types Schema:" not in schema_str_empty_defs -def test_with_pruned_components_prunes_common_types(): +def test_with_pruning_common_types(): common_types = { "$defs": { "TypeForCompA": {"type": "string"}, @@ -283,8 +351,74 @@ def test_with_pruned_components_prunes_common_types(): catalog_schema=catalog_schema, ) - pruned_catalog = catalog.with_pruned_components(["CompA"]) + pruned_catalog = catalog.with_pruning(allowed_components=["CompA"]) pruned_defs = pruned_catalog.common_types_schema["$defs"] assert "TypeForCompA" in pruned_defs assert "TypeForCompB" not in pruned_defs + + +def test_with_pruning_s2c_also_prunes_common_types(): + common_types = { + "$defs": { + "TypeForA": {"type": "string"}, + "TypeForB": {"type": "number"}, + } + } + s2c_schema = { + "oneOf": [ + {"$ref": "#/$defs/MessageA"}, + {"$ref": "#/$defs/MessageB"}, + ], + "$defs": { + "MessageA": {"$ref": "common_types.json#/$defs/TypeForA"}, + "MessageB": {"$ref": "common_types.json#/$defs/TypeForB"}, + }, + } + catalog_schema = {"catalogId": "basic"} + catalog = A2uiCatalog( + version=VERSION_0_9, + name=BASIC_CATALOG_NAME, + s2c_schema=s2c_schema, + common_types_schema=common_types, + catalog_schema=catalog_schema, + ) + + # Prune to only MessageA + pruned_catalog = catalog.with_pruning([], allowed_messages=["MessageA"]) + + assert "MessageA" in pruned_catalog.s2c_schema["$defs"] + assert "MessageB" not in pruned_catalog.s2c_schema["$defs"] + + assert "TypeForA" in pruned_catalog.common_types_schema["$defs"] + assert "TypeForB" not in pruned_catalog.common_types_schema["$defs"] + + +def test_with_pruning_messages_v08(): + s2c_schema = { + "properties": { + "beginRendering": {"type": "object"}, + "surfaceUpdate": {"type": "object"}, + "deleteSurface": {"type": "object"}, + }, + "required": ["surfaceId"], + } + catalog_schema = {"catalogId": "basic"} + catalog = A2uiCatalog( + version=VERSION_0_8, + name=BASIC_CATALOG_NAME, + s2c_schema=s2c_schema, + common_types_schema={}, + catalog_schema=catalog_schema, + ) + + # Prune to only beginRendering and deleteSurface + pruned_catalog = catalog.with_pruning( + [], allowed_messages=["beginRendering", "deleteSurface"] + ) + pruned_s2c = pruned_catalog.s2c_schema + + assert "beginRendering" in pruned_s2c["properties"] + assert "deleteSurface" in pruned_s2c["properties"] + assert "surfaceUpdate" not in pruned_s2c["properties"] + assert pruned_s2c["required"] == ["surfaceId"]