Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pydough/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"GraphMetadata",
"MaskedTableColumnMetadata",
"PropertyMetadata",
"ScalarAttributeMetadata",
"SimpleJoinMetadata",
"SimpleTableMetadata",
"SubcollectionRelationshipMetadata",
Expand All @@ -24,6 +25,7 @@
GeneralJoinMetadata,
MaskedTableColumnMetadata,
PropertyMetadata,
ScalarAttributeMetadata,
SimpleJoinMetadata,
SubcollectionRelationshipMetadata,
TableColumnMetadata,
Expand Down
10 changes: 9 additions & 1 deletion pydough/metadata/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,14 @@ def parse_graph_v2(graph_name: str, graph_json: dict) -> GraphMetadata:
"""
verified_analysis: list[dict] = []
additional_definitions: list[str] = []
extra_semantic_info: dict = {}
graph: GraphMetadata = GraphMetadata(
graph_name,
additional_definitions,
verified_analysis,
None,
None,
{},
extra_semantic_info,
)

# Parse and extract the metadata for all of the collections in the graph.
Expand Down Expand Up @@ -140,6 +141,7 @@ def parse_graph_v2(graph_name: str, graph_json: dict) -> GraphMetadata:
defn,
f"metadata for additional definitions inside {graph.error_name}",
)
additional_definitions.append(defn)
if "verified pydough analysis" in graph_json:
verified_analysis_json: list = extract_array(
graph_json, "verified pydough analysis", graph.error_name
Expand All @@ -156,6 +158,12 @@ def parse_graph_v2(graph_name: str, graph_json: dict) -> GraphMetadata:
HasPropertyWith("code", is_string).verify(
verified_json, "metadata for verified pydough analysis"
)
verified_analysis.append(verified_json)
if "extra semantic info" in graph_json:
extra_info_json: dict = extract_object(
graph_json, "extra semantic info", graph.error_name
)
extra_semantic_info.update(extra_info_json)

# Add all of the UDF definitions to the graph.
if "functions" in graph_json:
Expand Down
12 changes: 11 additions & 1 deletion pydough/metadata/properties/masked_table_column_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,20 @@ def __init__(
unprotect_protocol: str,
protect_protocol: str,
server_masked: bool,
sample_values: list | None = None,
sample_values: list | None,
description: str | None,
synonyms: list[str] | None,
extra_semantic_info: dict | None,
):
super().__init__(
name,
collection,
protected_data_type,
column_name,
sample_values,
description,
synonyms,
extra_semantic_info,
)
self._unprotected_data_type: PyDoughType = data_type
self._unprotect_protocol: str = unprotect_protocol
Expand Down Expand Up @@ -172,6 +178,10 @@ def parse_from_json(
unprotect_protocol,
protect_protocol,
server_masked,
None,
None,
None,
None,
)
# Parse the optional common semantic properties like the description.
property.parse_optional_properties(property_json)
Expand Down
4 changes: 3 additions & 1 deletion pydough/metadata/properties/scalar_attribute_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,6 @@ def parse_optional_properties(self, meta_json: dict) -> None:

# Extract the optional sample values field from the JSON object.
if "sample values" in meta_json:
extract_array(meta_json, "sample values", self.error_name)
self._sample_values = extract_array(
meta_json, "sample values", self.error_name
)
12 changes: 8 additions & 4 deletions pydough/metadata/properties/table_column_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def __init__(
collection: CollectionMetadata,
data_type: PyDoughType,
column_name: str,
sample_values: list | None = None,
description: str | None = None,
synonyms: list[str] | None = None,
extra_semantic_info: dict | None = None,
sample_values: list | None,
description: str | None,
synonyms: list[str] | None,
extra_semantic_info: dict | None,
):
super().__init__(
name,
Expand Down Expand Up @@ -112,6 +112,10 @@ def parse_from_json(
collection,
data_type,
column_name,
None,
None,
None,
None,
)
# Parse the optional common semantic properties like the description.
property.parse_optional_properties(property_json)
Expand Down
128 changes: 128 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CollectionMetadata,
GraphMetadata,
PropertyMetadata,
ScalarAttributeMetadata,
SimpleJoinMetadata,
SimpleTableMetadata,
TableColumnMetadata,
Expand Down Expand Up @@ -296,3 +297,130 @@ def test_simple_join_info(
assert property.keys == keys, (
f"Mismatch between 'keys' of {property!r} and expected value"
)


def test_semantic_info(get_sample_graph: graph_fetcher) -> None:
"""
Testing that the semantic fields of the metadata are set correctly.
"""
graph: GraphMetadata = get_sample_graph("TPCH")

# Verify the semantic info fields of the overall grapah
assert graph.verified_pydough_analysis == [
{
"question": "How many customers are in China?",
"code": "TPCH.CALCULATE(n_chinese_customers=COUNT(customers.WHERE(nation.name == 'CHINA')))",
},
{
"question": "What was the most ordered part in 1995, by quantity, by Brazilian customers?",
"code": "parts.CALCULATE(name, quantity=SUM(lines.WHERE((YEAR(ship_date) == 1995) & (order.customer.nation.name == 'BRAZIL')).quantity)).TOP_K(1, by=quantity)",
},
{
"question": "Who is the wealthiest customer in each nation in Africa?",
"code": "nations.WHERE(region.name == 'AFRICA').CALCULATE(nation_name=name, richest_customer=customers.BEST(per='nation', by=account_balance.DESC()).name)",
},
]

assert graph.additional_definitions == [
"Revenue for a lineitem is the extended_price * (1 - discount) * (1 - tax) minus quantity * supply_cost from the corresponding supply record",
"A domestic shipment is a lineitem where the customer and supplier are from the same nation",
"Frequent buyers are customers that have placed more than 5 orders in a single year for at least two different years",
]

assert graph.extra_semantic_info == {
"data source": "TPC-H Benchmark Dataset",
"data generation tool": "TPC-H dbgen tool",
"dataset download link": "https://github.com/lovasoa/TPCH-sqlite/releases/download/v1.0/TPC-H.db",
"schema diagram link": "https://docs.snowflake.com/en/user-guide/sample-data-tpch",
"dataset specification link": "https://www.tpc.org/TPC_Documents_Current_Versions/pdf/TPC-H_v3.0.1.pdf",
"data scale factor": 1,
"intended use": "Simulating decision support systems for complex ad-hoc queries and concurrent data modifications",
"notable characteristics": "Highly normalized schema with multiple tables and relationships, designed to represent a wholesale supplier's business environment",
"data description": "Contains information about orders. Every order has one or more lineitems, each representing the purchase and shipment of a specific part from a specific supplier. Each order is placed by a customer, and both customers and suppliers belong to nations which in turn belong to regions. Additionally, there are supply records indicating every combination of a supplier and the parts they supply.",
}

# Verify the semantic info fields for a collection (parts)
collection = graph.get_collection("parts")
assert isinstance(collection, CollectionMetadata)
assert (
collection.description
== "The various products supplied by various companies in shipments to different customers"
)
assert collection.synonyms == [
"products",
"components",
"items",
"goods",
]
assert collection.extra_semantic_info == {
"nrows": 200000,
"distinct values": {
"key": 200000,
"name": 200000,
"manufacturer": 5,
"brand": 25,
"part_type": 150,
"size": 50,
"container": 40,
"retail_price": 20899,
"comment": 131753,
},
"correlations": {
"brand": "each brand is associated with exactly one manufacturer, and each manufacturer has exactly 5 distinct brands"
},
}

# Verify the semantic info fields for a scalar property (part.size)
scalar_property = collection.get_property("size")
assert isinstance(scalar_property, ScalarAttributeMetadata)
assert scalar_property.sample_values == [1, 10, 31, 46, 50]
assert scalar_property.description == "The size of the part"
assert scalar_property.synonyms == [
"dimension",
"measurement",
"length",
"width",
"height",
"volume",
]
assert scalar_property.extra_semantic_info == {
"minimum value": 1,
"maximum value": 50,
"is dense": True,
"distinct values": 50,
"correlated fields": [],
}

# Test the semantic information for a relationship property (part.lines)
join_property = collection.get_property("lines")
assert isinstance(join_property, SimpleJoinMetadata)
assert join_property.description == ("The line items for shipments of the part")
assert join_property.synonyms == [
"shipments",
"packages",
"purchases",
"deliveries",
"sales",
]
assert join_property.extra_semantic_info == {
"unmatched rows": 0,
"min matches per row": 9,
"max matches per row": 57,
"avg matches per row": 30.01,
"classification": "one-to-many",
}

# Test the semantic information for a property that does not have some
# semantic information defined for it (part.supply_records)
empty_semantic_property = collection.get_property("supply_records")
assert isinstance(empty_semantic_property, SimpleJoinMetadata)
assert (
empty_semantic_property.description
== "The records indicating which companies supply the part"
)
assert empty_semantic_property.synonyms == [
"producers",
"vendors",
"suppliers of part",
]
assert empty_semantic_property.extra_semantic_info is None
61 changes: 55 additions & 6 deletions tests/test_metadata/sample_graphs.json
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,14 @@
"data type": "numeric",
"description": "The size of the part",
"sample values": [1, 10, 31, 46, 50],
"synonyms": ["dimension", "measurement", "length", "width", "height", "volume"]
"synonyms": ["dimension", "measurement", "length", "width", "height", "volume"],
"extra semantic info": {
"minimum value": 1,
"maximum value": 50,
"is dense": true,
"distinct values": 50,
"correlated fields": []
}
},
{
"name": "container",
Expand Down Expand Up @@ -164,7 +171,24 @@
}
],
"description": "The various products supplied by various companies in shipments to different customers",
"synonyms": ["products", "components", "items", "goods"]
"synonyms": ["products", "components", "items", "goods"],
"extra semantic info": {
"nrows": 200000,
"distinct values": {
"key": 200000,
"name": 200000,
"manufacturer": 5,
"brand": 25,
"part_type": 150,
"size": 50,
"container": 40,
"retail_price": 20899,
"comment": 131753
},
"correlations": {
"brand": "each brand is associated with exactly one manufacturer, and each manufacturer has exactly 5 distinct brands"
}
}
},
{
"name": "suppliers",
Expand Down Expand Up @@ -697,7 +721,14 @@
"always matches": false,
"keys": {"key": ["part_key"]},
"description": "The line items for shipments of the part",
"synonyms": ["shipments", "packages", "purchases", "deliveries", "sales"]
"synonyms": ["shipments", "packages", "purchases", "deliveries", "sales"],
"extra semantic info": {
"unmatched rows": 0,
"min matches per row": 9,
"max matches per row": 57,
"avg matches per row": 30.01,
"classification": "one-to-many"
}
},
{
"type": "reverse",
Expand Down Expand Up @@ -814,9 +845,27 @@
"synonyms": ["transactions", "purchases"]
}
],
"additional definitions": [],
"verified pydough analysis": [],
"extra semantic info": {}
"additional definitions": [
"Revenue for a lineitem is the extended_price * (1 - discount) * (1 - tax) minus quantity * supply_cost from the corresponding supply record",
"A domestic shipment is a lineitem where the customer and supplier are from the same nation",
"Frequent buyers are customers that have placed more than 5 orders in a single year for at least two different years"
],
"verified pydough analysis": [
{"question": "How many customers are in China?", "code": "TPCH.CALCULATE(n_chinese_customers=COUNT(customers.WHERE(nation.name == 'CHINA')))"},
{"question": "What was the most ordered part in 1995, by quantity, by Brazilian customers?", "code": "parts.CALCULATE(name, quantity=SUM(lines.WHERE((YEAR(ship_date) == 1995) & (order.customer.nation.name == 'BRAZIL')).quantity)).TOP_K(1, by=quantity)"},
{"question": "Who is the wealthiest customer in each nation in Africa?", "code": "nations.WHERE(region.name == 'AFRICA').CALCULATE(nation_name=name, richest_customer=customers.BEST(per='nation', by=account_balance.DESC()).name)"}
],
"extra semantic info": {
"data source": "TPC-H Benchmark Dataset",
"data generation tool": "TPC-H dbgen tool",
"dataset download link": "https://github.com/lovasoa/TPCH-sqlite/releases/download/v1.0/TPC-H.db",
"schema diagram link": "https://docs.snowflake.com/en/user-guide/sample-data-tpch",
"dataset specification link": "https://www.tpc.org/TPC_Documents_Current_Versions/pdf/TPC-H_v3.0.1.pdf",
"data scale factor": 1,
"intended use": "Simulating decision support systems for complex ad-hoc queries and concurrent data modifications",
"notable characteristics": "Highly normalized schema with multiple tables and relationships, designed to represent a wholesale supplier's business environment",
"data description": "Contains information about orders. Every order has one or more lineitems, each representing the purchase and shipment of a specific part from a specific supplier. Each order is placed by a customer, and both customers and suppliers belong to nations which in turn belong to regions. Additionally, there are supply records indicating every combination of a supplier and the parts they supply."
}
},
{
"name": "Empty",
Expand Down