Skip to content

Commit

Permalink
python tuples in expressions (#8246)
Browse files Browse the repository at this point in the history
Co-authored-by: Jakub Kowalski <[email protected]>
GitOrigin-RevId: 1326542e0eb5fc5ec5206a7f91bee5e2e6bff354
  • Loading branch information
2 people authored and Manul from Pathway committed Feb 21, 2025
1 parent 18dde12 commit 9900020
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
### Added
- Added structure-aware chunking for `DoclingParser`
- Added `table_parsing_strategy` for `DoclingParser`
- Support for python tuples in expressions.

### Changed
- **BREAKING**: Changed the argument in `DoclingParser` from `parse_images` (bool) into `image_parsing_strategy` (Literal["llm"] | None)
Expand Down
7 changes: 5 additions & 2 deletions python/pathway/internals/custom_reducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def mark_stub(fun):

class ReducerProtocol(Protocol):
def __call__(
self, *args: expr.ColumnExpression | api.Value
self,
*args: expr.ColumnExpression | api.Value | tuple[expr.ColumnExpression, ...],
) -> expr.ColumnExpression: ...


Expand Down Expand Up @@ -95,7 +96,9 @@ def stateful_many(
for group 2 (at processing times 2, 4, 6).
"""

def wrapper(*args: expr.ColumnExpression | api.Value) -> expr.ColumnExpression:
def wrapper(
*args: expr.ColumnExpression | api.Value | tuple[expr.ColumnExpression, ...],
) -> expr.ColumnExpression:
return expr.ReducerExpression(StatefulManyReducer(combine_many), *args)

return wrapper
Expand Down
10 changes: 10 additions & 0 deletions python/pathway/internals/desugaring.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,17 @@


class DesugaringTransform(IdentityTransform):
def eval_tuple_of_maybe_expressions(self, expression: tuple, **kwargs):
result = [self.eval_expression(e) for e in expression]

if any(isinstance(e, expr.ColumnExpression) for e in result):
return expr.MakeTupleExpression(*result)

return expression

def eval_any(self, expression, **kwargs):
if isinstance(expression, tuple):
return self.eval_tuple_of_maybe_expressions(expression, **kwargs)
return expression


Expand Down
10 changes: 7 additions & 3 deletions python/pathway/internals/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ def __repr__(self):
return ExpressionFormatter().eval_expression(self)

@staticmethod
def _wrap(arg: ColumnExpression | Value) -> ColumnExpression:
def _wrap(
arg: ColumnExpression | Value | tuple[ColumnExpression, ...]
) -> ColumnExpression:
if isinstance(arg, tuple):
return MakeTupleExpression(*arg) if arg else ColumnConstExpression(())
if not isinstance(arg, ColumnExpression):
return ColumnConstExpression(arg)
return arg
Expand Down Expand Up @@ -713,8 +717,8 @@ class ReducerExpression(ColumnExpression):
def __init__(
self,
reducer: Reducer,
*args: ColumnExpression | Value,
**kwargs: ColumnExpression | Value,
*args: ColumnExpression | Value | tuple[ColumnExpression, ...],
**kwargs: ColumnExpression | Value | tuple[ColumnExpression, ...],
):
super().__init__()
self._reducer = reducer
Expand Down
160 changes: 160 additions & 0 deletions python/pathway/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6529,3 +6529,163 @@ def test_doesnt_warn_if_all_operators_used():
with warnings.catch_warnings():
warnings.simplefilter("error")
warn_if_some_operators_unused()


def test_python_tuple_select():
t = T(
"""
a | b
1 | 2
3 | 4
5 | 6
"""
)
t1 = t.select(c=(pw.this.a, pw.this.b), d=(pw.this.a, 2))
t2 = t.select(c=(t.a, pw.this.b), d=(t.a, 2))
t3 = t.select(c=(pw.this.a, t.b), d=(pw.this.a, 2))
expected = t.select(
c=pw.make_tuple(pw.this.a, pw.this.b), d=pw.make_tuple(pw.this.a, 2)
)
assert_table_equality(t1, expected)
assert_table_equality(t2, expected)
assert_table_equality(t3, expected)


def test_python_tuple_comparison():
t = T(
"""
a | b
1 | 2
4 | 3
5 | 5
"""
)
t1 = t.select(
x=pw.make_tuple(pw.this.a, pw.this.b) < (pw.this.b, pw.this.a),
y=(pw.this.a, pw.this.b) < pw.make_tuple(pw.this.b, pw.this.a),
z=pw.make_tuple(pw.this.a, pw.this.b) > (pw.this.b, pw.this.a),
t=(pw.this.a, pw.this.b) > pw.make_tuple(pw.this.b, pw.this.a),
e=pw.make_tuple(pw.this.a, pw.this.b) == (pw.this.b, pw.this.a),
n=(pw.this.a, pw.this.b) != pw.make_tuple(pw.this.b, pw.this.a),
)
expected = T(
"""
x | y | z | t | e | n
True | True | False | False | False | True
False | False | True | True | False | True
False | False | False | False | True | False
"""
)

assert_table_equality(t1, expected)


def test_python_tuple_inside_udf():
t = T(
"""
a | b | c
1 | 2 | 3
4 | 3 | 3
5 | 5 | 3
"""
)

@pw.udf
def foo(x: tuple) -> int:
return sum(x)

res = t.select(s=foo((pw.this.a, pw.this.c, pw.this.b)))

expected = T(
"""
s
6
10
13
"""
)
assert_table_equality(res, expected)


def test_python_tuple_if_else():
t = T(
"""
a | b | c
0 | 2 | 3
1 | 3 | 0
1 | 4 | 5
0 | 5 | 2
"""
)
res = t.select(
z=pw.if_else(
pw.this.a == 1, (pw.this.b, pw.this.c), (pw.this.c, pw.this.b)
).get(0)
)
expected = T(
"""
z
3
3
4
2
"""
)
assert_table_equality(res, expected)


def test_python_tuple_stateful_reducer():
@pw.reducers.stateful_single # type: ignore[arg-type]
def sum2d(state: int | None, values: tuple[int, int]) -> int:
if state is None:
state = 0
state += values[0] + values[1]
return state

t = T(
"""
a | b
1 | 2
3 | 4
"""
)
res = t.reduce(s=sum2d((pw.this.a, pw.this.b)))
expected = T(
"""
s
10
"""
)
assert_table_equality_wo_index_types(res, expected)


def test_python_tuple_sorting():
t = T(
"""
a | b | c
1 | 3 | 2
2 | 4 | 1
3 | 3 | 6
4 | 2 | 8
5 | 5 | 6
6 | 1 | 4
7 | 2 | 2
8 | 3 | 3
"""
)
sorted = t.sort(key=(pw.this.b, pw.this.c))
result = t.select(pw.this.a, prev_a=t.ix(sorted.prev, optional=True).a)
expected = T(
"""
a | prev_a
1 | 4
2 | 3
3 | 8
4 | 7
5 | 2
6 |
7 | 6
8 | 1
"""
)
assert_table_equality(result, expected)
33 changes: 33 additions & 0 deletions python/pathway/tests/test_deduplicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,36 @@ def acceptor(new_value, old_value) -> bool:
assert_stream_equality_wo_index(
result, expected_2, persistence_config=persistence_config
)


def test_deduplicate_python_tuple():
t = pw.debug.table_from_markdown(
"""
a | b | __time__
1 | 1 | 2
1 | 2 | 4
3 | 1 | 6
3 | 0 | 8
4 | 2 | 10
4 | 2 | 12
4 | 1 | 14
"""
)

def acceptor(new_value: tuple, old_value: tuple) -> bool:
return new_value > old_value

res = t.deduplicate(value=(pw.this.a, pw.this.b), acceptor=acceptor)
expected = pw.debug.table_from_markdown(
"""
a | b | __time__ | __diff__
1 | 1 | 2 | 1
1 | 1 | 4 | -1
1 | 2 | 4 | 1
1 | 2 | 6 | -1
3 | 1 | 6 | 1
3 | 1 | 10 | -1
4 | 2 | 10 | 1
"""
)
assert_stream_equality_wo_index(res, expected)

0 comments on commit 9900020

Please sign in to comment.