From d333b332484d606d667cffbe6edb515ac309f3c7 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 19 Dec 2025 05:59:31 -0800 Subject: [PATCH 01/12] Adding max_rows feature and tests [RUN CI] --- documentation/usage.md | 2 + pydough/evaluation/evaluate_unqualified.py | 6 +- pydough/sqlglot/execute_relational.py | 27 +++- tests/test_max_rows.py | 129 ++++++++++++++++++ .../nations_top3_max2_sqlite.sql | 7 + .../nations_top3_max6_sqlite.sql | 7 + .../regions_max100_sqlite.sql | 6 + .../test_sql_refsols/regions_max1_sqlite.sql | 6 + tests/testing_utilities.py | 11 +- 9 files changed, 196 insertions(+), 5 deletions(-) create mode 100644 tests/test_max_rows.py create mode 100644 tests/test_sql_refsols/nations_top3_max2_sqlite.sql create mode 100644 tests/test_sql_refsols/nations_top3_max6_sqlite.sql create mode 100644 tests/test_sql_refsols/regions_max100_sqlite.sql create mode 100644 tests/test_sql_refsols/regions_max1_sqlite.sql diff --git a/documentation/usage.md b/documentation/usage.md index 57bc2da95..5f5894687 100644 --- a/documentation/usage.md +++ b/documentation/usage.md @@ -454,6 +454,7 @@ The `to_sql` API takes in PyDough code and transforms it into SQL query text wit - `config`: the PyDough configuration settings to use for the conversion (if omitted, `pydough.active_session.config` is used instead). - `database`: the database context to use for the conversion (if omitted, `pydough.active_session.database` is used instead). The database context matters because it controls which SQL dialect is used for the translation. - `session`: a PyDough session object which, if provided, is used instead of `pydough.active_session` or the `metadata` / `config` / `database` arguments. Note: this argument cannot be used alongside those arguments. +- `max_rows`: a positive integer which, if provided, indicates that the SQL query should produce at most that many rows. E.g. `max_rows=10` will ensure the SQL query ends in `LIMIT 10` (unless the query already ends in a smaller limit). Below is an example of using `pydough.to_sql` and the output (the SQL output may be outdated if PyDough's SQL conversion process has been updated): @@ -497,6 +498,7 @@ The `to_df` API does all the same steps as the [`to_sql` API](#pydoughto_sql), b - `database`: the database context to use for the conversion (if omitted, `pydough.active_session.database` is used instead). The database context matters because it controls which SQL dialect is used for the translation. - `session`: a PyDough session object which, if provided, is used instead of `pydough.active_session` or the `metadata` / `config` / `database` arguments. Note: this argument cannot be used alongside those arguments. - `display_sql`: displays the sql before executing in a logger. +- `max_rows`: a positive integer which, if provided, indicates that the output should produce at most that many rows. E.g. `max_rows=10` will ensure the result returns at most 10 rows, as if the SQL query ended with `LIMIT 10`. Below is an example of using `pydough.to_df` and the output, attached to a sqlite database containing data for the TPC-H schema: diff --git a/pydough/evaluation/evaluate_unqualified.py b/pydough/evaluation/evaluate_unqualified.py index 2ab99bffb..7cd2093ef 100644 --- a/pydough/evaluation/evaluate_unqualified.py +++ b/pydough/evaluation/evaluate_unqualified.py @@ -148,6 +148,7 @@ def to_sql(node: UnqualifiedNode, **kwargs) -> str: The SQL string corresponding to the unqualified query. """ column_selection: list[tuple[str, str]] | None = _load_column_selection(kwargs) + max_rows: int | None = kwargs.pop("max_rows", None) session: PyDoughSession = _load_session_info(**kwargs) qualified: PyDoughQDAG = qualify_node(node, session) if not isinstance(qualified, PyDoughCollectionQDAG): @@ -155,7 +156,7 @@ def to_sql(node: UnqualifiedNode, **kwargs) -> str: relational: RelationalRoot = convert_ast_to_relational( qualified, column_selection, session ) - return convert_relation_to_sql(relational, session) + return convert_relation_to_sql(relational, session, max_rows) def to_df(node: UnqualifiedNode, **kwargs) -> pd.DataFrame: @@ -175,6 +176,7 @@ def to_df(node: UnqualifiedNode, **kwargs) -> pd.DataFrame: The DataFrame corresponding to the unqualified query. """ column_selection: list[tuple[str, str]] | None = _load_column_selection(kwargs) + max_rows: int | None = kwargs.pop("max_rows", None) display_sql: bool = bool(kwargs.pop("display_sql", False)) session: PyDoughSession = _load_session_info(**kwargs) qualified: PyDoughQDAG = qualify_node(node, session) @@ -183,4 +185,4 @@ def to_df(node: UnqualifiedNode, **kwargs) -> pd.DataFrame: relational: RelationalRoot = convert_ast_to_relational( qualified, column_selection, session ) - return execute_df(relational, session, display_sql) + return execute_df(relational, session, display_sql, max_rows) diff --git a/pydough/sqlglot/execute_relational.py b/pydough/sqlglot/execute_relational.py index b55209707..75f241807 100644 --- a/pydough/sqlglot/execute_relational.py +++ b/pydough/sqlglot/execute_relational.py @@ -50,7 +50,9 @@ __all__ = ["convert_relation_to_sql", "execute_df"] -def convert_relation_to_sql(relational: RelationalRoot, session: PyDoughSession) -> str: +def convert_relation_to_sql( + relational: RelationalRoot, session: PyDoughSession, max_rows: int | None +) -> str: """ Convert the given relational tree to a SQL string using the given dialect. @@ -58,6 +60,7 @@ def convert_relation_to_sql(relational: RelationalRoot, session: PyDoughSession) `relational`: The relational tree to convert. `session`: The PyDough session encapsulating the logic used to execute the logic, including the PyDough configs and the database context. + `max_rows`: An optional limit on the number of rows to return. Returns: The SQL string representing the relational tree. @@ -65,6 +68,24 @@ def convert_relation_to_sql(relational: RelationalRoot, session: PyDoughSession) glot_expr: SQLGlotExpression = SQLGlotRelationalVisitor( session ).relational_to_sqlglot(relational) + + # If `max_rows` is specified, add a LIMIT clause to the SQLGlot expression. + if max_rows is not None: + assert isinstance(glot_expr, Select) + # If a limit does not already exist, add one. + if glot_expr.args.get("limit") is None: + glot_expr = glot_expr.limit(sqlglot_expressions.Literal.number(max_rows)) + # If one does exist, update its value to be the minimum of the + # existing limit and `max_rows`. + else: + existing_limit_expr = glot_expr.args.get("limit").expression + assert isinstance(existing_limit_expr, sqlglot_expressions.Literal) + glot_expr = glot_expr.limit( + sqlglot_expressions.Literal.number( + min(int(existing_limit_expr.this), max_rows) + ) + ) + sqlglot_dialect: SQLGlotDialect = convert_dialect_to_sqlglot( session.database.dialect ) @@ -417,6 +438,7 @@ def execute_df( relational: RelationalRoot, session: PyDoughSession, display_sql: bool = False, + max_rows: int | None = None, ) -> pd.DataFrame: """ Execute the given relational tree on the given database access @@ -428,11 +450,12 @@ def execute_df( the logic, including the database context. `display_sql`: if True, prints out the SQL that will be run before it is executed. + `max_rows`: An optional limit on the number of rows to return. Returns: The result of the query as a Pandas DataFrame """ - sql: str = convert_relation_to_sql(relational, session) + sql: str = convert_relation_to_sql(relational, session, max_rows) if display_sql: pyd_logger = get_logger(__name__) pyd_logger.info(f"SQL query:\n {sql}") diff --git a/tests/test_max_rows.py b/tests/test_max_rows.py new file mode 100644 index 000000000..cfbc5c9c3 --- /dev/null +++ b/tests/test_max_rows.py @@ -0,0 +1,129 @@ +""" +Integration tests for the PyDough workflow on the TPC-H queries. +""" + +from collections.abc import Callable + +import pandas as pd +import pytest + +from pydough.database_connectors import DatabaseContext, DatabaseDialect +from tests.testing_utilities import ( + graph_fetcher, +) + +from .testing_utilities import PyDoughPandasTest + + +@pytest.mark.parametrize( + "execute", + [ + pytest.param(False, id="sql"), + pytest.param(True, id="e2e", marks=pytest.mark.execute), + ], +) +@pytest.mark.parametrize( + "test_data, max_rows", + [ + pytest.param( + PyDoughPandasTest( + "result = regions", + "TPCH", + lambda: pd.DataFrame( + { + "key": [0], + "name": ["AFRICA"], + "comment": [ + "lar deposits. blithely final packages cajole. regular waters are final requests. regular accounts are according to " + ], + } + ), + "regions_max1", + ), + 1, + id="regions_max1", + ), + pytest.param( + PyDoughPandasTest( + "result = regions", + "TPCH", + lambda: pd.DataFrame( + { + "key": [0, 1, 2, 3, 4], + "name": ["AFRICA", "AMERICA", "ASIA", "EUROPE", "MIDDLE EAST"], + "comment": [ + "lar deposits. blithely final packages cajole. regular waters are final requests. regular accounts are according to ", + "hs use ironic, even requests. s", + "ges. thinly even pinto beans ca", + "ly final courts cajole furiously final excuse", + "uickly special accounts cajole carefully blithely close requests. carefully final asymptotes haggle furiousl", + ], + } + ), + "regions_max100", + ), + 100, + id="regions_max100", + ), + pytest.param( + PyDoughPandasTest( + "result = nations.CALCULATE(key, name).TOP_K(3, by=name.ASC())", + "TPCH", + lambda: pd.DataFrame( + { + "key": [0, 1], + "name": ["ALGERIA", "ARGENTINA"], + } + ), + "nations_top3_max2", + ), + 2, + id="nations_top3_max2", + ), + pytest.param( + PyDoughPandasTest( + "result = nations.CALCULATE(key, name).TOP_K(3, by=name.ASC())", + "TPCH", + lambda: pd.DataFrame( + { + "key": [0, 1, 2], + "name": ["ALGERIA", "ARGENTINA", "BRAZIL"], + } + ), + "nations_top3_max6", + ), + 6, + id="nations_top3_max6", + ), + ], +) +def test_max_rows( + test_data: PyDoughPandasTest, + max_rows: int, + get_sample_graph: graph_fetcher, + sqlite_tpch_db_context: DatabaseContext, + execute: bool, + get_sql_test_filename: Callable[[str, DatabaseDialect], str], + update_tests: bool, +): + """ + Test either SQL or runtime execution of custom queries on the TPC-H dataset + with various max_rows settings. + """ + if execute: + test_data.run_e2e_test( + get_sample_graph, + sqlite_tpch_db_context, + max_rows=max_rows, + ) + else: + file_path: str = get_sql_test_filename( + test_data.test_name, sqlite_tpch_db_context.dialect + ) + test_data.run_sql_test( + get_sample_graph, + file_path, + update_tests, + sqlite_tpch_db_context, + max_rows=max_rows, + ) diff --git a/tests/test_sql_refsols/nations_top3_max2_sqlite.sql b/tests/test_sql_refsols/nations_top3_max2_sqlite.sql new file mode 100644 index 000000000..13108744d --- /dev/null +++ b/tests/test_sql_refsols/nations_top3_max2_sqlite.sql @@ -0,0 +1,7 @@ +SELECT + n_nationkey AS key, + n_name AS name +FROM tpch.nation +ORDER BY + 2 +LIMIT 2 diff --git a/tests/test_sql_refsols/nations_top3_max6_sqlite.sql b/tests/test_sql_refsols/nations_top3_max6_sqlite.sql new file mode 100644 index 000000000..90eaf5a19 --- /dev/null +++ b/tests/test_sql_refsols/nations_top3_max6_sqlite.sql @@ -0,0 +1,7 @@ +SELECT + n_nationkey AS key, + n_name AS name +FROM tpch.nation +ORDER BY + 2 +LIMIT 3 diff --git a/tests/test_sql_refsols/regions_max100_sqlite.sql b/tests/test_sql_refsols/regions_max100_sqlite.sql new file mode 100644 index 000000000..cdb00872d --- /dev/null +++ b/tests/test_sql_refsols/regions_max100_sqlite.sql @@ -0,0 +1,6 @@ +SELECT + r_regionkey AS key, + r_name AS name, + r_comment AS comment +FROM tpch.region +LIMIT 100 diff --git a/tests/test_sql_refsols/regions_max1_sqlite.sql b/tests/test_sql_refsols/regions_max1_sqlite.sql new file mode 100644 index 000000000..fa3ec2fe4 --- /dev/null +++ b/tests/test_sql_refsols/regions_max1_sqlite.sql @@ -0,0 +1,6 @@ +SELECT + r_regionkey AS key, + r_name AS name, + r_comment AS comment +FROM tpch.region +LIMIT 1 diff --git a/tests/testing_utilities.py b/tests/testing_utilities.py index f3b734241..229c70918 100644 --- a/tests/testing_utilities.py +++ b/tests/testing_utilities.py @@ -1205,6 +1205,7 @@ def run_sql_test( update: bool, database: DatabaseContext, config: PyDoughConfigs | None = None, + max_rows: int | None = None, ) -> None: """ Runs a test on the SQL code generated by the PyDough code, @@ -1221,6 +1222,7 @@ def run_sql_test( `database`: The database context to determine what dialect of SQL to use when generating the SQL test. `config`: The PyDough configuration to use for the test, if any. + `max_rows`: The maximum number of rows to return from the query. """ # Skip if indicated. if self.skip_sql: @@ -1233,7 +1235,11 @@ def run_sql_test( ) # Convert the PyDough code to SQL text - call_kwargs: dict = {"metadata": graph, "database": database} + call_kwargs: dict = { + "metadata": graph, + "database": database, + "max_rows": max_rows, + } if config is not None: call_kwargs["config"] = config if self.columns is not None: @@ -1259,6 +1265,7 @@ def run_e2e_test( config: PyDoughConfigs | None = None, display_sql: bool = False, coerce_types: bool = False, + max_rows: int | None = None, ): """ Runs an end-to-end test using the data in the SQL comparison test, @@ -1273,6 +1280,7 @@ def run_e2e_test( `display_sql`: If True, displays the SQL generated by PyDough. `coerce_types`: If True, coerces the types of the result and reference solution DataFrames to ensure compatibility. + `max_rows`: The maximum number of rows to return from the query. """ # Obtain the graph and the unqualified node graph: GraphMetadata = fetcher(self.graph_name) @@ -1284,6 +1292,7 @@ def run_e2e_test( "metadata": graph, "database": database, "display_sql": display_sql, + "max_rows": max_rows, } if config is not None: call_kwargs["config"] = config From 30015ebc8ee8e42aaf6be7ddd11439b73b9d679e Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 19 Dec 2025 06:07:00 -0800 Subject: [PATCH 02/12] Added more complicated test [RUN CI] --- tests/test_max_rows.py | 17 +++++++++++++++++ .../richest_customers_orders_max5_sqlite.sql | 18 ++++++++++++++++++ 2 files changed, 35 insertions(+) create mode 100644 tests/test_sql_refsols/richest_customers_orders_max5_sqlite.sql diff --git a/tests/test_max_rows.py b/tests/test_max_rows.py index cfbc5c9c3..9c9944b3f 100644 --- a/tests/test_max_rows.py +++ b/tests/test_max_rows.py @@ -95,6 +95,23 @@ 6, id="nations_top3_max6", ), + pytest.param( + PyDoughPandasTest( + "result = customers.CALCULATE(ck=key).TOP_K(10, by=account_balance.DESC()).orders.CALCULATE(ck, ok=key, tp=total_price).ORDER_BY(tp.DESC())", + "TPCH", + lambda: pd.DataFrame( + { + "ck": [61453, 23828, 61453, 144232, 129934], + "ok": [4056323, 5349568, 5503299, 5343141, 1808867], + "tp": [424918.3, 322586.39, 308661.75, 307284.63, 306708.79], + } + ), + "richest_customers_orders_max5", + order_sensitive=True, + ), + 5, + id="richest_customers_orders_max5", + ), ], ) def test_max_rows( diff --git a/tests/test_sql_refsols/richest_customers_orders_max5_sqlite.sql b/tests/test_sql_refsols/richest_customers_orders_max5_sqlite.sql new file mode 100644 index 000000000..f4cdfed6b --- /dev/null +++ b/tests/test_sql_refsols/richest_customers_orders_max5_sqlite.sql @@ -0,0 +1,18 @@ +WITH _s0 AS ( + SELECT + c_custkey + FROM tpch.customer + ORDER BY + c_acctbal DESC + LIMIT 10 +) +SELECT + _s0.c_custkey AS ck, + orders.o_orderkey AS ok, + orders.o_totalprice AS tp +FROM _s0 AS _s0 +JOIN tpch.orders AS orders + ON _s0.c_custkey = orders.o_custkey +ORDER BY + 3 DESC +LIMIT 5 From 50a5c5bc1d29ce3bab2dc543be4b171e1a97b521 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 19 Dec 2025 06:07:43 -0800 Subject: [PATCH 03/12] Added more complicated test [RUN CI] --- pydough/sqlglot/execute_relational.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydough/sqlglot/execute_relational.py b/pydough/sqlglot/execute_relational.py index 75f241807..026995f13 100644 --- a/pydough/sqlglot/execute_relational.py +++ b/pydough/sqlglot/execute_relational.py @@ -51,7 +51,7 @@ def convert_relation_to_sql( - relational: RelationalRoot, session: PyDoughSession, max_rows: int | None + relational: RelationalRoot, session: PyDoughSession, max_rows: int | None = None ) -> str: """ Convert the given relational tree to a SQL string using the given dialect. From 3e6ccdf7589bd5e99f25498b79b403108a9d4e67 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Fri, 19 Dec 2025 14:34:46 -0500 Subject: [PATCH 04/12] Apply suggestions from code review Co-authored-by: john-sanchez31 --- tests/test_max_rows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_max_rows.py b/tests/test_max_rows.py index 9c9944b3f..573681f6a 100644 --- a/tests/test_max_rows.py +++ b/tests/test_max_rows.py @@ -1,5 +1,5 @@ """ -Integration tests for the PyDough workflow on the TPC-H queries. +Integration tests for the PyDough workflow with `max_rows` using TPC-H """ from collections.abc import Callable From 37457d9e6de1521bca4ccc18b183b815de645d19 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 19 Dec 2025 11:35:13 -0800 Subject: [PATCH 05/12] [RUN CI] From 6c7f10feae152809392f000305a56a24d2f88b16 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 19 Dec 2025 16:17:35 -0800 Subject: [PATCH 06/12] Adding semantic info handling in metadata with testing [RUN CI] --- pydough/metadata/__init__.py | 2 + pydough/metadata/parse.py | 10 +- .../masked_table_column_metadata.py | 12 ++- .../properties/scalar_attribute_metadata.py | 4 +- .../properties/table_column_metadata.py | 12 ++- tests/test_metadata.py | 94 +++++++++++++++++++ tests/test_metadata/sample_graphs.json | 52 +++++++++- 7 files changed, 174 insertions(+), 12 deletions(-) diff --git a/pydough/metadata/__init__.py b/pydough/metadata/__init__.py index 175624669..ca66b0332 100644 --- a/pydough/metadata/__init__.py +++ b/pydough/metadata/__init__.py @@ -9,6 +9,7 @@ "GraphMetadata", "MaskedTableColumnMetadata", "PropertyMetadata", + "ScalarAttributeMetadata", "SimpleJoinMetadata", "SimpleTableMetadata", "SubcollectionRelationshipMetadata", @@ -24,6 +25,7 @@ GeneralJoinMetadata, MaskedTableColumnMetadata, PropertyMetadata, + ScalarAttributeMetadata, SimpleJoinMetadata, SubcollectionRelationshipMetadata, TableColumnMetadata, diff --git a/pydough/metadata/parse.py b/pydough/metadata/parse.py index f866e77e2..bb1c67b1c 100644 --- a/pydough/metadata/parse.py +++ b/pydough/metadata/parse.py @@ -98,13 +98,14 @@ def parse_graph_v2(graph_name: str, graph_json: dict) -> GraphMetadata: """ verified_analysis: list[dict] = [] additional_definitions: list[str] = [] + extra_info: dict = {} graph: GraphMetadata = GraphMetadata( graph_name, additional_definitions, verified_analysis, None, None, - {}, + extra_info, ) # Parse and extract the metadata for all of the collections in the graph. @@ -140,6 +141,7 @@ def parse_graph_v2(graph_name: str, graph_json: dict) -> GraphMetadata: defn, f"metadata for additional definitions inside {graph.error_name}", ) + additional_definitions.append(defn) if "verified pydough analysis" in graph_json: verified_analysis_json: list = extract_array( graph_json, "verified pydough analysis", graph.error_name @@ -156,6 +158,12 @@ def parse_graph_v2(graph_name: str, graph_json: dict) -> GraphMetadata: HasPropertyWith("code", is_string).verify( verified_json, "metadata for verified pydough analysis" ) + verified_analysis.append(verified_json) + if "extra semantic info" in graph_json: + extra_info_json: dict = extract_object( + graph_json, "extra semantic info", graph.error_name + ) + extra_info.update(extra_info_json) # Add all of the UDF definitions to the graph. if "functions" in graph_json: diff --git a/pydough/metadata/properties/masked_table_column_metadata.py b/pydough/metadata/properties/masked_table_column_metadata.py index a30c4aaa7..500d9857c 100644 --- a/pydough/metadata/properties/masked_table_column_metadata.py +++ b/pydough/metadata/properties/masked_table_column_metadata.py @@ -48,7 +48,10 @@ def __init__( unprotect_protocol: str, protect_protocol: str, server_masked: bool, - sample_values: list | None = None, + sample_values: list | None, + description: str | None, + synonyms: list[str] | None, + extra_semantic_info: dict | None, ): super().__init__( name, @@ -56,6 +59,9 @@ def __init__( protected_data_type, column_name, sample_values, + description, + synonyms, + extra_semantic_info, ) self._unprotected_data_type: PyDoughType = data_type self._unprotect_protocol: str = unprotect_protocol @@ -172,6 +178,10 @@ def parse_from_json( unprotect_protocol, protect_protocol, server_masked, + None, + None, + None, + None, ) # Parse the optional common semantic properties like the description. property.parse_optional_properties(property_json) diff --git a/pydough/metadata/properties/scalar_attribute_metadata.py b/pydough/metadata/properties/scalar_attribute_metadata.py index 38c934df0..5a7922b06 100644 --- a/pydough/metadata/properties/scalar_attribute_metadata.py +++ b/pydough/metadata/properties/scalar_attribute_metadata.py @@ -73,4 +73,6 @@ def parse_optional_properties(self, meta_json: dict) -> None: # Extract the optional sample values field from the JSON object. if "sample values" in meta_json: - extract_array(meta_json, "sample values", self.error_name) + self._sample_values = extract_array( + meta_json, "sample values", self.error_name + ) diff --git a/pydough/metadata/properties/table_column_metadata.py b/pydough/metadata/properties/table_column_metadata.py index da6c49098..befb95388 100644 --- a/pydough/metadata/properties/table_column_metadata.py +++ b/pydough/metadata/properties/table_column_metadata.py @@ -40,10 +40,10 @@ def __init__( collection: CollectionMetadata, data_type: PyDoughType, column_name: str, - sample_values: list | None = None, - description: str | None = None, - synonyms: list[str] | None = None, - extra_semantic_info: dict | None = None, + sample_values: list | None, + description: str | None, + synonyms: list[str] | None, + extra_semantic_info: dict | None, ): super().__init__( name, @@ -112,6 +112,10 @@ def parse_from_json( collection, data_type, column_name, + None, + None, + None, + None, ) # Parse the optional common semantic properties like the description. property.parse_optional_properties(property_json) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index d36aa4edf..68d3d99ca 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -8,6 +8,7 @@ CollectionMetadata, GraphMetadata, PropertyMetadata, + ScalarAttributeMetadata, SimpleJoinMetadata, SimpleTableMetadata, TableColumnMetadata, @@ -296,3 +297,96 @@ def test_simple_join_info( assert property.keys == keys, ( f"Mismatch between 'keys' of {property!r} and expected value" ) + + +def test_semantic_info(get_sample_graph: graph_fetcher) -> None: + """ + Testing that the semantic fields of the metadata are set correctly. + """ + graph: GraphMetadata = get_sample_graph("TPCH") + + # Verify the semantic info fields of the grpah + assert graph.verified_pydough_analysis == [ + { + "question": "How many customers are in China?", + "code": "TPCH.CALCULATE(n_chinese_customers=COUNT(customers.WHERE(nation.name == 'CHINA')))", + }, + { + "question": "What was the most ordered part in 1995, by quantity, by Brazilian customers?", + "code": "parts.CALCULATE(name, quantity=SUM(lines.WHERE((YEAR(ship_date) == 1995) & (order.customer.nation.name == 'BRAZIL')).quantity)).TOP_K(1, by=quantity)", + }, + { + "question": "Who is the wealthiest customer in each nation in Africa?", + "code": "nations.WHERE(region.name == 'AFRICA').CALCULATE(nation_name=name, richest_customer=customers.BEST(per='nation', by=account_balance.DESC()).name)", + }, + ] + + assert graph.additional_definitions == [ + "Revenue for a lineitem is the extended_price * (1 - discount) * (1 - tax) minus quantity * supply_cost from the corresponding supply record", + "A domestic shipment is a lineitem where the customer and supplier are from the same nation", + "Frequent buyers are customers that have placed more than 5 orders in a single year for at least two different years", + ] + + assert graph.extra_semantic_info == { + "data source": "TPC-H Benchmark Dataset", + "data generation tool": "TPC-H dbgen tool", + "dataset download link": "https://github.com/lovasoa/TPCH-sqlite/releases/download/v1.0/TPC-H.db", + "schema diagram link": "https://docs.snowflake.com/en/user-guide/sample-data-tpch", + "dataset specification link": "https://www.tpc.org/TPC_Documents_Current_Versions/pdf/TPC-H_v3.0.1.pdf", + "data scale factor": 1, + "intended use": "Simulating decision support systems for complex ad-hoc queries and concurrent data modifications", + "notable characteristics": "Highly normalized schema with multiple tables and relationships, designed to represent a wholesale supplier's business environment", + "data description": "Contains information about orders. Every order has one or more lineitems, each representing the purchase and shipment of a specific part from a specific supplier. Each order is placed by a customer, and both customers and suppliers belong to nations which in turn belong to regions. Additionally, there are supply records indicating every combination of a supplier and the parts they supply.", + } + + # Verify the semantic info fields of the parts collection + collection = graph.get_collection("parts") + assert isinstance(collection, CollectionMetadata) + assert ( + collection.description + == "The various products supplied by various companies in shipments to different customers" + ) + assert collection.synonyms == [ + "products", + "components", + "items", + "goods", + ] + assert collection.extra_semantic_info == { + "nrows": 200000, + "distinct values": { + "key": 200000, + "name": 200000, + "manufacturer": 5, + "brand": 25, + "part_type": 150, + "size": 50, + "container": 40, + "retail_price": 20899, + "comment": 131753, + }, + "correlations": { + "brand": "each brand is associated with exactly one manufacturer, and each manufacturer has exactly 5 distinct brands" + }, + } + + # Verify the semantic info fields of the size property + property = collection.get_property("size") + assert isinstance(property, ScalarAttributeMetadata) + assert property.sample_values == [1, 10, 31, 46, 50] + assert property.description == "The size of the part" + assert property.synonyms == [ + "dimension", + "measurement", + "length", + "width", + "height", + "volume", + ] + assert property.extra_semantic_info == { + "minimum value": 1, + "maximum value": 50, + "is dense": True, + "distinct values": 50, + "correlated fields": [], + } diff --git a/tests/test_metadata/sample_graphs.json b/tests/test_metadata/sample_graphs.json index 3b0757973..460c10a80 100644 --- a/tests/test_metadata/sample_graphs.json +++ b/tests/test_metadata/sample_graphs.json @@ -135,7 +135,14 @@ "data type": "numeric", "description": "The size of the part", "sample values": [1, 10, 31, 46, 50], - "synonyms": ["dimension", "measurement", "length", "width", "height", "volume"] + "synonyms": ["dimension", "measurement", "length", "width", "height", "volume"], + "extra semantic info": { + "minimum value": 1, + "maximum value": 50, + "is dense": true, + "distinct values": 50, + "correlated fields": [] + } }, { "name": "container", @@ -164,7 +171,24 @@ } ], "description": "The various products supplied by various companies in shipments to different customers", - "synonyms": ["products", "components", "items", "goods"] + "synonyms": ["products", "components", "items", "goods"], + "extra semantic info": { + "nrows": 200000, + "distinct values": { + "key": 200000, + "name": 200000, + "manufacturer": 5, + "brand": 25, + "part_type": 150, + "size": 50, + "container": 40, + "retail_price": 20899, + "comment": 131753 + }, + "correlations": { + "brand": "each brand is associated with exactly one manufacturer, and each manufacturer has exactly 5 distinct brands" + } + } }, { "name": "suppliers", @@ -814,9 +838,27 @@ "synonyms": ["transactions", "purchases"] } ], - "additional definitions": [], - "verified pydough analysis": [], - "extra semantic info": {} + "additional definitions": [ + "Revenue for a lineitem is the extended_price * (1 - discount) * (1 - tax) minus quantity * supply_cost from the corresponding supply record", + "A domestic shipment is a lineitem where the customer and supplier are from the same nation", + "Frequent buyers are customers that have placed more than 5 orders in a single year for at least two different years" + ], + "verified pydough analysis": [ + {"question": "How many customers are in China?", "code": "TPCH.CALCULATE(n_chinese_customers=COUNT(customers.WHERE(nation.name == 'CHINA')))"}, + {"question": "What was the most ordered part in 1995, by quantity, by Brazilian customers?", "code": "parts.CALCULATE(name, quantity=SUM(lines.WHERE((YEAR(ship_date) == 1995) & (order.customer.nation.name == 'BRAZIL')).quantity)).TOP_K(1, by=quantity)"}, + {"question": "Who is the wealthiest customer in each nation in Africa?", "code": "nations.WHERE(region.name == 'AFRICA').CALCULATE(nation_name=name, richest_customer=customers.BEST(per='nation', by=account_balance.DESC()).name)"} + ], + "extra semantic info": { + "data source": "TPC-H Benchmark Dataset", + "data generation tool": "TPC-H dbgen tool", + "dataset download link": "https://github.com/lovasoa/TPCH-sqlite/releases/download/v1.0/TPC-H.db", + "schema diagram link": "https://docs.snowflake.com/en/user-guide/sample-data-tpch", + "dataset specification link": "https://www.tpc.org/TPC_Documents_Current_Versions/pdf/TPC-H_v3.0.1.pdf", + "data scale factor": 1, + "intended use": "Simulating decision support systems for complex ad-hoc queries and concurrent data modifications", + "notable characteristics": "Highly normalized schema with multiple tables and relationships, designed to represent a wholesale supplier's business environment", + "data description": "Contains information about orders. Every order has one or more lineitems, each representing the purchase and shipment of a specific part from a specific supplier. Each order is placed by a customer, and both customers and suppliers belong to nations which in turn belong to regions. Additionally, there are supply records indicating every combination of a supplier and the parts they supply." + } }, { "name": "Empty", From a1bfec53386ace8dce645c54c8de8f1546a4eb9d Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 19 Dec 2025 16:30:11 -0800 Subject: [PATCH 07/12] Added more tests [RUN CI] --- tests/test_metadata.py | 52 +++++++++++++++++++++----- tests/test_metadata/sample_graphs.json | 9 ++++- 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 68d3d99ca..ab684f9c6 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -305,7 +305,7 @@ def test_semantic_info(get_sample_graph: graph_fetcher) -> None: """ graph: GraphMetadata = get_sample_graph("TPCH") - # Verify the semantic info fields of the grpah + # Verify the semantic info fields of the overall grapah assert graph.verified_pydough_analysis == [ { "question": "How many customers are in China?", @@ -339,7 +339,7 @@ def test_semantic_info(get_sample_graph: graph_fetcher) -> None: "data description": "Contains information about orders. Every order has one or more lineitems, each representing the purchase and shipment of a specific part from a specific supplier. Each order is placed by a customer, and both customers and suppliers belong to nations which in turn belong to regions. Additionally, there are supply records indicating every combination of a supplier and the parts they supply.", } - # Verify the semantic info fields of the parts collection + # Verify the semantic info fields for a collection (parts) collection = graph.get_collection("parts") assert isinstance(collection, CollectionMetadata) assert ( @@ -370,12 +370,12 @@ def test_semantic_info(get_sample_graph: graph_fetcher) -> None: }, } - # Verify the semantic info fields of the size property - property = collection.get_property("size") - assert isinstance(property, ScalarAttributeMetadata) - assert property.sample_values == [1, 10, 31, 46, 50] - assert property.description == "The size of the part" - assert property.synonyms == [ + # Verify the semantic info fields for a scalar property (part.size) + scalar_property = collection.get_property("size") + assert isinstance(scalar_property, ScalarAttributeMetadata) + assert scalar_property.sample_values == [1, 10, 31, 46, 50] + assert scalar_property.description == "The size of the part" + assert scalar_property.synonyms == [ "dimension", "measurement", "length", @@ -383,10 +383,44 @@ def test_semantic_info(get_sample_graph: graph_fetcher) -> None: "height", "volume", ] - assert property.extra_semantic_info == { + assert scalar_property.extra_semantic_info == { "minimum value": 1, "maximum value": 50, "is dense": True, "distinct values": 50, "correlated fields": [], } + + # Test the semantic information for a relationship property (part.lines) + join_property = collection.get_property("lines") + assert isinstance(join_property, SimpleJoinMetadata) + assert join_property.description == ("The line items for shipments of the part") + assert join_property.synonyms == [ + "shipments", + "packages", + "purchases", + "deliveries", + "sales", + ] + assert join_property.extra_semantic_info == { + "unmatched rows": 0, + "min matches per row": 9, + "max matches per row": 57, + "avg matches per row": 30.01, + "classification": "one-to-many", + } + + # Test the semantic information for a property that does not have some + # semantic information defined for it (part.supply_records) + empty_semantic_property = collection.get_property("supply_records") + assert isinstance(empty_semantic_property, SimpleJoinMetadata) + assert ( + empty_semantic_property.description + == "The records indicating which companies supply the part" + ) + assert empty_semantic_property.synonyms == [ + "producers", + "vendors", + "suppliers of part", + ] + assert empty_semantic_property.extra_semantic_info is None diff --git a/tests/test_metadata/sample_graphs.json b/tests/test_metadata/sample_graphs.json index 460c10a80..9f99dc3d6 100644 --- a/tests/test_metadata/sample_graphs.json +++ b/tests/test_metadata/sample_graphs.json @@ -721,7 +721,14 @@ "always matches": false, "keys": {"key": ["part_key"]}, "description": "The line items for shipments of the part", - "synonyms": ["shipments", "packages", "purchases", "deliveries", "sales"] + "synonyms": ["shipments", "packages", "purchases", "deliveries", "sales"], + "extra semantic info": { + "unmatched rows": 0, + "min matches per row": 9, + "max matches per row": 57, + "avg matches per row": 30.01, + "classification": "one-to-many" + } }, { "type": "reverse", From 7c963adf3397f9a30200152c1c16da8c7cc34879 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 19 Dec 2025 16:30:57 -0800 Subject: [PATCH 08/12] Added more tests [RUN CI] --- tests/test_relational_to_sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_relational_to_sql.py b/tests/test_relational_to_sql.py index 801e9601f..0dcc0798d 100644 --- a/tests/test_relational_to_sql.py +++ b/tests/test_relational_to_sql.py @@ -584,7 +584,7 @@ def test_convert_relation_to_sqlite_sql( Test converting a relational tree to SQL text in the SQLite dialect. """ file_path: str = get_sql_test_filename(test_name, DatabaseDialect.SQLITE) - created_sql: str = convert_relation_to_sql(root, empty_sqlite_tpch_session) + created_sql: str = convert_relation_to_sql(root, empty_sqlite_tpch_session, None) if update_tests: with open(file_path, "w") as f: f.write(created_sql + "\n") @@ -964,7 +964,7 @@ def test_function_to_sql( to SQL. """ file_path: str = get_sql_test_filename(f"func_{test_name}", DatabaseDialect.ANSI) - created_sql: str = convert_relation_to_sql(root, empty_sqlite_tpch_session) + created_sql: str = convert_relation_to_sql(root, empty_sqlite_tpch_session, None) if update_tests: with open(file_path, "w") as f: f.write(created_sql + "\n") From 1ce78025ebd6147b14d5c207f07a0cc1ed11bb89 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 19 Dec 2025 16:31:39 -0800 Subject: [PATCH 09/12] [RUN CI] --- tests/test_relational_to_sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_relational_to_sql.py b/tests/test_relational_to_sql.py index 801e9601f..0dcc0798d 100644 --- a/tests/test_relational_to_sql.py +++ b/tests/test_relational_to_sql.py @@ -584,7 +584,7 @@ def test_convert_relation_to_sqlite_sql( Test converting a relational tree to SQL text in the SQLite dialect. """ file_path: str = get_sql_test_filename(test_name, DatabaseDialect.SQLITE) - created_sql: str = convert_relation_to_sql(root, empty_sqlite_tpch_session) + created_sql: str = convert_relation_to_sql(root, empty_sqlite_tpch_session, None) if update_tests: with open(file_path, "w") as f: f.write(created_sql + "\n") @@ -964,7 +964,7 @@ def test_function_to_sql( to SQL. """ file_path: str = get_sql_test_filename(f"func_{test_name}", DatabaseDialect.ANSI) - created_sql: str = convert_relation_to_sql(root, empty_sqlite_tpch_session) + created_sql: str = convert_relation_to_sql(root, empty_sqlite_tpch_session, None) if update_tests: with open(file_path, "w") as f: f.write(created_sql + "\n") From 8525e71309199d8d8a1947c0afaac1de1e07f61d Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Sun, 21 Dec 2025 09:47:59 -0800 Subject: [PATCH 10/12] Updates [RUN CI] --- pydough/evaluation/evaluate_unqualified.py | 6 ++++++ tests/test_max_rows.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/pydough/evaluation/evaluate_unqualified.py b/pydough/evaluation/evaluate_unqualified.py index 7cd2093ef..79408bb8d 100644 --- a/pydough/evaluation/evaluate_unqualified.py +++ b/pydough/evaluation/evaluate_unqualified.py @@ -149,6 +149,9 @@ def to_sql(node: UnqualifiedNode, **kwargs) -> str: """ column_selection: list[tuple[str, str]] | None = _load_column_selection(kwargs) max_rows: int | None = kwargs.pop("max_rows", None) + assert (isinstance(max_rows, int) and max_rows > 0) or max_rows is None, ( + "`max_rows` must be a positive integer or None." + ) session: PyDoughSession = _load_session_info(**kwargs) qualified: PyDoughQDAG = qualify_node(node, session) if not isinstance(qualified, PyDoughCollectionQDAG): @@ -177,6 +180,9 @@ def to_df(node: UnqualifiedNode, **kwargs) -> pd.DataFrame: """ column_selection: list[tuple[str, str]] | None = _load_column_selection(kwargs) max_rows: int | None = kwargs.pop("max_rows", None) + assert (isinstance(max_rows, int) and max_rows > 0) or max_rows is None, ( + "`max_rows` must be a positive integer or None." + ) display_sql: bool = bool(kwargs.pop("display_sql", False)) session: PyDoughSession = _load_session_info(**kwargs) qualified: PyDoughQDAG = qualify_node(node, session) diff --git a/tests/test_max_rows.py b/tests/test_max_rows.py index 573681f6a..7453127f8 100644 --- a/tests/test_max_rows.py +++ b/tests/test_max_rows.py @@ -76,6 +76,7 @@ } ), "nations_top3_max2", + order_sensitive=True, ), 2, id="nations_top3_max2", @@ -91,6 +92,7 @@ } ), "nations_top3_max6", + order_sensitive=True, ), 6, id="nations_top3_max6", @@ -112,6 +114,22 @@ 5, id="richest_customers_orders_max5", ), + pytest.param( + PyDoughPandasTest( + "result = nations.TOP_K(8, by=name.ASC()).WHERE(region.name != 'AMERICA').CALCULATE(key, name)", + "TPCH", + lambda: pd.DataFrame( + { + "key": [0, 18, 4], + "name": ["ALGERIA", "CHINA", "EGYPT"], + } + ), + "nested_topk_nations_max_3", + order_sensitive=True, + ), + 3, + id="nested_topk_nations_max_3", + ), ], ) def test_max_rows( From b9abf1d24c760fd09f3b02f5735fbd238192012b Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Sun, 21 Dec 2025 10:00:15 -0800 Subject: [PATCH 11/12] [RUN CI] From 62dd3955e2db1d7d183cb72eb7663b3819e5e589 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 23 Dec 2025 09:52:59 -0800 Subject: [PATCH 12/12] Revisions [RUN CI] --- pydough/metadata/parse.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pydough/metadata/parse.py b/pydough/metadata/parse.py index bb1c67b1c..1260905e1 100644 --- a/pydough/metadata/parse.py +++ b/pydough/metadata/parse.py @@ -98,14 +98,14 @@ def parse_graph_v2(graph_name: str, graph_json: dict) -> GraphMetadata: """ verified_analysis: list[dict] = [] additional_definitions: list[str] = [] - extra_info: dict = {} + extra_semantic_info: dict = {} graph: GraphMetadata = GraphMetadata( graph_name, additional_definitions, verified_analysis, None, None, - extra_info, + extra_semantic_info, ) # Parse and extract the metadata for all of the collections in the graph. @@ -163,7 +163,7 @@ def parse_graph_v2(graph_name: str, graph_json: dict) -> GraphMetadata: extra_info_json: dict = extract_object( graph_json, "extra semantic info", graph.error_name ) - extra_info.update(extra_info_json) + extra_semantic_info.update(extra_info_json) # Add all of the UDF definitions to the graph. if "functions" in graph_json: