Skip to content

Commit 2224a45

Browse files
committed
Treeshake common types schema
1 parent 92b4c2d commit 2224a45

File tree

2 files changed

+78
-2
lines changed

2 files changed

+78
-2
lines changed

agent_sdks/python/src/a2ui/core/schema/catalog.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def with_pruned_components(self, allowed_components: List[str]) -> "A2uiCatalog"
102102

103103
# Allow all components if no allowed components are specified
104104
if not allowed_components:
105-
return self
105+
return self._with_pruned_common_types()
106106

107107
if CATALOG_COMPONENTS_KEY in schema_copy and isinstance(
108108
schema_copy[CATALOG_COMPONENTS_KEY], dict
@@ -132,7 +132,54 @@ def with_pruned_components(self, allowed_components: List[str]) -> "A2uiCatalog"
132132

133133
any_comp["oneOf"] = filtered_one_of
134134

135-
return replace(self, catalog_schema=schema_copy)
135+
pruned_catalog = replace(self, catalog_schema=schema_copy)
136+
return pruned_catalog._with_pruned_common_types()
137+
138+
def _with_pruned_common_types(self) -> "A2uiCatalog":
139+
"""Returns a new catalog with unused common types pruned from the schema."""
140+
if not self.common_types_schema or "$defs" not in self.common_types_schema:
141+
return self
142+
143+
def _collect_refs(obj: Any) -> set[str]:
144+
refs = set()
145+
if isinstance(obj, dict):
146+
for k, v in obj.items():
147+
if k == "$ref" and isinstance(v, str):
148+
refs.add(v)
149+
else:
150+
refs.update(_collect_refs(v))
151+
elif isinstance(obj, list):
152+
for item in obj:
153+
refs.update(_collect_refs(item))
154+
return refs
155+
156+
visited_defs = set()
157+
queue = []
158+
159+
# Initialize queue with refs from catalog_schema and s2c_schema
160+
queue.extend(_collect_refs(self.catalog_schema))
161+
queue.extend(_collect_refs(self.s2c_schema))
162+
163+
while queue:
164+
ref = queue.pop(0)
165+
if ref.startswith("#/$defs/"):
166+
# Note: This assumes a flat `$defs` namespace and no escaped
167+
# slashes (~1) or tildes (~0) in the definition names as per RFC 6901.
168+
def_name = ref.split("/")[-1]
169+
if (
170+
def_name in self.common_types_schema["$defs"]
171+
and def_name not in visited_defs
172+
):
173+
visited_defs.add(def_name)
174+
queue.extend(_collect_refs(self.common_types_schema["$defs"][def_name]))
175+
176+
new_common_types_schema = copy.deepcopy(self.common_types_schema)
177+
all_defs = new_common_types_schema["$defs"]
178+
new_common_types_schema["$defs"] = {
179+
k: v for k, v in all_defs.items() if k in visited_defs
180+
}
181+
182+
return replace(self, common_types_schema=new_common_types_schema)
136183

137184
def render_as_llm_instructions(self) -> str:
138185
"""Renders the catalog and schema as LLM instructions."""

agent_sdks/python/tests/core/schema/test_catalog.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,32 @@ def test_render_as_llm_instructions():
209209
assert '"catalog": "schema"' in schema_str
210210
assert '"catalogId": "id_basic"' in schema_str
211211
assert A2UI_SCHEMA_BLOCK_END in schema_str
212+
213+
214+
def test_with_pruned_components_prunes_common_types():
215+
common_types = {
216+
"$defs": {
217+
"TypeForCompA": {"type": "string"},
218+
"TypeForCompB": {"type": "number"},
219+
}
220+
}
221+
catalog_schema = {
222+
"catalogId": "basic",
223+
"components": {
224+
"CompA": {"$ref": "#/$defs/TypeForCompA"},
225+
"CompB": {"$ref": "#/$defs/TypeForCompB"},
226+
},
227+
}
228+
catalog = A2uiCatalog(
229+
version=VERSION_0_8,
230+
name=BASIC_CATALOG_NAME,
231+
s2c_schema={},
232+
common_types_schema=common_types,
233+
catalog_schema=catalog_schema,
234+
)
235+
236+
pruned_catalog = catalog.with_pruned_components(["CompA"])
237+
pruned_defs = pruned_catalog.common_types_schema["$defs"]
238+
239+
assert "TypeForCompA" in pruned_defs
240+
assert "TypeForCompB" not in pruned_defs

0 commit comments

Comments
 (0)