diff --git a/fquery/malloy_builder.py b/fquery/malloy_builder.py index 532df34..625823b 100644 --- a/fquery/malloy_builder.py +++ b/fquery/malloy_builder.py @@ -1,7 +1,3 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. import ast import operator diff --git a/fquery/polars_builder.py b/fquery/polars_builder.py new file mode 100644 index 0000000..5162bad --- /dev/null +++ b/fquery/polars_builder.py @@ -0,0 +1,53 @@ +import ast +import operator + +import polars as pl + +from .visitor import Visitor + +# inspired from pandas.core.computation.ops +_cmp_ops_syms = (">", "<", ">=", "<=", "==", "!=") +_cmp_ops_funcs = ( + operator.gt, + operator.lt, + operator.ge, + operator.le, + operator.eq, + operator.ne, +) +_cmp_ops_dict = dict(zip(_cmp_ops_syms, _cmp_ops_funcs)) + + +class PolarsBuilderVisitor(Visitor): + + def __init__(self, id1s): + self.polars = None + self.polars_stack = [] + self.visited = set() + + async def visit_leaf(self, query): + # TODO: make this columnar and real lazy instead of faking laziness + self.polars = pl.DataFrame(await query.as_list()).lazy() + while self.polars_stack: + func, params = self.polars_stack.pop() + self.polars = getattr(self.polars, func)(params) + + async def visit_project(self, query): + self.polars_stack.append(("select", query.projector)) + await self.visit(query.child) + + async def visit_take(self, query): + self.polars_stack.append(("limit", query._count)) + await self.visit(query.child) + + async def visit_where(self, query): + left, op, right = query._expr.value.split() + right = ast.literal_eval(right) + table, field = left.split(".") if "." in left else (self.malloy, left) + self.polars_stack.append(("filter", (_cmp_ops_dict[op](pl.col(field), right)))) + await self.visit(query.child) + + async def visit_order_by(self, query): + table, field = query._expr.value.split(".") + self.polars_stack.append(("sort", field)) + await self.visit(query.child) diff --git a/fquery/query.py b/fquery/query.py index c2c2611..7341235 100644 --- a/fquery/query.py +++ b/fquery/query.py @@ -12,6 +12,7 @@ from .async_utils import wait_for from .execute import AbstractSyntaxTreeVisitor from .malloy_builder import MalloyBuilderVisitor +from .polars_builder import PolarsBuilderVisitor from .sql_builder import SQLBuilderVisitor from .view_model import ViewModel, get_edges, get_return_type from .walk import ( @@ -256,6 +257,16 @@ def to_malloy(self) -> str: wait_for(visitor.visit(self)) return visitor.malloy + def to_polars(self) -> Tree: + visitor = PolarsBuilderVisitor([]) + wait_for(visitor.visit(self)) + return visitor.polars.collect() + + async def to_async_polars(self) -> Tree: + visitor = PolarsBuilderVisitor([]) + await visitor.visit(self) + return await visitor.polars.collect_async() + def batch_resolve_objs(self) -> List[Dict[str, List[ViewModel]]]: return [{str(None): [o for o in (self.resolve_obj(i) for i in self.ids) if o]}] diff --git a/setup.py b/setup.py index 74a01ff..f196325 100755 --- a/setup.py +++ b/setup.py @@ -34,5 +34,6 @@ "inflection >= 0.5.1", ], "graphql": ["strawberry >= 0.37.1"], + "df": ["polars >= 0.12.0"], }, ) diff --git a/tests/test_polars.py b/tests/test_polars.py new file mode 100644 index 0000000..93fdf46 --- /dev/null +++ b/tests/test_polars.py @@ -0,0 +1,36 @@ +import ast +import random +import unittest + +import polars as pl +from polars.testing import assert_frame_equal + +from .mock_user import UserQuery + + +class PolarsTests(unittest.TestCase): + def setUp(self): + random.seed(100) + self.maxDiff = None + + def test_project(self): + df = ( + UserQuery(range(1, 10)) + .project([":id", "name", "age"]) + .where(ast.Expr("user.age >= 16")) + .order_by(ast.Expr("user.age")) + .take(3) + .to_polars() + ) + expected = pl.DataFrame( + { + ":id": [1, 4, 2], + "name": ["id1", "id4", "id2"], + "age": [16, 16, 17], + } + ) + assert_frame_equal(expected, df) + + +if __name__ == "__main__": + unittest.main() diff --git a/tox.ini b/tox.ini index 7c39213..c9d3676 100644 --- a/tox.ini +++ b/tox.ini @@ -12,5 +12,6 @@ deps = sqlmodel@git+https://github.com/adsharma/sqlmodel.git@sqlmodel_rebuild duckdb-engine inflection + polars commands = pytest --cov --cov-config=setup.cfg -rs -v {posargs}