diff --git a/documentation/dsl.md b/documentation/dsl.md index bf847bd18..efa0cda15 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -19,10 +19,8 @@ This page describes the specification of the PyDough DSL. The specification incl * [SINGULAR](#singular) * [BEST](#best) * [CROSS](#cross) -- [Induced Properties](#induced-properties) - * [Induced Scalar Properties](#induced-scalar-properties) - * [Induced Subcollection Properties](#induced-subcollection-properties) - * [Induced Arbitrary Joins](#induced-arbitrary-joins) +- [User Generated Collections](#user-generated-collections) + * [range_collection](#range_collection) - [Larger Examples](#larger-examples) * [Example 1: Highest Residency Density States](#example-1-highest-residency-density-states) * [Example 2: Yearly Trans-Coastal Shipments](#example-2-yearly-trans-coastal-shipments) @@ -1536,25 +1534,49 @@ People.CALCULATE(Packages=COUNT(People.packages)).CROSS(Packages) People.CROSS(Addresses).current_address ``` - -## Induced Properties + +## User Generated Collections -This section of the PyDough specification has not yet been defined. +> [!WARNING] +> NOTE: User collections are currently supported **only in the Snowflake context**. - -### Induced Scalar Properties +This section describes APIs for dynamically creating PyDough collections and using them alongside other data sources. -This section of the PyDough specification has not yet been defined. + +### `pydough.range_collection` - -### Induced Subcollection Properties +The `range_collection` creates a collection that generates a sequence of integers within a specified range. This is useful for building numeric datasets dynamically. +It takes in the following arguments: -This section of the PyDough specification has not yet been defined. +- `name`: The name of the range collection. +- `column_name`: The name of the column in the collection. +- `start`: The starting value of the range (inclusive). +- `end`: The ending value of the range (exclusive). +- `step`: The increment between consecutive values (default: 1). - -### Induced Arbitrary Joins +Supported Signatures: +- `range_collection(name, column_name, end)`: generates integers from 0 to `end-1` with a step of 1. +- `range_collection(name, column_name, start, end)`: generates integers from `start` to `end-1` with a step of 1. +- `range_collection(name, column_name, start, end, step)`: generates integers from `start` to `end-1` with the specified `step`. -This section of the PyDough specification has not yet been defined. +#### Example + +```python +import pydough + +my_range = pydough.range_collection("simple_range", "col1", 1, 10, 2) +df = pydough.to_df(my_range) +print(df) +``` +Output: +``` + col1 +0 1 +1 3 +2 5 +3 7 +4 9 +``` ## Larger Examples diff --git a/documentation/usage.md b/documentation/usage.md index a649f51eb..57bc2da95 100644 --- a/documentation/usage.md +++ b/documentation/usage.md @@ -442,7 +442,7 @@ You can find a full example of using Postgres database with PyDough in [this usa ## Evaluation APIs -This sections describes various APIs you can use to execute PyDough code. +This section describes various APIs you can use to execute PyDough code. ### `pydough.to_sql` @@ -548,10 +548,10 @@ pydough.to_df(result, columns={"name": "name", "n_custs": "n"}) See the [demo notebooks](../demos/notebooks/1_introduction.ipynb) for more instances of how to use the `to_df` API. - + ## Transformation APIs -This sections describes various APIs you can use to transform PyDough source code into a result that can be used as input for other evaluation or exploration APIs. +This section describes various APIs you can use to transform PyDough source code into a result that can be used as input for other evaluation or exploration APIs. ### `pydough.from_string` @@ -726,7 +726,7 @@ ORDER BY ## Exploration APIs -This sections describes various APIs you can use to explore PyDough code and figure out what each component is doing without having PyDough fully evaluate it. The following APIs take an optional `config` argument which can be used to specify the PyDough configuration settings to use for the exploration. +This section describes various APIs you can use to explore PyDough code and figure out what each component is doing without having PyDough fully evaluate it. The following APIs take an optional `config` argument which can be used to specify the PyDough configuration settings to use for the exploration. See the [demo notebooks](../demos/notebooks/2_exploration.ipynb) for more instances of how to use the exploration APIs. diff --git a/pydough/__init__.py b/pydough/__init__.py index 68d9f7618..d38291255 100644 --- a/pydough/__init__.py +++ b/pydough/__init__.py @@ -12,6 +12,7 @@ "get_logger", "init_pydough_context", "parse_json_metadata_from_file", + "range_collection", "to_df", "to_sql", ] @@ -22,6 +23,7 @@ from .logger import get_logger from .metadata import parse_json_metadata_from_file from .unqualified import display_raw, from_string, init_pydough_context +from .user_collections.user_collection_apis import range_collection # Create a default session for the user to interact with. # In most situations users will just use this session and diff --git a/pydough/conversion/agg_removal.py b/pydough/conversion/agg_removal.py index 49c0ee14e..864be2ee0 100644 --- a/pydough/conversion/agg_removal.py +++ b/pydough/conversion/agg_removal.py @@ -12,6 +12,7 @@ CallExpression, EmptySingleton, Filter, + GeneratedTable, Join, JoinType, Limit, @@ -276,7 +277,7 @@ def aggregation_uniqueness_helper( ) return node, final_uniqueness # Empty singletons don't have uniqueness information. - case EmptySingleton(): + case EmptySingleton() | GeneratedTable(): return node, set() case _: raise NotImplementedError( diff --git a/pydough/conversion/filter_pushdown.py b/pydough/conversion/filter_pushdown.py index feb94d1b3..e25c367c4 100644 --- a/pydough/conversion/filter_pushdown.py +++ b/pydough/conversion/filter_pushdown.py @@ -13,6 +13,7 @@ ColumnReference, EmptySingleton, Filter, + GeneratedTable, Join, JoinCardinality, JoinType, @@ -306,6 +307,11 @@ def visit_empty_singleton(self, empty_singleton: EmptySingleton) -> RelationalNo # cannot be pushed down any further. return self.flush_remaining_filters(empty_singleton, self.filters, set()) + def visit_generated_table(self, generated_table: GeneratedTable) -> RelationalNode: + # Materialize all filters before the user generated table, since they + # cannot be pushed down any further. + return self.flush_remaining_filters(generated_table, self.filters, set()) + def push_filters(node: RelationalNode, session: PyDoughSession) -> RelationalNode: """ diff --git a/pydough/conversion/hybrid_operations.py b/pydough/conversion/hybrid_operations.py index 620c7f9bb..452147758 100644 --- a/pydough/conversion/hybrid_operations.py +++ b/pydough/conversion/hybrid_operations.py @@ -16,6 +16,7 @@ "HybridPartition", "HybridPartitionChild", "HybridRoot", + "HybridUserGeneratedCollection", ] @@ -27,6 +28,9 @@ ColumnProperty, PyDoughExpressionQDAG, ) +from pydough.qdag.collections.user_collection_qdag import ( + PyDoughUserGeneratedCollectionQDag, +) from .hybrid_connection import HybridConnection from .hybrid_expressions import ( @@ -483,3 +487,34 @@ def __repr__(self): def search_term_definition(self, name: str) -> HybridExpr | None: return self.predecessor.search_term_definition(name) + + +class HybridUserGeneratedCollection(HybridOperation): + """ + Class for HybridOperation corresponding to a user-generated collection. + """ + + def __init__(self, user_collection: PyDoughUserGeneratedCollectionQDag): + """ + Args: + `collection`: the QDAG node for the user-generated collection. + """ + self._user_collection: PyDoughUserGeneratedCollectionQDag = user_collection + terms: dict[str, HybridExpr] = {} + for name, typ in user_collection.collection.column_names_and_types: + terms[name] = HybridRefExpr(name, typ) + unique_exprs: list[HybridExpr] = [] + for name in sorted(self.user_collection.unique_terms, key=str): + expr: PyDoughExpressionQDAG = self.user_collection.get_expr(name) + unique_exprs.append(HybridRefExpr(name, expr.pydough_type)) + super().__init__(terms, {}, [], unique_exprs) + + @property + def user_collection(self) -> PyDoughUserGeneratedCollectionQDag: + """ + The user-generated collection that this hybrid operation represents. + """ + return self._user_collection + + def __repr__(self): + return self.user_collection.to_string() diff --git a/pydough/conversion/hybrid_translator.py b/pydough/conversion/hybrid_translator.py index a6834f71e..c8e1e1617 100644 --- a/pydough/conversion/hybrid_translator.py +++ b/pydough/conversion/hybrid_translator.py @@ -43,6 +43,9 @@ Where, WindowCall, ) +from pydough.qdag.collections.user_collection_qdag import ( + PyDoughUserGeneratedCollectionQDag, +) from pydough.types import BooleanType, NumericType from .hybrid_connection import ConnectionType, HybridConnection @@ -71,6 +74,7 @@ HybridPartition, HybridPartitionChild, HybridRoot, + HybridUserGeneratedCollection, ) from .hybrid_syncretizer import HybridSyncretizer from .hybrid_tree import HybridTree @@ -1308,6 +1312,9 @@ def define_root_link( case HybridRoot(): # A root does not need to be joined to its parent join_keys = [] + case HybridUserGeneratedCollection(): + # A user-generated collection does not need to be joined to its parent + join_keys = [] case _: raise NotImplementedError(f"{operation.__class__.__name__}") if join_keys is not None: @@ -1533,6 +1540,18 @@ def make_hybrid_tree( HybridLimit(hybrid.pipeline[-1], node.records_to_keep) ) return hybrid + case PyDoughUserGeneratedCollectionQDag(): + # A user-generated collection is a special case of a collection + # access that is not a sub-collection, but rather a user-defined + # collection that is defined in the PyDough user collections. + hybrid_collection = HybridUserGeneratedCollection(node) + # Create a new hybrid tree for the user-generated collection. + successor_hybrid = HybridTree(hybrid_collection, node.ancestral_mapping) + hybrid = self.make_hybrid_tree( + node.ancestor_context, parent, is_aggregate + ) + hybrid.add_successor(successor_hybrid) + return successor_hybrid case ChildOperatorChildAccess(): assert parent is not None match node.child_access: @@ -1605,6 +1624,17 @@ def make_hybrid_tree( successor_hybrid = HybridTree( HybridRoot(), node.ancestral_mapping ) + case PyDoughUserGeneratedCollectionQDag(): + # A user-generated collection is a special case of a collection + # access that is not a sub-collection, but rather a user-defined + # collection that is defined in the PyDough user collections. + hybrid_collection = HybridUserGeneratedCollection( + node.child_access + ) + # Create a new hybrid tree for the user-generated collection. + successor_hybrid = HybridTree( + hybrid_collection, node.ancestral_mapping + ) case _: raise NotImplementedError( f"{node.__class__.__name__} (child is {node.child_access.__class__.__name__})" diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 7c50c5e6b..93201e0ac 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -48,6 +48,7 @@ HybridPartition, HybridPartitionChild, HybridRoot, + HybridUserGeneratedCollection, ) @@ -792,6 +793,8 @@ def always_exists(self) -> bool: # Stepping into a partition child always has a matching data # record for each parent, by definition. pass + case HybridUserGeneratedCollection(): + return start_operation.user_collection.collection.always_exists() case _: raise NotImplementedError( f"Invalid start of pipeline: {start_operation.__class__.__name__}" @@ -842,6 +845,8 @@ def is_singular(self) -> bool: case HybridChildPullUp(): if not self.children[self.pipeline[0].child_idx].subtree.is_singular(): return False + case HybridUserGeneratedCollection(): + return self.pipeline[0].user_collection.collection.is_singular() case HybridRoot(): pass case _: diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index a10e608ac..61fb70501 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -37,6 +37,7 @@ EmptySingleton, ExpressionSortInfo, Filter, + GeneratedTable, Join, JoinCardinality, JoinType, @@ -82,6 +83,7 @@ HybridPartition, HybridPartitionChild, HybridRoot, + HybridUserGeneratedCollection, ) from .hybrid_translator import HybridTranslator from .hybrid_tree import HybridTree @@ -1267,6 +1269,29 @@ def translate_hybridroot(self, context: TranslationOutput) -> TranslationOutput: new_expressions[shifted_expr] = column_ref return TranslationOutput(context.relational_node, new_expressions) + def build_user_generated_table( + self, node: HybridUserGeneratedCollection + ) -> TranslationOutput: + """Builds a user-generated table from the given hybrid user-generated collection. + + Args: + `node`: The user-generated collection node to translate. + + Returns: + The translated output payload. + """ + collection = node._user_collection.collection + out_columns: dict[HybridExpr, ColumnReference] = {} + gen_columns: dict[str, RelationalExpression] = {} + for column_name, column_type in collection.column_names_and_types: + hybrid_ref = HybridRefExpr(column_name, column_type) + col_ref = ColumnReference(column_name, column_type) + out_columns[hybrid_ref] = col_ref + gen_columns[column_name] = col_ref + + answer = GeneratedTable(collection) + return TranslationOutput(answer, out_columns) + def rel_translation( self, hybrid: HybridTree, @@ -1395,6 +1420,19 @@ def rel_translation( case HybridRoot(): assert context is not None, "Malformed HybridTree pattern." result = self.translate_hybridroot(context) + case HybridUserGeneratedCollection(): + assert context is not None, "Malformed HybridTree pattern." + result = self.build_user_generated_table(operation) + result = self.join_outputs( + context, + result, + JoinType.INNER, + JoinCardinality.PLURAL_ACCESS, + JoinCardinality.SINGULAR_ACCESS, + [], + None, + None, + ) case _: raise NotImplementedError( f"TODO: support relational conversion on {operation.__class__.__name__}" diff --git a/pydough/conversion/relational_simplification.py b/pydough/conversion/relational_simplification.py index 4eaf33e7a..44e7c1a22 100644 --- a/pydough/conversion/relational_simplification.py +++ b/pydough/conversion/relational_simplification.py @@ -24,6 +24,7 @@ CorrelatedReference, EmptySingleton, Filter, + GeneratedTable, Join, JoinType, Limit, @@ -1552,6 +1553,12 @@ def visit_empty_singleton(self, node: EmptySingleton) -> None: ) self.stack.append(output_predicates) + def visit_generated_table(self, node: GeneratedTable) -> None: + output_predicates: dict[RelationalExpression, PredicateSet] = ( + self.generic_visit(node) + ) + self.stack.append(output_predicates) + def visit_project(self, node: Project) -> None: output_predicates: dict[RelationalExpression, PredicateSet] = ( self.generic_visit(node) diff --git a/pydough/qdag/collections/README.md b/pydough/qdag/collections/README.md index f312bef8e..cb00a1f67 100644 --- a/pydough/qdag/collections/README.md +++ b/pydough/qdag/collections/README.md @@ -16,6 +16,7 @@ The QDAG collections module contains the following hierarchy of collection class - [`TableCollection`](table_collection.py) (concrete): Accessing a table collection directly. - [`SubCollection`](sub_collection.py) (concrete): Accessing a subcolleciton of another collection. - [`CompoundSubCollection`](sub_collection.py) (concrete): Accessing a subcollection of another collection where the subcollection property is a compound relationship. + - [`PyDoughUserGeneratedCollectionQDag`](user_generated_collection_qdag.py) (concrete): Accessing a user-generated collection. - [`ChildOperator`](child_operator.py) (abstract): Base class for collection QDAG nodes that need to access child contexts in order to make a child reference. - [`Calculate`](calculate.py) (concrete): Operation that defines new singular expression terms in the current context and names them. - [`Where`](where.py) (concrete): Operation that filters the current context based on a predicate that is a singular expression. diff --git a/pydough/qdag/collections/__init__.py b/pydough/qdag/collections/__init__.py index 6eaab3bd8..e9c28db18 100644 --- a/pydough/qdag/collections/__init__.py +++ b/pydough/qdag/collections/__init__.py @@ -21,6 +21,7 @@ "TableCollection", "TopK", "Where", + "range_collection", ] from .augmenting_child_operator import AugmentingChildOperator diff --git a/pydough/qdag/collections/user_collection_qdag.py b/pydough/qdag/collections/user_collection_qdag.py new file mode 100644 index 000000000..560a1097d --- /dev/null +++ b/pydough/qdag/collections/user_collection_qdag.py @@ -0,0 +1,135 @@ +from functools import cache + +from pydough.errors import PyDoughQDAGException +from pydough.qdag import PyDoughCollectionQDAG +from pydough.qdag.abstract_pydough_qdag import PyDoughQDAG +from pydough.qdag.expressions.back_reference_expression import BackReferenceExpression +from pydough.qdag.expressions.reference import Reference +from pydough.types import NumericType +from pydough.user_collections.user_collections import PyDoughUserGeneratedCollection + +from .child_access import ChildAccess + + +class PyDoughUserGeneratedCollectionQDag(ChildAccess): + def __init__( + self, + ancestor: PyDoughCollectionQDAG, + collection: PyDoughUserGeneratedCollection, + ): + assert ancestor is not None + super().__init__(ancestor) + self._collection: PyDoughUserGeneratedCollection = collection + self._all_property_names: set[str] = set() + self._ancestral_mapping: dict[str, int] = { + name: level + 1 for name, level in ancestor.ancestral_mapping.items() + } + self._all_property_names.update(self._ancestral_mapping) + self._all_property_names.update(self.calc_terms) + + def clone_with_parent( + self, new_ancestor: PyDoughCollectionQDAG + ) -> "PyDoughUserGeneratedCollectionQDag": + """ + Copies `self` but with a new ancestor node that presumably has the + original ancestor in its predecessor chain. + + Args: + `new_ancestor`: the node to use as the new parent of the clone. + + Returns: + The cloned version of `self`. + """ + return PyDoughUserGeneratedCollectionQDag(new_ancestor, self._collection) + + @property + def collection(self) -> PyDoughUserGeneratedCollection: + """ + The metadata for the table that is being referenced by the collection + node. + """ + return self._collection + + @property + def name(self) -> str: + return self.collection.name + + @property + def calc_terms(self) -> set[str]: + return set(self.collection.columns) + + @property + def ancestral_mapping(self) -> dict[str, int]: + return self._ancestral_mapping + + @property + def inherited_downstreamed_terms(self) -> set[str]: + return self.ancestor_context.inherited_downstreamed_terms + + @cache + def get_term(self, term_name: str) -> PyDoughQDAG: + # Special handling of terms down-streamed + if term_name in self.ancestral_mapping: + # Verify that the ancestor name is not also a name in the current + # context. + if term_name in self.calc_terms: + raise PyDoughQDAGException( + f"Cannot have term name {term_name!r} used in an ancestor of collection {self!r}" + ) + # Create a back-reference to the ancestor term. + return BackReferenceExpression( + self, term_name, self.ancestral_mapping[term_name] + ) + + if term_name in self.inherited_downstreamed_terms: + context: PyDoughCollectionQDAG = self + while term_name not in context.all_terms: + if context is self: + context = self.ancestor_context + else: + assert context.ancestor_context is not None + context = context.ancestor_context + return Reference( + context, term_name, context.get_expr(term_name).pydough_type + ) + + if term_name not in self.all_terms: + raise PyDoughQDAGException(self.name_mismatch_error(term_name)) + + return Reference(self, term_name, NumericType()) + + @property + def all_terms(self) -> set[str]: + """ + The set of expression/subcollection names accessible by the context. + """ + return self._all_property_names + + def is_singular(self, context: "PyDoughCollectionQDAG") -> bool: + return False + + def get_expression_position(self, expr_name: str) -> int: + if expr_name not in self.calc_terms: + raise PyDoughQDAGException( + f"Unrecognized User Collection term: {expr_name!r}" + ) + return self.collection.get_expression_position(expr_name) + + @property + def unique_terms(self) -> list[str]: + return self.collection.unique_column_names + + @property + def standalone_string(self) -> str: + return self.to_string() + + @property + def key(self) -> str: + return f"USER_GENERATED_COLLECTION-{self.name}" + + def to_string(self) -> str: + return f"UserCollection[{self.collection.to_string()}]" + + @property + def tree_item_string(self) -> str: + return self.collection.to_string() diff --git a/pydough/qdag/node_builder.py b/pydough/qdag/node_builder.py index d48cb0300..06b7fcb30 100644 --- a/pydough/qdag/node_builder.py +++ b/pydough/qdag/node_builder.py @@ -17,7 +17,11 @@ PyDoughOperator, builtin_registered_operators, ) +from pydough.qdag.collections.user_collection_qdag import ( + PyDoughUserGeneratedCollectionQDag, +) from pydough.types import PyDoughType +from pydough.user_collections.user_collections import PyDoughUserGeneratedCollection from .abstract_pydough_qdag import PyDoughQDAG from .collections import ( @@ -396,3 +400,26 @@ def build_singular( The newly created PyDough SINGULAR instance. """ return Singular(preceding_context) + + def build_generated_collection( + self, + preceding_context: PyDoughCollectionQDAG, + user_collection: PyDoughUserGeneratedCollection, + ) -> PyDoughUserGeneratedCollectionQDag: + """ + Creates a new user-defined collection. + + Args: + `preceding_context`: the preceding collection that the + user-defined collection is based on. + `user_collection`: the user-defined collection to be created. + + Returns: + The newly created user-defined collection. + """ + collection_qdag: PyDoughUserGeneratedCollectionQDag = ( + PyDoughUserGeneratedCollectionQDag( + ancestor=preceding_context, collection=user_collection + ) + ) + return collection_qdag diff --git a/pydough/relational/__init__.py b/pydough/relational/__init__.py index bbe146dbe..98772e97a 100644 --- a/pydough/relational/__init__.py +++ b/pydough/relational/__init__.py @@ -9,6 +9,7 @@ "EmptySingleton", "ExpressionSortInfo", "Filter", + "GeneratedTable", "Join", "JoinCardinality", "JoinType", @@ -46,6 +47,7 @@ ColumnPruner, EmptySingleton, Filter, + GeneratedTable, Join, JoinCardinality, JoinType, diff --git a/pydough/relational/relational_nodes/README.md b/pydough/relational/relational_nodes/README.md index cbc42dee9..dfddb65cc 100644 --- a/pydough/relational/relational_nodes/README.md +++ b/pydough/relational/relational_nodes/README.md @@ -32,6 +32,10 @@ The relational_nodes module provides functionality to define and manage various - `Filter`: The relational node representing a filter operation in the relational tree. +### [generated_table.py](generated_table.py) + +- `GeneratedTable`: The relational node representing a generated table collection in the relational tree. + ### [project.py](project.py) - `Project`: The relational node representing a project operation in the relational tree. diff --git a/pydough/relational/relational_nodes/__init__.py b/pydough/relational/relational_nodes/__init__.py index 736656cf7..b16c6f9f1 100644 --- a/pydough/relational/relational_nodes/__init__.py +++ b/pydough/relational/relational_nodes/__init__.py @@ -8,6 +8,7 @@ "ColumnPruner", "EmptySingleton", "Filter", + "GeneratedTable", "Join", "JoinCardinality", "JoinType", @@ -26,6 +27,7 @@ from .column_pruner import ColumnPruner from .empty_singleton import EmptySingleton from .filter import Filter +from .generated_table import GeneratedTable from .join import Join, JoinCardinality, JoinType from .join_type_relational_visitor import JoinTypeRelationalVisitor from .limit import Limit diff --git a/pydough/relational/relational_nodes/generated_table.py b/pydough/relational/relational_nodes/generated_table.py new file mode 100644 index 000000000..6774fffea --- /dev/null +++ b/pydough/relational/relational_nodes/generated_table.py @@ -0,0 +1,71 @@ +""" +This file contains the relational implementation for a "generatedtable" node, +which generally represents user generated table. +""" + +from typing import TYPE_CHECKING + +from pydough.relational.relational_expressions import ( + RelationalExpression, +) +from pydough.relational.relational_expressions.column_reference import ColumnReference +from pydough.user_collections.user_collections import PyDoughUserGeneratedCollection + +from .abstract_node import RelationalNode + +if TYPE_CHECKING: + from .relational_shuttle import RelationalShuttle + + +class GeneratedTable(RelationalNode): + """ + The GeneratedTable node in the relational tree. Represents + a user-generated table stored locally which is assumed to be singular + and always available. + """ + + def __init__( + self, + user_collection: PyDoughUserGeneratedCollection, + ) -> None: + columns: dict[str, RelationalExpression] = { + col_name: ColumnReference(col_name, col_type) + for col_name, col_type in user_collection.column_names_and_types + } + super().__init__(columns) + self._collection = user_collection + + @property + def inputs(self) -> list[RelationalNode]: + return [] + + @property + def name(self) -> str: + """Returns the name of the generated table.""" + return self.collection.name + + @property + def collection(self) -> PyDoughUserGeneratedCollection: + """ + The user-generated collection that this generated table represents. + """ + return self._collection + + def node_equals(self, other: RelationalNode) -> bool: + return isinstance(other, GeneratedTable) and self.collection == other.collection + + def accept(self, visitor: "RelationalVisitor") -> None: # type: ignore # noqa + visitor.visit_generated_table(self) + + def accept_shuttle(self, shuttle: "RelationalShuttle") -> RelationalNode: + return shuttle.visit_generated_table(self) + + def to_string(self, compact=False) -> str: + return f"GENERATED_TABLE({self.collection})" + + def node_copy( + self, + columns: dict[str, RelationalExpression], + inputs: list[RelationalNode], + ) -> RelationalNode: + return GeneratedTable(self.collection) diff --git a/pydough/relational/relational_nodes/join_type_relational_visitor.py b/pydough/relational/relational_nodes/join_type_relational_visitor.py index 2f402337c..9ae1ee926 100644 --- a/pydough/relational/relational_nodes/join_type_relational_visitor.py +++ b/pydough/relational/relational_nodes/join_type_relational_visitor.py @@ -35,6 +35,9 @@ def visit_inputs(self, node) -> None: def visit_scan(self, scan: Scan) -> None: pass + def visit_generated_table(self, generated_table) -> None: + pass + def visit_join(self, join: Join) -> None: """ Visit a Join node, collecting join types. diff --git a/pydough/relational/relational_nodes/relational_expression_dispatcher.py b/pydough/relational/relational_nodes/relational_expression_dispatcher.py index b296b9869..e5ad70ba8 100644 --- a/pydough/relational/relational_nodes/relational_expression_dispatcher.py +++ b/pydough/relational/relational_nodes/relational_expression_dispatcher.py @@ -77,3 +77,6 @@ def visit_root(self, root: RelationalRoot) -> None: self.visit_common(root) for order in root.orderings: order.expr.accept(self._expr_visitor) + + def visit_generated_table(self, generated_table) -> None: + self.visit_common(generated_table) diff --git a/pydough/relational/relational_nodes/relational_shuttle.py b/pydough/relational/relational_nodes/relational_shuttle.py index 3e1b4482a..2dadcca92 100644 --- a/pydough/relational/relational_nodes/relational_shuttle.py +++ b/pydough/relational/relational_nodes/relational_shuttle.py @@ -10,6 +10,7 @@ from .aggregate import Aggregate from .empty_singleton import EmptySingleton from .filter import Filter +from .generated_table import GeneratedTable from .join import Join from .limit import Limit from .project import Project @@ -110,6 +111,15 @@ def visit_empty_singleton(self, singleton: EmptySingleton) -> RelationalNode: """ return singleton + def visit_generated_table(self, generated_table: GeneratedTable) -> RelationalNode: + """ + Visit a user GeneratedTable node. + + Args: + `generated_table`: The generated table node to visit. + """ + return self.generic_visit_inputs(generated_table) + def visit_root(self, root: RelationalRoot) -> RelationalNode: """ Visit a root node. diff --git a/pydough/relational/relational_nodes/relational_visitor.py b/pydough/relational/relational_nodes/relational_visitor.py index 2a138b719..7d5530383 100644 --- a/pydough/relational/relational_nodes/relational_visitor.py +++ b/pydough/relational/relational_nodes/relational_visitor.py @@ -119,3 +119,12 @@ def visit_root(self, root: RelationalRoot) -> None: Args: `root`: The root node to visit. """ + + @abstractmethod + def visit_generated_table(self, generated_table) -> None: + """ + Visit a GeneratedTable node. + + Args: + `generated_table`: The generated table node to visit. + """ diff --git a/pydough/relational/relational_nodes/tree_string_visitor.py b/pydough/relational/relational_nodes/tree_string_visitor.py index d47723c80..3dc56eda3 100644 --- a/pydough/relational/relational_nodes/tree_string_visitor.py +++ b/pydough/relational/relational_nodes/tree_string_visitor.py @@ -62,3 +62,6 @@ def visit_empty_singleton(self, empty_singleton) -> None: def visit_root(self, root) -> None: self.visit_node(root) + + def visit_generated_table(self, root) -> None: + self.visit_node(root) diff --git a/pydough/sqlglot/override_merge_subqueries.py b/pydough/sqlglot/override_merge_subqueries.py index 3cfb4d85e..22d6fea4b 100644 --- a/pydough/sqlglot/override_merge_subqueries.py +++ b/pydough/sqlglot/override_merge_subqueries.py @@ -188,6 +188,21 @@ def invalid_aggregate_convolution(inner_scope: Scope, outer_scope: Scope) -> boo return result +def has_seq4_or_table(expr: Scope) -> bool: + """Check if the expression contains SEQ4() or TABLE(). + + Args: + `expr` (Scope): The SQLGlot expression walk and check. + + Returns: + True if SEQ4() or TABLE() is found, False otherwise. + """ + for e in expr.walk(): + if isinstance(e, exp.Anonymous) and e.this.upper() in {"SEQ4", "TABLE"}: + return True + return False + + def _mergeable( outer_scope: Scope, inner_scope: Scope, @@ -309,4 +324,25 @@ def _is_recursive(): and not _is_a_window_expression_in_unmergable_operation() and not _is_recursive() and not (inner_select.args.get("order") and outer_scope.is_union) + # PYDOUGH CHANGE: avoid merging CTEs when the inner scope uses + # SEQ4()/TABLE() and if any of these exist in the outer query: + # - joins + # - window functions + # - aggregations + # - limit/offset + # - where/having/qualify clauses + # - group by + and not ( + has_seq4_or_table(inner_scope.expression) + and ( + outer_scope.expression.args.get("joins") is not None + or outer_scope.expression.find(exp.Window) + or outer_scope.expression.find(exp.Limit) + or outer_scope.expression.find(exp.AggFunc) + or outer_scope.expression.find(exp.Where) + or outer_scope.expression.find(exp.Having) + or outer_scope.expression.find(exp.Qualify) + or outer_scope.expression.find(exp.Group) + ) + ) ) diff --git a/pydough/sqlglot/override_qualify.py b/pydough/sqlglot/override_qualify.py index a0aa861b2..ddd8b1fff 100644 --- a/pydough/sqlglot/override_qualify.py +++ b/pydough/sqlglot/override_qualify.py @@ -211,7 +211,13 @@ def _qualify(table: exp.Table) -> None: # PYDOUGH CHANGE: preserve quoting from the original table name # Example: keywords."CAST" should become keywords."CAST" AS "CAST" - quoted = source.this.quoted + # Only do this if the source is not an Anonymous expression + # e.g. TABLE(GENERATOR(...)) is not a named table + quoted = ( + source.this.quoted + if not isinstance(source.this, exp.Anonymous) + else quoted + ) # Mutates the source by attaching an alias to it # PYDOUGH CHANGE: pass along quoting information diff --git a/pydough/sqlglot/sqlglot_helpers.py b/pydough/sqlglot/sqlglot_helpers.py index ea1e76cad..0c4111423 100644 --- a/pydough/sqlglot/sqlglot_helpers.py +++ b/pydough/sqlglot/sqlglot_helpers.py @@ -3,9 +3,8 @@ that can act as wrappers around the internal implementation of SQLGlot. """ -from sqlglot.expressions import ( - Alias as SQLGlotAlias, -) +from sqlglot.expressions import Alias as SQLGlotAlias +from sqlglot.expressions import Column as SQLGlotColumn from sqlglot.expressions import Expression as SQLGlotExpression from sqlglot.expressions import ( Identifier, @@ -33,6 +32,10 @@ def get_glot_name(expr: SQLGlotExpression) -> str | None: return expr.alias elif isinstance(expr, Identifier): return expr.this + if isinstance(expr, SQLGlotColumn): + if isinstance(expr.this, Identifier): + return expr.this.this + return expr.this else: return None diff --git a/pydough/sqlglot/sqlglot_relational_visitor.py b/pydough/sqlglot/sqlglot_relational_visitor.py index 7555021d8..ef5a9eb91 100644 --- a/pydough/sqlglot/sqlglot_relational_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_visitor.py @@ -9,7 +9,13 @@ from sqlglot.expressions import Alias as SQLGlotAlias from sqlglot.expressions import Column as SQLGlotColumn from sqlglot.expressions import Expression as SQLGlotExpression -from sqlglot.expressions import Identifier, Select, Subquery, TableAlias, values +from sqlglot.expressions import ( + Identifier, + Select, + Subquery, + TableAlias, + values, +) from sqlglot.expressions import Literal as SQLGlotLiteral from sqlglot.expressions import Null as SQLGlotNull from sqlglot.expressions import Star as SQLGlotStar @@ -26,6 +32,7 @@ EmptySingleton, ExpressionSortInfo, Filter, + GeneratedTable, Join, Limit, LiteralExpression, @@ -568,6 +575,14 @@ def visit_root(self, root: RelationalRoot) -> None: query = query.limit(limit_expr) self._stack.append(query) + def visit_generated_table(self, generated_table: "GeneratedTable") -> None: + query: SQLGlotExpression = ( + self._expr_visitor._bindings.convert_user_generated_collection( + generated_table.collection + ) + ) + self._stack.append(query) + def relational_to_sqlglot(self, root: RelationalRoot) -> SQLGlotExpression: """ Interface to convert an entire relational tree to a SQLGlot expression. diff --git a/pydough/sqlglot/transform_bindings/base_transform_bindings.py b/pydough/sqlglot/transform_bindings/base_transform_bindings.py index dd4f1e6a5..5792abd1e 100644 --- a/pydough/sqlglot/transform_bindings/base_transform_bindings.py +++ b/pydough/sqlglot/transform_bindings/base_transform_bindings.py @@ -17,6 +17,8 @@ from pydough.configs import DayOfWeek, PyDoughConfigs from pydough.errors import PyDoughSQLException from pydough.types import BooleanType, NumericType, PyDoughType, StringType +from pydough.user_collections.range_collection import RangeGeneratedCollection +from pydough.user_collections.user_collections import PyDoughUserGeneratedCollection from .sqlglot_transform_utils import ( DateTimeUnit, @@ -2153,3 +2155,40 @@ def convert_ordering( A SQLGlotExpression representing the order key transformed in any necessary way. """ return arg + + def convert_user_generated_collection( + self, + collection: PyDoughUserGeneratedCollection, + ) -> SQLGlotExpression: + """ + Converts a user-generated collection (e.g., range or dataframe) into a SQLGlot expression. + + Args: + `collection`: The user-generated collection to convert. + + Returns: + A SQLGlotExpression representing the user-generated collection. + """ + + match collection: + case RangeGeneratedCollection(): + return self.convert_user_generated_range(collection) + case _: + raise PyDoughSQLException( + f"Unsupported user-generated collection type: {type(collection)}" + ) + + def convert_user_generated_range( + self, + collection: RangeGeneratedCollection, + ) -> SQLGlotExpression: + """ + Converts a user-generated range into a SQLGlot expression. + Args: + `collection`: The user-generated range to convert. + Returns: + A SQLGlotExpression representing the user-generated range as table. + """ + raise NotImplementedError( + "range_collections are not supported for this dialect" + ) diff --git a/pydough/sqlglot/transform_bindings/sf_transform_bindings.py b/pydough/sqlglot/transform_bindings/sf_transform_bindings.py index 1796a329b..e2e4c72eb 100644 --- a/pydough/sqlglot/transform_bindings/sf_transform_bindings.py +++ b/pydough/sqlglot/transform_bindings/sf_transform_bindings.py @@ -5,12 +5,15 @@ __all__ = ["SnowflakeTransformBindings"] +import math + import sqlglot.expressions as sqlglot_expressions from sqlglot.expressions import Expression as SQLGlotExpression import pydough.pydough_operators as pydop from pydough.types import PyDoughType from pydough.types.boolean_type import BooleanType +from pydough.user_collections.range_collection import RangeGeneratedCollection from .base_transform_bindings import BaseTransformBindings from .sqlglot_transform_utils import DateTimeUnit @@ -162,3 +165,118 @@ def convert_datediff( else: # For other units, use base implementation return super().convert_datediff(args, types) + + def convert_user_generated_range( + self, collection: RangeGeneratedCollection + ) -> SQLGlotExpression: + """ + Converts a user-generated range collection to its Snowflake SQLGlot + representation. + Arguments: + `collection` : The user-generated range collection to convert. + Returns: + A SQLGlotExpression representing the user-generated range as table. + """ + + # Calculate the number of rows needed for the range (end-start)/step + row_count: int = math.ceil( + (collection.end - collection.start) / collection.step + ) + + # Handle empty range by injecting a single NULL row + # SELECT CAST(NULL AS INT) AS x WHERE FALSE + if row_count <= 0: + query: SQLGlotExpression = sqlglot_expressions.Select( + expressions=[ + sqlglot_expressions.Alias( + this=sqlglot_expressions.Cast( + this=sqlglot_expressions.Null(), + to=sqlglot_expressions.DataType.build("INTEGER"), + ), + alias=sqlglot_expressions.Identifier( + this=collection.column_name + ), + ) + ], + ).where(sqlglot_expressions.false()) + + else: + # Build the SQLGlot query using Snowflake's GENERATOR function + # WITH table_name AS ( + # SELECT + # start + SEQ4() * step AS column_name + # FROM TABLE(GENERATOR(ROWCOUNT => row_count)) + # ) + # SELECT column_name FROM table_name + + # Step 1. Build the base expression: SEQ4() * step + # (or just SEQ4() if step == 1) + if collection.step == 1: + seq4_expr = sqlglot_expressions.Anonymous(this="SEQ4") + else: + seq4_expr = sqlglot_expressions.Mul( + this=sqlglot_expressions.Anonymous(this="SEQ4"), + expression=sqlglot_expressions.Literal.number(collection.step), + ) + + # Step 2. Add start if start != 0 + # Final expression: start + SEQ4() * step + if collection.start != 0: + final_expr = sqlglot_expressions.Add( + this=sqlglot_expressions.Literal.number(collection.start), + expression=seq4_expr, + ) + else: + final_expr = seq4_expr + + # 3. Build the inner SELECT + # SELECT start + SEQ4() * step AS column_name + # FROM TABLE(GENERATOR(ROWCOUNT => row_count)) + inner_select: SQLGlotExpression = sqlglot_expressions.Select( + expressions=[ + sqlglot_expressions.Alias( + this=final_expr, + alias=sqlglot_expressions.Identifier( + this=collection.column_name + ), + ) + ] + ).from_( + sqlglot_expressions.Table( + this=sqlglot_expressions.Anonymous( + this="TABLE", + expressions=[ + sqlglot_expressions.Anonymous( + this="GENERATOR", + expressions=[ + sqlglot_expressions.Kwarg( + this=sqlglot_expressions.Var(this="ROWCOUNT"), + expression=sqlglot_expressions.Literal.number( + row_count + ), + ) + ], + ) + ], + ) + ) + ) + + # 4. Wrap it as a subquery with alias + # WITH table_name AS ( ...inner_select... ) + subquery: SQLGlotExpression = sqlglot_expressions.Subquery( + this=inner_select, + alias=sqlglot_expressions.Identifier(this=collection.name), + ) + + # 5. Outer SELECT that references the subquery + # SELECT column_name FROM table_name + query = sqlglot_expressions.Select( + expressions=[ + sqlglot_expressions.Column( + this=collection.column_name, table=collection.name + ) + ] + ).from_(subquery) + + return query diff --git a/pydough/unqualified/qualification.py b/pydough/unqualified/qualification.py index e8e8b2f45..c1566b8a6 100644 --- a/pydough/unqualified/qualification.py +++ b/pydough/unqualified/qualification.py @@ -43,6 +43,7 @@ UnqualifiedCalculate, UnqualifiedCollation, UnqualifiedCross, + UnqualifiedGeneratedCollection, UnqualifiedLiteral, UnqualifiedNode, UnqualifiedOperation, @@ -1260,6 +1261,39 @@ def qualify_cross( return qualified_child + def qualify_generated_collection( + self, + unqualified: UnqualifiedGeneratedCollection, + context: PyDoughCollectionQDAG, + is_child: bool, + ) -> PyDoughCollectionQDAG: + """ + Transforms an `UnqualifiedGeneratedCollection` into a PyDoughCollectionQDAG node. + + Args: + `unqualified`: the UnqualifiedGeneratedCollection instance to be transformed. + `context`: the collection QDAG whose context the collection is being + evaluated within. + `is_child`: whether the collection is being qualified as a child + of a child operator context, such as CALCULATE or PARTITION. + + Returns: + The PyDough QDAG object for the qualified collection node. + + """ + + generated_collection_qdag: PyDoughCollectionQDAG = ( + self.builder.build_generated_collection( + context, + unqualified._parcel[0], + ) + ) + if is_child: + generated_collection_qdag = ChildOperatorChildAccess( + generated_collection_qdag + ) + return generated_collection_qdag + def qualify_node( self, unqualified: UnqualifiedNode, @@ -1324,6 +1358,10 @@ def qualify_node( answer = self.qualify_best(unqualified, context, is_child) case UnqualifiedCross(): answer = self.qualify_cross(unqualified, context, is_child) + case UnqualifiedGeneratedCollection(): + answer = self.qualify_generated_collection( + unqualified, context, is_child + ) case _: raise PyDoughUnqualifiedException( f"Cannot qualify {unqualified.__class__.__name__}: {unqualified!r}" diff --git a/pydough/unqualified/unqualified_node.py b/pydough/unqualified/unqualified_node.py index 4be3173f4..04b992e40 100644 --- a/pydough/unqualified/unqualified_node.py +++ b/pydough/unqualified/unqualified_node.py @@ -8,6 +8,7 @@ "UnqualifiedBinaryOperation", "UnqualifiedCalculate", "UnqualifiedCross", + "UnqualifiedGeneratedCollection", "UnqualifiedLiteral", "UnqualifiedNode", "UnqualifiedOperation", @@ -40,6 +41,7 @@ StringType, UnknownType, ) +from pydough.user_collections.user_collections import PyDoughUserGeneratedCollection class UnqualifiedNode(ABC): @@ -784,6 +786,13 @@ def __init__( ] = (data, by, per, allow_ties, n_best) +class UnqualifiedGeneratedCollection(UnqualifiedNode): + """Represents a user-generated collection of values.""" + + def __init__(self, user_collection: PyDoughUserGeneratedCollection): + self._parcel: tuple[PyDoughUserGeneratedCollection] = (user_collection,) + + def display_raw(unqualified: UnqualifiedNode) -> str: """ Prints an unqualified node in a human-readable manner that shows its @@ -881,6 +890,12 @@ def display_raw(unqualified: UnqualifiedNode) -> str: if unqualified._parcel[4] > 1: result += f", n_best={unqualified._parcel[4]}" return result + ")" + case UnqualifiedGeneratedCollection(): + result = "generated_collection(" + result += f"name={unqualified._parcel[0].name!r}, " + result += f"columns=[{', '.join(unqualified._parcel[0].columns)}]," + result += f"data={unqualified._parcel[0].to_string()}" + return result + ")" case _: raise PyDoughUnqualifiedException( f"Unsupported unqualified node: {unqualified.__class__.__name__}" diff --git a/pydough/unqualified/unqualified_transform.py b/pydough/unqualified/unqualified_transform.py index e663e8974..56feee650 100644 --- a/pydough/unqualified/unqualified_transform.py +++ b/pydough/unqualified/unqualified_transform.py @@ -1,6 +1,6 @@ """ Logic for transforming raw Python code into PyDough code by replacing undefined -variables with unqualified nodes by prepending with with `_ROOT.`. +variables with unqualified nodes by prepending it with `_ROOT.`. """ __all__ = ["from_string", "init_pydough_context", "transform_cell", "transform_code"] @@ -365,8 +365,8 @@ def transform_code( source: str, graph_dict: dict[str, GraphMetadata], known_names: set[str] ) -> ast.AST: """ - Transforms the source code into a new Python QDAG that has had the PyDough - decorator removed, had the definition of `_ROOT` injected at the top of the + Transforms the source code into a new Python QDAG that has the PyDough + decorator removed, has the definition of `_ROOT` injected at the top of the function body, and prepend unknown variables with `_ROOT.` Args: diff --git a/pydough/user_collections/README.md b/pydough/user_collections/README.md new file mode 100644 index 000000000..b44a758e8 --- /dev/null +++ b/pydough/user_collections/README.md @@ -0,0 +1,54 @@ +# PyDough User Collections + +This module defines the user collections that can be created on the fly and used in PyDough with other collections, for example: range collections, Pandas DataFrame collections. The user collections are registered and made available for use in PyDough code. + +## Available APIs + +### [range_collection.py](range_collection.py) + + - `RangeGeneratedCollection`: Class used to create a range collection that generates a sequence of numbers based on the specified start, end, and step values. + - `name`: The name of the range collection. + - `column_name`: The name of the column in the range collection. + - `start`: The starting value of the range (inclusive). + - `end`: The ending value of the range (exclusive). + - `step`: The step value for incrementing the range. Default is 1. + +### [user_collection_apis.py](user_collection_apis.py) + - `range_collection`: Function to create a range collection with the specified parameters. + - `name`: The name of the range collection. + - `column_name`: The name of the column in the range collection. + - `start`: The starting value of the range (inclusive). + - `end`: The ending value of the range (exclusive). + - `step`: The step value for incrementing the range. Default is 1. + - Returns: An instance of `RangeGeneratedCollection`. + +### [user_collections.py](user_collections.py) + - `PyDoughUserGeneratedCollection`: Base class for all user-generated collections in PyDough. + +## Usage + +You can access user collections through `pydough` and call them with the required arguments. For example: + +```python +import pydough + +my_range = pydough.range_collection( + "simple_range", + "col1", + 1, 10, 2 + ) +``` +Output: +``` + col1 +0 1 +1 3 +2 5 +3 7 +4 9 +``` + +## Detailed Explanation + +The user collections module provides a way to create collections that are not part of the static metadata graph but can be generated dynamically based on user input or code. The most common user collection are integer range collections and Pandas DataFrame collections. +The range collection, generates a sequence of numbers. The `RangeGeneratedCollection` class allows users to define a range collection by specifying the start, end, and step values. The `range_collection` function is a convenient API to create instances of `RangeGeneratedCollection`. \ No newline at end of file diff --git a/pydough/user_collections/__init__.py b/pydough/user_collections/__init__.py new file mode 100644 index 000000000..fa78c0640 --- /dev/null +++ b/pydough/user_collections/__init__.py @@ -0,0 +1,5 @@ +""" +Module of PyDough dealing with APIs used for user generated collections. +""" + +__all__ = ["range_collection"] diff --git a/pydough/user_collections/range_collection.py b/pydough/user_collections/range_collection.py new file mode 100644 index 000000000..fa747e606 --- /dev/null +++ b/pydough/user_collections/range_collection.py @@ -0,0 +1,94 @@ +"""A user-defined collection of integers in a specified range. +Usage: +`pydough.range_collection(name, column, *args)` + args: start, end, step + +This module defines a collection that generates integers from `start` to `end` +with a specified `step`. The user must specify the name of the collection and the +name of the column that will hold the integer values. +""" + +from pydough.types import NumericType +from pydough.types.pydough_type import PyDoughType +from pydough.user_collections.user_collections import PyDoughUserGeneratedCollection + +all = ["RangeGeneratedCollection"] + + +class RangeGeneratedCollection(PyDoughUserGeneratedCollection): + """Integer range-based collection.""" + + def __init__( + self, + name: str, + column_name: str, + range: range, + ) -> None: + super().__init__( + name=name, + columns=[ + column_name, + ], + types=[NumericType()], + ) + self._range = range + self._start = self._range.start + self._end = self._range.stop + self._step = self._range.step + + @property + def start(self) -> int: + """Return the start of the range.""" + return self._start + + @property + def end(self) -> int: + """Return the end of the range.""" + return self._end + + @property + def step(self) -> int: + """Return the step of the range.""" + return self._step + + @property + def range(self) -> range: + """Return the range object representing the collection.""" + return self._range + + @property + def column_names_and_types(self) -> list[tuple[str, PyDoughType]]: + return [(self.columns[0], NumericType())] + + @property + def column_name(self) -> str: + return self.columns[0] + + @property + def unique_column_names(self) -> list[str]: + return [self.columns[0]] + + def __len__(self) -> int: + return len(self._range) + + def is_singular(self) -> bool: + """Returns True if the collection is guaranteed to contain at most one row.""" + return len(self) <= 1 + + def always_exists(self) -> bool: + """Check if the range collection is always non-empty.""" + return len(self) > 0 + + def to_string(self) -> str: + """Return a string representation of the range collection.""" + return f"RangeCollection({self.name!r}, {self.columns[0]}={self.range})" + + def equals(self, other) -> bool: + return ( + isinstance(other, RangeGeneratedCollection) + and self.name == other.name + and self.columns == other.columns + and self.start == other.start + and self.end == other.end + and self.step == other.step + ) diff --git a/pydough/user_collections/user_collection_apis.py b/pydough/user_collections/user_collection_apis.py new file mode 100644 index 000000000..ea008691e --- /dev/null +++ b/pydough/user_collections/user_collection_apis.py @@ -0,0 +1,45 @@ +""" +Implementation of User Collection APIs in PyDough. +""" + +__all__ = ["range_collection"] + +from pydough.unqualified.unqualified_node import UnqualifiedGeneratedCollection +from pydough.user_collections.range_collection import RangeGeneratedCollection + + +def range_collection( + name: str, column: str, *args: int +) -> UnqualifiedGeneratedCollection: + """ + Implementation of the `pydough.range_collection` function, which provides + a way to create a collection of integer ranges over a specified column in PyDough. + + Args: + `name` : The name of the collection. + `column` : The column to create ranges for. + `*args` : Variable length arguments that specify the range parameters. + Supported formats: + - `range_collection(end)`: generates a range from 0 to `end-1` + with a step of 1. + - `range_collection(start, end)`: generates a range from `start` + to `end-1` with a step of 1. + - `range_collection(start, end, step)`: generates a range from + `start` to `end-1` with the specified step. + Returns: + A collection of integer ranges. + """ + if not isinstance(name, str): + raise TypeError(f"Expected 'name' to be a string, got {type(name).__name__}") + if not isinstance(column, str): + raise TypeError( + f"Expected 'column' to be a string, got {type(column).__name__}" + ) + r = range(*args) + range_collection = RangeGeneratedCollection( + name=name, + column_name=column, + range=r, + ) + + return UnqualifiedGeneratedCollection(range_collection) diff --git a/pydough/user_collections/user_collections.py b/pydough/user_collections/user_collections.py new file mode 100644 index 000000000..1ec3049ce --- /dev/null +++ b/pydough/user_collections/user_collections.py @@ -0,0 +1,87 @@ +""" +Base definition of PyDough QDAG collection type for accesses to a user defined +collection of the current context. +""" + +from abc import ABC, abstractmethod + +from pydough.types.pydough_type import PyDoughType + +__all__ = ["PyDoughUserGeneratedCollection"] + + +class PyDoughUserGeneratedCollection(ABC): + """ + Abstract base class for a user defined table collection. + This class defines the interface for accessing a user defined table collection + directly, without any specific implementation details. + It is intended to be subclassed by specific implementations that provide + the actual behavior and properties of the collection. + """ + + def __init__(self, name: str, columns: list[str], types: list[PyDoughType]) -> None: + self._name = name + self._columns = columns + self._types = types + + def __eq__(self, other) -> bool: + return self.equals(other) + + def __repr__(self) -> str: + return self.to_string() + + def __hash__(self) -> int: + return hash(repr(self)) + + def __str__(self) -> str: + return self.to_string() + + @property + def name(self) -> str: + """Return the name used for the collection.""" + return self._name + + @property + def columns(self) -> list[str]: + """Return column names.""" + return self._columns + + @property + @abstractmethod + def column_names_and_types(self) -> list[tuple[str, PyDoughType]]: + """Return column names and their types.""" + + @property + @abstractmethod + def unique_column_names(self) -> list[str]: + """Return the set of unique column names in the collection.""" + + @abstractmethod + def always_exists(self) -> bool: + """Check if the collection is always non-empty.""" + + @abstractmethod + def is_singular(self) -> bool: + """Returns True if the collection is guaranteed to contain at most one row.""" + + @abstractmethod + def to_string(self) -> str: + """Return a string representation of the collection.""" + + @abstractmethod + def equals(self, other) -> bool: + """ + Check if this collection is equal to another collection. + Two collections are considered equal if they have the same name and columns. + """ + + def get_expression_position(self, expr_name: str) -> int: + """ + Get the position of an expression in the collection. + This is used to determine the order of expressions in the collection. + """ + if expr_name not in self.columns: + raise ValueError( + f"Expression {expr_name!r} not found in collection {self.name!r}" + ) + return self.columns.index(expr_name) diff --git a/tests/test_pipeline_sf.py b/tests/test_pipeline_sf.py index 0b887bc97..f7954fb46 100644 --- a/tests/test_pipeline_sf.py +++ b/tests/test_pipeline_sf.py @@ -6,6 +6,7 @@ # mypy: ignore-errors # ruff & mypy should not try to typecheck or verify any of this +from collections.abc import Callable import pandas as pd import pytest import datetime @@ -35,6 +36,22 @@ from .testing_utilities import PyDoughPandasTest from pydough import init_pydough_context, to_df, to_sql +# NOTE: this should move to test_pipeline_tpch_custom.py once the +# other dialects are supported +from tests.test_pydough_functions.user_collections import ( + simple_range_1, + simple_range_2, + simple_range_3, + simple_range_4, + simple_range_5, + user_range_collection_1, + user_range_collection_2, + user_range_collection_3, + user_range_collection_4, + user_range_collection_5, + user_range_collection_6, +) + @pytest.fixture( params=[ @@ -573,3 +590,273 @@ def test_pipeline_e2e_snowflake_custom_datasets( ) else: pytest.skip("Skipping non-keywords custom dataset tests for Snowflake.") + + +# NOTE: this should move and be part of tpch_custom_pipeline_test_data once the +# other dialects are supported +@pytest.fixture( + params=[ + pytest.param( + PyDoughPandasTest( + simple_range_1, + "TPCH", + lambda: pd.DataFrame({"value": range(10)}), + "simple_range_1", + ), + id="simple_range_1", + ), + pytest.param( + PyDoughPandasTest( + simple_range_2, + "TPCH", + lambda: pd.DataFrame({"value": range(9, -1, -1)}), + "simple_range_2", + ), + id="simple_range_2", + ), + pytest.param( + PyDoughPandasTest( + simple_range_3, + "TPCH", + lambda: pd.DataFrame({"foo": range(15, 20)}), + "simple_range_3", + ), + id="simple_range_3", + ), + pytest.param( + PyDoughPandasTest( + simple_range_4, + "TPCH", + lambda: pd.DataFrame({"foo": range(10, 0, -1)}), + "simple_range_4", + ), + id="simple_range_4", + ), + pytest.param( + PyDoughPandasTest( + simple_range_5, + "TPCH", + # TODO: even though generated SQL has CAST(NULL AS INT) AS x + # it returns x as object datatype. + # using `x: range(-1)` returns int64 so temp. using dtype=object + lambda: pd.DataFrame({"x": pd.Series(range(-1), dtype="object")}), + "simple_range_5", + ), + id="simple_range_5", + ), + pytest.param( + PyDoughPandasTest( + user_range_collection_1, + "TPCH", + lambda: pd.DataFrame( + { + "part_size": [ + 1, + 6, + 11, + 16, + 21, + 26, + 31, + 36, + 41, + 46, + 51, + 56, + 61, + 66, + 71, + 76, + 81, + 86, + 91, + 96, + ], + "n_parts": [ + 228, + 225, + 206, + 234, + 228, + 221, + 231, + 208, + 245, + 226, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + } + ), + "user_range_collection_1", + ), + id="user_range_collection_1", + ), + pytest.param( + PyDoughPandasTest( + user_range_collection_2, + "TPCH", + lambda: pd.DataFrame( + { + "x": [0, 2, 4, 6, 8], + "n_prefix": [1, 56, 56, 56, 56], + "n_suffix": [101, 100, 100, 100, 100], + } + ), + "user_range_collection_2", + ), + id="user_range_collection_2", + ), + pytest.param( + PyDoughPandasTest( + user_range_collection_3, + "TPCH", + lambda: pd.DataFrame( + { + "x": [0, 2, 4, 6, 8], + "n_prefix": [1, 56, 56, 56, 56], + "n_suffix": [101, 100, 100, 100, 100], + } + ), + "user_range_collection_3", + ), + id="user_range_collection_3", + ), + pytest.param( + PyDoughPandasTest( + user_range_collection_4, + "TPCH", + lambda: pd.DataFrame( + { + "part_size": [1, 2, 4, 5, 6, 10], + "name": [ + "azure lime burnished blush salmon", + "spring green chocolate azure navajo", + "cornflower bisque thistle floral azure", + "azure aquamarine tomato lace peru", + "antique cyan tomato azure dim", + "red cream rosy hot azure", + ], + "retail_price": [ + 1217.13, + 1666.60, + 1863.87, + 1114.16, + 1716.72, + 1746.81, + ], + } + ), + "user_range_collection_4", + ), + id="user_range_collection_4", + ), + pytest.param( + PyDoughPandasTest( + user_range_collection_5, + "TPCH", + lambda: pd.DataFrame( + { + "part_size": [1, 11, 21, 31, 41, 51, 6, 16, 26, 36, 46, 56], + "n_parts": [ + 1135, + 1067, + 1128, + 1109, + 1038, + 0, + 1092, + 1154, + 1065, + 1094, + 1088, + 0, + ], + } + ), + "user_range_collection_5", + ), + id="user_range_collection_5", + ), + pytest.param( + PyDoughPandasTest( + user_range_collection_6, + "TPCH", + lambda: pd.DataFrame( + { + "year": [ + 1990, + 1991, + 1992, + 1993, + 1994, + 1995, + 1996, + 1997, + 1998, + 1999, + 2000, + ], + "n_orders": [0, 0, 1, 2, 0, 0, 1, 1, 2, 0, 0], + } + ), + "user_range_collection_6", + ), + id="user_range_collection_6", + ), + ], +) +def sf_user_generated_data(request) -> PyDoughPandasTest: + """ + Test data for e2e tests for user generated collections on Snowflake. Returns an instance of + PyDoughPandasTest containing information about the test. + """ + return request.param + + +@pytest.mark.snowflake +@pytest.mark.execute +def test_e2e_sf_user_generated_data( + sf_user_generated_data: PyDoughPandasTest, + get_sf_sample_graph: graph_fetcher, + sf_conn_db_context: DatabaseContext, +): + """ + Test executing the TPC-H queries from the original code generation, + with Snowflake as the executing database. + Using the `connection` as keyword argument to the DatabaseContext. + """ + sf_user_generated_data.run_e2e_test( + get_sf_sample_graph, + sf_conn_db_context("SNOWFLAKE_SAMPLE_DATA", "TPCH_SF1"), + coerce_types=True, + ) + + +# TODO: delete this test once the other dialects are supported +# and moved to tpch_custom_pipeline_test_data +# It's needed here to access sf_user_generated_data fixture +# that has user-generated test cases. +def test_pipeline_until_relational_tpch_custom_sf( + sf_user_generated_data: PyDoughPandasTest, + get_sample_graph: graph_fetcher, + get_plan_test_filename: Callable[[str], str], + update_tests: bool, +) -> None: + """ + Tests that a PyDough unqualified node can be correctly translated to its + qualified DAG version, with the correct string representation. Run on + custom queries with the TPC-H graph. + """ + file_path: str = get_plan_test_filename(sf_user_generated_data.test_name) + sf_user_generated_data.run_relational_test( + get_sample_graph, file_path, update_tests + ) diff --git a/tests/test_plan_refsols/simple_range_1.txt b/tests/test_plan_refsols/simple_range_1.txt new file mode 100644 index 000000000..0305c5c26 --- /dev/null +++ b/tests/test_plan_refsols/simple_range_1.txt @@ -0,0 +1,2 @@ +ROOT(columns=[('value', value)], orderings=[]) + GENERATED_TABLE(RangeCollection('simple_range', value=range(0, 10))) diff --git a/tests/test_plan_refsols/simple_range_2.txt b/tests/test_plan_refsols/simple_range_2.txt new file mode 100644 index 000000000..dcaca909b --- /dev/null +++ b/tests/test_plan_refsols/simple_range_2.txt @@ -0,0 +1,2 @@ +ROOT(columns=[('value', value)], orderings=[(value):desc_last]) + GENERATED_TABLE(RangeCollection('simple_range', value=range(0, 10))) diff --git a/tests/test_plan_refsols/simple_range_3.txt b/tests/test_plan_refsols/simple_range_3.txt new file mode 100644 index 000000000..cdf4a5808 --- /dev/null +++ b/tests/test_plan_refsols/simple_range_3.txt @@ -0,0 +1,2 @@ +ROOT(columns=[('foo', foo)], orderings=[(foo):asc_first]) + GENERATED_TABLE(RangeCollection('T1', foo=range(15, 20))) diff --git a/tests/test_plan_refsols/simple_range_4.txt b/tests/test_plan_refsols/simple_range_4.txt new file mode 100644 index 000000000..373e87f0b --- /dev/null +++ b/tests/test_plan_refsols/simple_range_4.txt @@ -0,0 +1,2 @@ +ROOT(columns=[('N', N)], orderings=[(N):asc_first]) + GENERATED_TABLE(RangeCollection('T2', N=range(10, 0, -1))) diff --git a/tests/test_plan_refsols/simple_range_5.txt b/tests/test_plan_refsols/simple_range_5.txt new file mode 100644 index 000000000..7ca2f3a49 --- /dev/null +++ b/tests/test_plan_refsols/simple_range_5.txt @@ -0,0 +1,2 @@ +ROOT(columns=[('x', x)], orderings=[]) + GENERATED_TABLE(RangeCollection('T3', x=range(0, -1))) diff --git a/tests/test_plan_refsols/user_range_collection_1.txt b/tests/test_plan_refsols/user_range_collection_1.txt new file mode 100644 index 000000000..2355c0122 --- /dev/null +++ b/tests/test_plan_refsols/user_range_collection_1.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('part_size', part_size), ('n_parts', DEFAULT_TO(n_rows, 0:numeric))], orderings=[]) + JOIN(condition=t0.part_size == t1.p_size, type=LEFT, cardinality=SINGULAR_FILTER, reverse_cardinality=PLURAL_FILTER, columns={'n_rows': t1.n_rows, 'part_size': t0.part_size}) + GENERATED_TABLE(RangeCollection('sizes', part_size=range(1, 100, 5))) + AGGREGATE(keys={'p_size': p_size}, aggregations={'n_rows': COUNT()}) + FILTER(condition=CONTAINS(p_name, 'turquoise':string), columns={'p_size': p_size}) + SCAN(table=tpch.PART, columns={'p_name': p_name, 'p_size': p_size}) diff --git a/tests/test_plan_refsols/user_range_collection_2.txt b/tests/test_plan_refsols/user_range_collection_2.txt new file mode 100644 index 000000000..99dd2df9f --- /dev/null +++ b/tests/test_plan_refsols/user_range_collection_2.txt @@ -0,0 +1,10 @@ +ROOT(columns=[('x', x), ('n_prefix', n_rows), ('n_suffix', agg_1_1)], orderings=[(x):asc_first]) + JOIN(condition=t0.x == t1.x, type=INNER, cardinality=SINGULAR_ACCESS, reverse_cardinality=SINGULAR_ACCESS, columns={'agg_1_1': t1.n_rows, 'n_rows': t0.n_rows, 'x': t0.x}) + AGGREGATE(keys={'x': x}, aggregations={'n_rows': COUNT()}) + JOIN(condition=STARTSWITH(STRING(t1.y), STRING(t0.x)), type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_FILTER, columns={'x': t0.x}) + GENERATED_TABLE(RangeCollection('a', x=range(0, 10))) + GENERATED_TABLE(RangeCollection('b', y=range(0, 1001, 2))) + AGGREGATE(keys={'x': x}, aggregations={'n_rows': COUNT()}) + JOIN(condition=ENDSWITH(STRING(t1.y), STRING(t0.x)), type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_FILTER, columns={'x': t0.x}) + GENERATED_TABLE(RangeCollection('a', x=range(0, 10))) + GENERATED_TABLE(RangeCollection('b', y=range(0, 1001, 2))) diff --git a/tests/test_plan_refsols/user_range_collection_3.txt b/tests/test_plan_refsols/user_range_collection_3.txt new file mode 100644 index 000000000..0f1aec332 --- /dev/null +++ b/tests/test_plan_refsols/user_range_collection_3.txt @@ -0,0 +1,10 @@ +ROOT(columns=[('x', x), ('n_prefix', n_rows), ('n_suffix', n_suffix)], orderings=[(x):asc_first]) + JOIN(condition=t0.x == t1.x, type=INNER, cardinality=SINGULAR_ACCESS, reverse_cardinality=SINGULAR_ACCESS, columns={'n_rows': t0.n_rows, 'n_suffix': t1.n_rows, 'x': t0.x}) + AGGREGATE(keys={'x': x}, aggregations={'n_rows': COUNT()}) + JOIN(condition=STARTSWITH(STRING(t1.y), STRING(t0.x)), type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_FILTER, columns={'x': t0.x}) + GENERATED_TABLE(RangeCollection('a', x=range(0, 10))) + GENERATED_TABLE(RangeCollection('b', y=range(0, 1001, 2))) + AGGREGATE(keys={'x': x}, aggregations={'n_rows': COUNT()}) + JOIN(condition=ENDSWITH(STRING(t1.y), STRING(t0.x)), type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_FILTER, columns={'x': t0.x}) + GENERATED_TABLE(RangeCollection('a', x=range(0, 10))) + GENERATED_TABLE(RangeCollection('b', y=range(0, 1001, 2))) diff --git a/tests/test_plan_refsols/user_range_collection_4.txt b/tests/test_plan_refsols/user_range_collection_4.txt new file mode 100644 index 000000000..46b108cb7 --- /dev/null +++ b/tests/test_plan_refsols/user_range_collection_4.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('part_size', part_size), ('name', p_name), ('retail_price', p_retailprice)], orderings=[(part_size):asc_first]) + FILTER(condition=RANKING(args=[], partition=[part_size], order=[(p_retailprice):asc_last], allow_ties=False) == 1:numeric, columns={'p_name': p_name, 'p_retailprice': p_retailprice, 'part_size': part_size}) + JOIN(condition=t1.p_size == t0.part_size, type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_FILTER, columns={'p_name': t1.p_name, 'p_retailprice': t1.p_retailprice, 'part_size': t0.part_size}) + GENERATED_TABLE(RangeCollection('sizes', part_size=range(1, 11))) + FILTER(condition=CONTAINS(p_container, 'SM DRUM':string) & CONTAINS(p_name, 'azure':string) & CONTAINS(p_type, 'PLATED':string), columns={'p_name': p_name, 'p_retailprice': p_retailprice, 'p_size': p_size}) + SCAN(table=tpch.PART, columns={'p_container': p_container, 'p_name': p_name, 'p_retailprice': p_retailprice, 'p_size': p_size, 'p_type': p_type}) diff --git a/tests/test_plan_refsols/user_range_collection_5.txt b/tests/test_plan_refsols/user_range_collection_5.txt new file mode 100644 index 000000000..dc9a61778 --- /dev/null +++ b/tests/test_plan_refsols/user_range_collection_5.txt @@ -0,0 +1,8 @@ +ROOT(columns=[('part_size', part_size), ('n_parts', DEFAULT_TO(n_rows, 0:numeric))], orderings=[]) + JOIN(condition=t0.part_size == t1.part_size, type=LEFT, cardinality=SINGULAR_FILTER, reverse_cardinality=SINGULAR_ACCESS, columns={'n_rows': t1.n_rows, 'part_size': t0.part_size}) + GENERATED_TABLE(RangeCollection('sizes', part_size=range(1, 60, 5))) + AGGREGATE(keys={'part_size': part_size}, aggregations={'n_rows': COUNT()}) + JOIN(condition=MONOTONIC(t0.part_size, t1.p_size, t0.part_size + 4:numeric), type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_FILTER, columns={'part_size': t0.part_size}) + GENERATED_TABLE(RangeCollection('sizes', part_size=range(1, 60, 5))) + FILTER(condition=CONTAINS(p_name, 'almond':string), columns={'p_size': p_size}) + SCAN(table=tpch.PART, columns={'p_name': p_name, 'p_size': p_size}) diff --git a/tests/test_plan_refsols/user_range_collection_6.txt b/tests/test_plan_refsols/user_range_collection_6.txt new file mode 100644 index 000000000..d3151cb5d --- /dev/null +++ b/tests/test_plan_refsols/user_range_collection_6.txt @@ -0,0 +1,12 @@ +ROOT(columns=[('year', year), ('n_orders', DEFAULT_TO(ndistinct_o_custkey, 0:numeric))], orderings=[(year):asc_first]) + JOIN(condition=t0.year == t1.year_o_orderdate, type=LEFT, cardinality=SINGULAR_FILTER, reverse_cardinality=PLURAL_FILTER, columns={'ndistinct_o_custkey': t1.ndistinct_o_custkey, 'year': t0.year}) + GENERATED_TABLE(RangeCollection('years', year=range(1990, 2001))) + AGGREGATE(keys={'year_o_orderdate': YEAR(o_orderdate)}, aggregations={'ndistinct_o_custkey': NDISTINCT(o_custkey)}) + JOIN(condition=t0.o_custkey == t1.c_custkey, type=INNER, cardinality=SINGULAR_FILTER, reverse_cardinality=PLURAL_FILTER, columns={'o_custkey': t0.o_custkey, 'o_orderdate': t0.o_orderdate}) + FILTER(condition=o_clerk == 'Clerk#000000925':string, columns={'o_custkey': o_custkey, 'o_orderdate': o_orderdate}) + SCAN(table=tpch.ORDERS, columns={'o_clerk': o_clerk, 'o_custkey': o_custkey, 'o_orderdate': o_orderdate}) + JOIN(condition=t0.c_nationkey == t1.n_nationkey, type=INNER, cardinality=SINGULAR_FILTER, reverse_cardinality=PLURAL_FILTER, columns={'c_custkey': t0.c_custkey}) + FILTER(condition=c_mktsegment == 'AUTOMOBILE':string, columns={'c_custkey': c_custkey, 'c_nationkey': c_nationkey}) + SCAN(table=tpch.CUSTOMER, columns={'c_custkey': c_custkey, 'c_mktsegment': c_mktsegment, 'c_nationkey': c_nationkey}) + FILTER(condition=n_name == 'JAPAN':string, columns={'n_nationkey': n_nationkey}) + SCAN(table=tpch.NATION, columns={'n_name': n_name, 'n_nationkey': n_nationkey}) diff --git a/tests/test_pydough_functions/user_collections.py b/tests/test_pydough_functions/user_collections.py new file mode 100644 index 000000000..3325929ca --- /dev/null +++ b/tests/test_pydough_functions/user_collections.py @@ -0,0 +1,151 @@ +""" +Various functions containing user generated collections as +PyDough code snippets for testing purposes. +""" +# ruff: noqa +# mypy: ignore-errors +# ruff & mypy should not try to typecheck or verify any of this + +import pydough + +import pytest + +# Snowflake only. +# Other dialects does not support range collections yet. +pytestmark = pytest.mark.snowflake + + +def simple_range_1(): + # Generates a table with column named `value` containing integers from 0 to 9. + return pydough.range_collection( + "simple_range", + "value", + 10, # end value + ) + + +def simple_range_2(): + # Generates a table with column named `value` containing integers from 0 to 9, + # ordered in descending order. + return pydough.range_collection( + "simple_range", + "value", + 10, # end value + ).ORDER_BY(value.DESC()) + + +def simple_range_3(): + # Generates a table with column named `foo` containing integers from 15 to + # 20 exclusive, ordered in ascending order. + return pydough.range_collection("T1", "foo", 15, 20).ORDER_BY(foo.ASC()) + + +def simple_range_4(): + # Generate a table with 1 column named `N` counting backwards + # from 10 to 1 (inclusive) + return pydough.range_collection("T2", "N", 10, 0, -1).ORDER_BY(N.ASC()) + + +def simple_range_5(): + # Generate a table with 1 column named `x` which is an empty range + return pydough.range_collection("T3", "x", -1) + + +def user_range_collection_1(): + # Creates a collection `sizes` with a single property `part_size` whose values are the + # integers from 1 (inclusive) to 100 (exclusive), skipping by 5s, then for each size value, + # counts how many turquoise parts have that size. + sizes = pydough.range_collection("sizes", "part_size", 1, 100, 5) + turquoise_parts = parts.WHERE(CONTAINS(name, "turquoise")) + return sizes.CALCULATE(part_size).CALCULATE( + part_size, n_parts=COUNT(CROSS(turquoise_parts).WHERE(size == part_size)) + ) + + +def user_range_collection_2(): + # Generate two tables with one column: `a` has a column `x` of digits 0-9, + # `b` has a column `y` of every even number from 0 to 1000 (inclusive), and for + # every row of `a` count how many rows of `b` have `x` has a prefix of `y`, and + # how many have `x` as a suffix of `y` + table_a = pydough.range_collection("a", "x", 10) + table_b = pydough.range_collection("b", "y", 0, 1001, 2) + result = ( + table_a.CALCULATE(x) + .CALCULATE( + x, + n_prefix=COUNT(CROSS(table_b).WHERE(STARTSWITH(STRING(y), STRING(x)))), + n_suffix=COUNT(CROSS(table_b).WHERE(ENDSWITH(STRING(y), STRING(x)))), + ) + .ORDER_BY(x.ASC()) + ) + return result + + +def user_range_collection_3(): + # Same as user_range_collection_2 but only includes rows of x that + # have at least one prefix/suffix max + table_a = pydough.range_collection("a", "x", 10) + table_b = pydough.range_collection("b", "y", 0, 1001, 2) + prefix_b = CROSS(table_b).WHERE(STARTSWITH(STRING(y), STRING(x))) + suffix_b = CROSS(table_b).WHERE(ENDSWITH(STRING(y), STRING(x))) + return ( + table_a.CALCULATE(x) + .CALCULATE( + x, + n_prefix=COUNT(prefix_b), + n_suffix=COUNT(suffix_b), + ) + .WHERE(HAS(prefix_b) & HAS(suffix_b)) + .ORDER_BY(x.ASC()) + ) + + +def user_range_collection_4(): + # For every part size 1-10, find the name & + # retail price of the cheapest part of that size that + # is azure, plated, and has a small drum container + sizes = pydough.range_collection("sizes", "part_size", 1, 11) + azure_parts = parts.WHERE( + CONTAINS(name, "azure") + & CONTAINS(part_type, "PLATED") + & CONTAINS(container, "SM DRUM") + ) + return ( + sizes.CALCULATE(part_size) + .CROSS(azure_parts) + .WHERE(size == part_size) + .BEST(per="sizes", by=retail_price.ASC()) + .CALCULATE(part_size, name, retail_price) + .ORDER_BY(part_size.ASC()) + ) + + +def user_range_collection_5(): + # Creates a collection `sizes` with a single property `part_size` whose values are the + # integers from 1 (inclusive) to 60 (exclusive), skipping by 5s, then for each size value, + # counts how many almond parts are in the interval of 5 sizes starting with that size + sizes = pydough.range_collection("sizes", "part_size", 1, 60, 5) + almond_parts = parts.WHERE(CONTAINS(name, "almond")) + return sizes.CALCULATE(part_size).CALCULATE( + part_size, + n_parts=COUNT( + CROSS(almond_parts).WHERE(MONOTONIC(part_size, size, part_size + 4)) + ), + ) + + +def user_range_collection_6(): + # For every year from 1990 to 2000, how many orders were made in that year + # by a Japanese customer in the automobile market segment, processed by clerk 925 + years = pydough.range_collection("years", "year", 1990, 2001) + selected_orders = orders.WHERE( + (clerk == "Clerk#000000925") + & (customer.market_segment == "AUTOMOBILE") + & (customer.nation.name == "JAPAN") + ).CALCULATE(order_year=YEAR(order_date)) + order_years = selected_orders.PARTITION(name="yrs", by=(order_year, customer_key)) + return ( + years.CALCULATE(year) + .CALCULATE(year, n_orders=COUNT(CROSS(order_years).WHERE(order_year == year))) + .ORDER_BY(year.ASC()) + ) diff --git a/tests/test_pydough_to_sql.py b/tests/test_pydough_to_sql.py index 493087db9..b1141c0c3 100644 --- a/tests/test_pydough_to_sql.py +++ b/tests/test_pydough_to_sql.py @@ -52,6 +52,19 @@ window_sliding_frame_relsize, window_sliding_frame_relsum, ) +from tests.test_pydough_functions.user_collections import ( + simple_range_1, + simple_range_2, + simple_range_3, + simple_range_4, + simple_range_5, + user_range_collection_1, + user_range_collection_2, + user_range_collection_3, + user_range_collection_4, + user_range_collection_5, + user_range_collection_6, +) from tests.testing_utilities import ( graph_fetcher, ) @@ -191,6 +204,83 @@ pytest.param( casting_functions, None, "casting_functions", id="casting_functions" ), + pytest.param( + simple_range_1, + None, + "simple_range_1", + id="simple_range_1", + marks=pytest.mark.snowflake, + ), + pytest.param( + simple_range_2, + None, + "simple_range_2", + id="simple_range_2", + marks=pytest.mark.snowflake, + ), + pytest.param( + simple_range_3, + None, + "simple_range_3", + id="simple_range_3", + marks=pytest.mark.snowflake, + ), + pytest.param( + simple_range_4, + None, + "simple_range_4", + id="simple_range_4", + marks=pytest.mark.snowflake, + ), + pytest.param( + simple_range_5, + None, + "simple_range_5", + id="simple_range_5", + marks=pytest.mark.snowflake, + ), + pytest.param( + user_range_collection_1, + None, + "user_range_collection_1", + id="user_range_collection_1", + marks=pytest.mark.snowflake, + ), + pytest.param( + user_range_collection_2, + None, + "user_range_collection_2", + id="user_range_collection_2", + marks=pytest.mark.snowflake, + ), + pytest.param( + user_range_collection_3, + None, + "user_range_collection_3", + id="user_range_collection_3", + marks=pytest.mark.snowflake, + ), + pytest.param( + user_range_collection_4, + None, + "user_range_collection_4", + id="user_range_collection_4", + marks=pytest.mark.snowflake, + ), + pytest.param( + user_range_collection_5, + None, + "user_range_collection_5", + id="user_range_collection_5", + marks=pytest.mark.snowflake, + ), + pytest.param( + user_range_collection_6, + None, + "user_range_collection_6", + id="user_range_collection_6", + marks=pytest.mark.snowflake, + ), ], ) def test_pydough_to_sql_tpch( @@ -206,6 +296,13 @@ def test_pydough_to_sql_tpch( Tests that a PyDough unqualified node can be correctly translated to its qualified DAG version, with the correct string representation. """ + if (empty_context_database.dialect != DatabaseDialect.SNOWFLAKE) and ( + ("simple_range_" in test_name) + or ("user_range_collection_" in pydough_code.__name__) + ): + pytest.skip( + f"Skipping test {empty_context_database.dialect}-{test_name} since it is only supported on Snowflake" + ) graph: GraphMetadata = get_sample_graph("TPCH") root: UnqualifiedNode = init_pydough_context(graph)(pydough_code)() actual_sql: str = to_sql( diff --git a/tests/test_qualification.py b/tests/test_qualification.py index 4b2f1d790..2f46c9865 100644 --- a/tests/test_qualification.py +++ b/tests/test_qualification.py @@ -62,6 +62,10 @@ impl_tpch_q21, impl_tpch_q22, ) +from tests.test_pydough_functions.user_collections import ( + simple_range_1, + simple_range_2, +) @pytest.mark.parametrize( @@ -938,6 +942,23 @@ """, id="simple_cross_6", ), + pytest.param( + simple_range_1, + """ +──┬─ TPCH + └─── RangeCollection('simple_range', value=range(0, 10)) + """, + id="simple_range_1", + ), + pytest.param( + simple_range_2, + """ +──┬─ TPCH + ├─── RangeCollection('simple_range', value=range(0, 10)) + └─── OrderBy[value.DESC(na_pos='last')] + """, + id="simple_range_2", + ), ], ) def test_qualify_node_to_ast_string( diff --git a/tests/test_sql_refsols/simple_range_1_snowflake.sql b/tests/test_sql_refsols/simple_range_1_snowflake.sql new file mode 100644 index 000000000..4b3d04b0c --- /dev/null +++ b/tests/test_sql_refsols/simple_range_1_snowflake.sql @@ -0,0 +1,8 @@ +WITH simple_range AS ( + SELECT + SEQ4() AS value + FROM TABLE(GENERATOR(ROWCOUNT => 10)) +) +SELECT + value +FROM simple_range diff --git a/tests/test_sql_refsols/simple_range_2_snowflake.sql b/tests/test_sql_refsols/simple_range_2_snowflake.sql new file mode 100644 index 000000000..02bcba6a3 --- /dev/null +++ b/tests/test_sql_refsols/simple_range_2_snowflake.sql @@ -0,0 +1,10 @@ +WITH simple_range AS ( + SELECT + SEQ4() AS value + FROM TABLE(GENERATOR(ROWCOUNT => 10)) +) +SELECT + value +FROM simple_range +ORDER BY + 1 DESC NULLS LAST diff --git a/tests/test_sql_refsols/simple_range_3_snowflake.sql b/tests/test_sql_refsols/simple_range_3_snowflake.sql new file mode 100644 index 000000000..2d57558b1 --- /dev/null +++ b/tests/test_sql_refsols/simple_range_3_snowflake.sql @@ -0,0 +1,10 @@ +WITH t1 AS ( + SELECT + 15 + SEQ4() AS foo + FROM TABLE(GENERATOR(ROWCOUNT => 5)) +) +SELECT + foo +FROM t1 +ORDER BY + 1 NULLS FIRST diff --git a/tests/test_sql_refsols/simple_range_4_snowflake.sql b/tests/test_sql_refsols/simple_range_4_snowflake.sql new file mode 100644 index 000000000..8f6f8658d --- /dev/null +++ b/tests/test_sql_refsols/simple_range_4_snowflake.sql @@ -0,0 +1,10 @@ +WITH t2 AS ( + SELECT + 10 + SEQ4() * -1 AS n + FROM TABLE(GENERATOR(ROWCOUNT => 10)) +) +SELECT + n AS N +FROM t2 +ORDER BY + 1 NULLS FIRST diff --git a/tests/test_sql_refsols/simple_range_5_snowflake.sql b/tests/test_sql_refsols/simple_range_5_snowflake.sql new file mode 100644 index 000000000..e4f7b84dd --- /dev/null +++ b/tests/test_sql_refsols/simple_range_5_snowflake.sql @@ -0,0 +1,4 @@ +SELECT + CAST(NULL AS INT) AS x +WHERE + FALSE diff --git a/tests/test_sql_refsols/user_range_collection_1_snowflake.sql b/tests/test_sql_refsols/user_range_collection_1_snowflake.sql new file mode 100644 index 000000000..fff55ca17 --- /dev/null +++ b/tests/test_sql_refsols/user_range_collection_1_snowflake.sql @@ -0,0 +1,20 @@ +WITH sizes AS ( + SELECT + 1 + SEQ4() * 5 AS part_size + FROM TABLE(GENERATOR(ROWCOUNT => 20)) +), _s1 AS ( + SELECT + p_size, + COUNT(*) AS n_rows + FROM tpch.part + WHERE + CONTAINS(p_name, 'turquoise') + GROUP BY + 1 +) +SELECT + sizes.part_size, + COALESCE(_s1.n_rows, 0) AS n_parts +FROM sizes AS sizes +LEFT JOIN _s1 AS _s1 + ON _s1.p_size = sizes.part_size diff --git a/tests/test_sql_refsols/user_range_collection_2_snowflake.sql b/tests/test_sql_refsols/user_range_collection_2_snowflake.sql new file mode 100644 index 000000000..1ef3f8592 --- /dev/null +++ b/tests/test_sql_refsols/user_range_collection_2_snowflake.sql @@ -0,0 +1,44 @@ +WITH a AS ( + SELECT + SEQ4() AS x + FROM TABLE(GENERATOR(ROWCOUNT => 10)) +), b AS ( + SELECT + SEQ4() * 2 AS y + FROM TABLE(GENERATOR(ROWCOUNT => 501)) +), _s4 AS ( + SELECT + a.x, + COUNT(*) AS n_rows + FROM a AS a + JOIN b AS b + ON STARTSWITH(CAST(b.y AS TEXT), CAST(a.x AS TEXT)) + GROUP BY + 1 +), a_2 AS ( + SELECT + SEQ4() AS x + FROM TABLE(GENERATOR(ROWCOUNT => 10)) +), b_2 AS ( + SELECT + SEQ4() * 2 AS y + FROM TABLE(GENERATOR(ROWCOUNT => 501)) +), _s5 AS ( + SELECT + a.x, + COUNT(*) AS n_rows + FROM a_2 AS a + JOIN b_2 AS b + ON ENDSWITH(CAST(b.y AS TEXT), CAST(a.x AS TEXT)) + GROUP BY + 1 +) +SELECT + _s4.x, + _s4.n_rows AS n_prefix, + _s5.n_rows AS n_suffix +FROM _s4 AS _s4 +JOIN _s5 AS _s5 + ON _s4.x = _s5.x +ORDER BY + 1 NULLS FIRST diff --git a/tests/test_sql_refsols/user_range_collection_3_snowflake.sql b/tests/test_sql_refsols/user_range_collection_3_snowflake.sql new file mode 100644 index 000000000..1ef3f8592 --- /dev/null +++ b/tests/test_sql_refsols/user_range_collection_3_snowflake.sql @@ -0,0 +1,44 @@ +WITH a AS ( + SELECT + SEQ4() AS x + FROM TABLE(GENERATOR(ROWCOUNT => 10)) +), b AS ( + SELECT + SEQ4() * 2 AS y + FROM TABLE(GENERATOR(ROWCOUNT => 501)) +), _s4 AS ( + SELECT + a.x, + COUNT(*) AS n_rows + FROM a AS a + JOIN b AS b + ON STARTSWITH(CAST(b.y AS TEXT), CAST(a.x AS TEXT)) + GROUP BY + 1 +), a_2 AS ( + SELECT + SEQ4() AS x + FROM TABLE(GENERATOR(ROWCOUNT => 10)) +), b_2 AS ( + SELECT + SEQ4() * 2 AS y + FROM TABLE(GENERATOR(ROWCOUNT => 501)) +), _s5 AS ( + SELECT + a.x, + COUNT(*) AS n_rows + FROM a_2 AS a + JOIN b_2 AS b + ON ENDSWITH(CAST(b.y AS TEXT), CAST(a.x AS TEXT)) + GROUP BY + 1 +) +SELECT + _s4.x, + _s4.n_rows AS n_prefix, + _s5.n_rows AS n_suffix +FROM _s4 AS _s4 +JOIN _s5 AS _s5 + ON _s4.x = _s5.x +ORDER BY + 1 NULLS FIRST diff --git a/tests/test_sql_refsols/user_range_collection_4_snowflake.sql b/tests/test_sql_refsols/user_range_collection_4_snowflake.sql new file mode 100644 index 000000000..8ded54daa --- /dev/null +++ b/tests/test_sql_refsols/user_range_collection_4_snowflake.sql @@ -0,0 +1,25 @@ +WITH sizes AS ( + SELECT + 1 + SEQ4() AS part_size + FROM TABLE(GENERATOR(ROWCOUNT => 10)) +), _t0 AS ( + SELECT + part.p_name, + part.p_retailprice, + sizes.part_size + FROM sizes AS sizes + JOIN tpch.part AS part + ON CONTAINS(part.p_container, 'SM DRUM') + AND CONTAINS(part.p_name, 'azure') + AND CONTAINS(part.p_type, 'PLATED') + AND part.p_size = sizes.part_size + QUALIFY + ROW_NUMBER() OVER (PARTITION BY part_size ORDER BY part.p_retailprice) = 1 +) +SELECT + part_size, + p_name AS name, + p_retailprice AS retail_price +FROM _t0 +ORDER BY + 1 NULLS FIRST diff --git a/tests/test_sql_refsols/user_range_collection_5_snowflake.sql b/tests/test_sql_refsols/user_range_collection_5_snowflake.sql new file mode 100644 index 000000000..c28aefc89 --- /dev/null +++ b/tests/test_sql_refsols/user_range_collection_5_snowflake.sql @@ -0,0 +1,28 @@ +WITH sizes AS ( + SELECT + 1 + SEQ4() * 5 AS part_size + FROM TABLE(GENERATOR(ROWCOUNT => 12)) +), sizes_2 AS ( + SELECT + 1 + SEQ4() * 5 AS part_size + FROM TABLE(GENERATOR(ROWCOUNT => 12)) +), _s3 AS ( + SELECT + sizes.part_size, + COUNT(*) AS n_rows + FROM sizes_2 AS sizes + JOIN tpch.part AS part + ON CONTAINS(part.p_name, 'almond') + AND part.p_size <= ( + sizes.part_size + 4 + ) + AND part.p_size >= sizes.part_size + GROUP BY + 1 +) +SELECT + sizes.part_size, + COALESCE(_s3.n_rows, 0) AS n_parts +FROM sizes AS sizes +LEFT JOIN _s3 AS _s3 + ON _s3.part_size = sizes.part_size diff --git a/tests/test_sql_refsols/user_range_collection_6_snowflake.sql b/tests/test_sql_refsols/user_range_collection_6_snowflake.sql new file mode 100644 index 000000000..db1899850 --- /dev/null +++ b/tests/test_sql_refsols/user_range_collection_6_snowflake.sql @@ -0,0 +1,26 @@ +WITH years AS ( + SELECT + 1990 + SEQ4() AS year + FROM TABLE(GENERATOR(ROWCOUNT => 11)) +), _s5 AS ( + SELECT + YEAR(CAST(orders.o_orderdate AS TIMESTAMP)) AS year_o_orderdate, + COUNT(DISTINCT orders.o_custkey) AS ndistinct_o_custkey + FROM tpch.orders AS orders + JOIN tpch.customer AS customer + ON customer.c_custkey = orders.o_custkey AND customer.c_mktsegment = 'AUTOMOBILE' + JOIN tpch.nation AS nation + ON customer.c_nationkey = nation.n_nationkey AND nation.n_name = 'JAPAN' + WHERE + orders.o_clerk = 'Clerk#000000925' + GROUP BY + 1 +) +SELECT + years.year, + COALESCE(_s5.ndistinct_o_custkey, 0) AS n_orders +FROM years AS years +LEFT JOIN _s5 AS _s5 + ON _s5.year_o_orderdate = years.year +ORDER BY + 1 NULLS FIRST