diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c6dcf03..9859774e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,14 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm - `pw.io.python.write` accepting `ConnectorObserver` as an alternative to `pw.io.subscribe`. - `pw.io.iceberg.read` and `pw.io.iceberg.write` now support S3 as data backend and AWS Glue catalog implementations. - All output connectors now support the `sort_by` field for ordering output within a single minibatch. +- A new UDF executor `pw.udfs.fully_async_executor`. It allows for creation of non-blocking asynchronous UDFs which results can be returned in the future processing time. +- A Future data type to represent results of fully asynchronous UDFs. +- `pw.Table.await_futures` method to wait for results of fully asynchronous UDFs. ### Changed - **BREAKING**: Changed the interface of `LLMReranker`, the `use_logit_bias`, `cache_strategy`, `retry_strategy` and `kwargs` arguments are no longer supported. - **BREAKING**: LLMReranker no longer inherits from pw.UDF +- **BREAKING**: `pw.stdlib.utils.AsyncTransformer.output_table` now returns a table with columns with Future data type. ## [0.18.0] - 2025-02-07 diff --git a/python/pathway/engine.pyi b/python/pathway/engine.pyi index 3f4d3ba1..2a284063 100644 --- a/python/pathway/engine.pyi +++ b/python/pathway/engine.pyi @@ -51,6 +51,8 @@ class PathwayType: PY_OBJECT_WRAPPER: PathwayType @staticmethod def optional(arg: PathwayType) -> PathwayType: ... + @staticmethod + def future(arg: PathwayType) -> PathwayType: ... class ConnectorMode(Enum): STATIC: ConnectorMode @@ -693,10 +695,11 @@ class Scope: def error_log(self, properties: ConnectorProperties) -> tuple[Table, ErrorLog]: ... def set_error_log(self, error_log: ErrorLog | None) -> None: ... def set_operator_properties(self, id: int, depends_on_error_log: bool) -> None: ... - def remove_errors_from_table( + def remove_value_from_table( self, table: Table, column_paths: Iterable[ColumnPath], + value: Value, table_properties: TableProperties, ) -> Table: ... def remove_retractions_from_table( @@ -704,11 +707,27 @@ class Scope: table: Table, table_properties: TableProperties, ) -> Table: ... + def async_transformer( + self, + table: Table, + column_paths: Iterable[ColumnPath], + on_change: Callable, + on_time_end: Callable, + on_end: Callable, + data_source: DataStorage, + data_format: DataFormat, + table_properties: ConnectorProperties, + skip_errors: bool, + ) -> Table: ... class Error: ... ERROR: Error +class Pending: ... + +PENDING: Pending + class Done: def __lt__(self, other: Frontier) -> bool: ... def __le__(self, other: Frontier) -> bool: ... diff --git a/python/pathway/internals/api.py b/python/pathway/internals/api.py index 068b5692..89e3c36b 100644 --- a/python/pathway/internals/api.py +++ b/python/pathway/internals/api.py @@ -34,6 +34,7 @@ dict[str, _Value], tuple[_Value, ...], Error, + Pending, ] CapturedTable = dict[Pointer, tuple[Value, ...]] CapturedStream = list[DataRow] diff --git a/python/pathway/internals/column.py b/python/pathway/internals/column.py index 06b7f80e..ca1ae7c2 100644 --- a/python/pathway/internals/column.py +++ b/python/pathway/internals/column.py @@ -13,12 +13,15 @@ import pathway.internals as pw from pathway.engine import ExternalIndexFactory from pathway.internals import column_properties as cp, dtype as dt, trace +from pathway.internals.datasource import GenericDataSource from pathway.internals.expression import ColumnExpression, ColumnReference from pathway.internals.helpers import SetOnceProperty, StableSet from pathway.internals.parse_graph import G +from pathway.internals.schema import Schema from pathway.internals.universe import Universe if TYPE_CHECKING: + from pathway.internals import api from pathway.internals.expression import InternalColRef from pathway.internals.operator import OutputHandle from pathway.internals.table import Table @@ -161,6 +164,23 @@ def properties(self) -> cp.ColumnProperties: return self._properties +class ColumnWithoutExpression(ColumnWithContext): + _dtype: dt.DType + + def __init__( + self, + context: Context, + universe: Universe, + dtype: dt.DType, + ) -> None: + super().__init__(context, universe) + self._dtype = dtype + + @cached_property + def context_dtype(self) -> dt.DType: + return self._dtype + + class ColumnWithExpression(ColumnWithContext): """Column holding expression and context.""" @@ -1103,12 +1123,16 @@ def id_column_type(self) -> dt.DType: @dataclass(eq=False, frozen=True) -class RemoveErrorsContext( +class FilterOutValueContext( Context, column_properties_evaluator=cp.PreserveDependenciesPropsEvaluator ): - """Context of `table.remove_errors() operation.""" + """Context of operations that filter all columns of the table. + + Used in `table.remove_errors()` and ``table.await_futures()` + """ orig_id_column: IdColumn + value_to_filter_out: api.Value def column_dependencies_external(self) -> Iterable[Column]: return [self.orig_id_column] @@ -1144,3 +1168,31 @@ def id_column_type(self) -> dt.DType: @cached_property def universe(self) -> Universe: return self.id_column_to_filter.universe.superset() + + +@dataclass(eq=False, frozen=True) +class AsyncTransformerContext( + Context, column_properties_evaluator=cp.PreserveDependenciesPropsEvaluator +): + """Context of `AsyncTransformer` operation.""" + + input_id_column: IdColumn + input_columns: list[Column] + schema: type[Schema] + on_change: Callable + on_time_end: Callable + on_end: Callable + datasource: GenericDataSource + + def column_dependencies_external(self) -> Iterable[Column]: + return [self.input_id_column] + self.input_columns + + def input_universe(self) -> Universe: + return self.input_id_column.universe + + def id_column_type(self) -> dt.DType: + return self.input_id_column.dtype + + @cached_property + def universe(self) -> Universe: + return self.input_id_column.universe.subset() diff --git a/python/pathway/internals/column_properties.py b/python/pathway/internals/column_properties.py index 0c721881..7ff63158 100644 --- a/python/pathway/internals/column_properties.py +++ b/python/pathway/internals/column_properties.py @@ -1,4 +1,4 @@ -# Copyright © 2024 Pathway +# Copyright © 2025 Pathway from __future__ import annotations @@ -7,10 +7,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any -from pathway.internals import dtype as dt - if TYPE_CHECKING: - import pathway.internals.column as clmn + from pathway.internals import column as clmn, dtype as dt @dataclass(frozen=True) @@ -39,7 +37,8 @@ def _append_only(self, column: clmn.ColumnWithContext) -> bool: class PreserveDependenciesPropsEvaluator(ColumnPropertiesEvaluator): def _append_only(self, column: clmn.ColumnWithContext): - return self._has_property(column, "append_only", True) + maybe_append_only = self._check_expression(column) + return maybe_append_only and self._has_property(column, "append_only", True) def _has_property(self, column: clmn.ColumnWithContext, name: str, value: Any): return all( @@ -47,6 +46,21 @@ def _has_property(self, column: clmn.ColumnWithContext, name: str, value: Any): for col in column.column_dependencies() ) + def _check_expression(self, column: clmn.ColumnWithContext) -> bool: + from pathway.internals.column import ColumnWithExpression + from pathway.internals.expression_props_evaluator import ( + ExpressionPropsEvaluator, + PropsEvaluatorState, + ) + + if isinstance(column, ColumnWithExpression): + evaluator = ExpressionPropsEvaluator() + props = PropsEvaluatorState(True) + evaluator.eval_expression(column.expression, props=props) + return props.append_only + else: + return True + class UpdateRowsPropsEvaluator(ColumnPropertiesEvaluator): context: clmn.UpdateRowsContext diff --git a/python/pathway/internals/dtype.py b/python/pathway/internals/dtype.py index dfbbcd29..8a2f9224 100644 --- a/python/pathway/internals/dtype.py +++ b/python/pathway/internals/dtype.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import collections import datetime import functools @@ -16,8 +17,8 @@ import numpy.typing as npt import pandas as pd -from pathway.engine import PathwayType from pathway.internals import api, datetime_types, json as js +from pathway.internals.api import PathwayType if typing.TYPE_CHECKING: from pathway.internals.schema import SchemaMetaclass @@ -457,6 +458,35 @@ def max_size(self) -> float: return math.inf +class Future(DType): + wrapped: DType + + def __repr__(self): + return f"Future({self.wrapped})" + + def __new__(cls, arg: DType) -> Future: + arg = wrap(arg) + if isinstance(arg, Future): + return arg + return super().__new__(cls, arg) + + def _set_args(self, wrapped): + self.wrapped = wrapped + + def to_engine(self) -> PathwayType: + return api.PathwayType.future(self.wrapped.to_engine()) + + def is_value_compatible(self, arg): + return arg is api.PENDING or self.wrapped.is_value_compatible(arg) + + @cached_property + def typehint(self) -> type[asyncio.Future]: + return asyncio.Future[self.wrapped.typehint] # type: ignore[name-defined] + + def max_size(self) -> float: + return self.wrapped.max_size() + + class _DateTimeNaive(DType): def __repr__(self): return "DATE_TIME_NAIVE" @@ -658,6 +688,10 @@ def wrap(input_type) -> DType: ) elif input_type == datetime.timedelta: raise TypeError(f"Unsupported type {input_type}, use pw.DURATION") + elif typing.get_origin(input_type) == asyncio.Future: + args = get_args(input_type) + (arg,) = args + return Future(wrap(arg)) else: dtype = { int: INT, diff --git a/python/pathway/internals/expression.py b/python/pathway/internals/expression.py index 35ae9d8b..12eb24eb 100644 --- a/python/pathway/internals/expression.py +++ b/python/pathway/internals/expression.py @@ -9,12 +9,12 @@ from typing import TYPE_CHECKING, Any, Mapping, cast from pathway.internals import api, dtype as dt, helpers -from pathway.internals.api import Value from pathway.internals.operator_input import OperatorInput from pathway.internals.shadows import operator from pathway.internals.trace import Trace if TYPE_CHECKING: + from pathway.internals.api import Value from pathway.internals.column import Column, ColumnWithExpression from pathway.internals.expressions import ( DateTimeNamespace, @@ -745,6 +745,7 @@ class ApplyExpression(ColumnExpression): _return_type: dt.DType _propagate_none: bool _deterministic: bool + _check_for_disallowed_types: bool _args: tuple[ColumnExpression, ...] _kwargs: dict[str, ColumnExpression] _fun: Callable @@ -757,15 +758,14 @@ def __init__( deterministic: bool, args: tuple[ColumnExpression | Value, ...], kwargs: Mapping[str, ColumnExpression | Value], + _check_for_disallowed_types: bool = True, ): super().__init__() self._fun = fun - return_type = dt.wrap(return_type) - if propagate_none: - return_type = dt.Optional(return_type) - self._return_type = return_type + self._return_type = dt.wrap(return_type) self._propagate_none = propagate_none self._deterministic = deterministic + self._check_for_disallowed_types = _check_for_disallowed_types self._args = tuple(ColumnExpression._wrap(arg) for arg in args) @@ -783,15 +783,40 @@ def _to_internal(self) -> InternalColExpr: self._return_type, self._propagate_none, self._deterministic, + self._check_for_disallowed_types, *self._args, **self._kwargs, ) + @property + def _maybe_optional_return_type(self) -> dt.DType: + if self._propagate_none: + return dt.Optional(self._return_type) + else: + return self._return_type + class AsyncApplyExpression(ApplyExpression): pass +class FullyAsyncApplyExpression(ApplyExpression): + autocommit_duration_ms: int | None + + def __init__( + self, + fun: Callable, + return_type: Any, + propagate_none: bool, + deterministic: bool, + autocommit_duration_ms: int | None, + args: tuple[ColumnExpression | Value, ...], + kwargs: Mapping[str, ColumnExpression | Value], + ): + super().__init__(fun, return_type, propagate_none, deterministic, args, kwargs) + self.autocommit_duration_ms = autocommit_duration_ms + + class CastExpression(ColumnExpression): _return_type: dt.DType _expr: ColumnExpression diff --git a/python/pathway/internals/expression_printer.py b/python/pathway/internals/expression_printer.py index 2e544977..1571a55d 100644 --- a/python/pathway/internals/expression_printer.py +++ b/python/pathway/internals/expression_printer.py @@ -163,6 +163,10 @@ def eval_fill_error(self, expression: expr.FillErrorExpression): args = self._eval_args_kwargs((expression._expr, expression._replacement)) return f"pathway.fill_error({args})" + def eval_fully_async_apply(self, expression: expr.FullyAsyncApplyExpression): + args = self._eval_args_kwargs(expression._args, expression._kwargs) + return f"pathway.apply_fully_async({expression._fun.__name__}, {args})" + def get_expression_info(expression: expr.ColumnExpression) -> str: printer = ExpressionFormatter() diff --git a/python/pathway/internals/expression_props_evaluator.py b/python/pathway/internals/expression_props_evaluator.py new file mode 100644 index 00000000..5da08483 --- /dev/null +++ b/python/pathway/internals/expression_props_evaluator.py @@ -0,0 +1,25 @@ +# Copyright © 2025 Pathway + +from __future__ import annotations + +from dataclasses import dataclass + +from pathway.internals import expression as expr +from pathway.internals.expression_visitor import IdentityTransform + + +@dataclass +class PropsEvaluatorState: + append_only: bool + + +class ExpressionPropsEvaluator(IdentityTransform): + def eval_fully_async_apply( + self, + expression: expr.FullyAsyncApplyExpression, + props: PropsEvaluatorState | None = None, + **kwargs, + ) -> expr.FullyAsyncApplyExpression: + assert props is not None + props.append_only = False + return super().eval_fully_async_apply(expression, props=props, **kwargs) diff --git a/python/pathway/internals/expression_visitor.py b/python/pathway/internals/expression_visitor.py index 0ad4ff7c..5860152b 100644 --- a/python/pathway/internals/expression_visitor.py +++ b/python/pathway/internals/expression_visitor.py @@ -37,6 +37,7 @@ def eval_expression(self, expression, **kwargs): expr.IsNoneExpression: self.eval_none, expr.UnwrapExpression: self.eval_unwrap, expr.FillErrorExpression: self.eval_fill_error, + expr.FullyAsyncApplyExpression: self.eval_fully_async_apply, } if not isinstance(expression, expr.ColumnExpression): return self.eval_any(expression, **kwargs) @@ -108,6 +109,9 @@ def eval_unwrap(self, expression: expr.UnwrapExpression): ... @abstractmethod def eval_fill_error(self, expression: expr.FillErrorExpression): ... + @abstractmethod + def eval_fully_async_apply(self, expression: expr.FullyAsyncApplyExpression): ... + def eval_any(self, expression, **kwargs): expression = expr.ColumnConstExpression(expression) return self.eval_const(expression, **kwargs) @@ -173,6 +177,7 @@ def eval_apply( deterministic=expression._deterministic, args=expr_args, kwargs=expr_kwargs, + _check_for_disallowed_types=expression._check_for_disallowed_types, ) def eval_async_apply( @@ -192,6 +197,24 @@ def eval_async_apply( kwargs=expr_kwargs, ) + def eval_fully_async_apply( + self, expression: expr.FullyAsyncApplyExpression, **kwargs + ) -> expr.FullyAsyncApplyExpression: + expr_args = [self.eval_expression(arg, **kwargs) for arg in expression._args] + expr_kwargs = { + name: self.eval_expression(arg, **kwargs) + for name, arg in expression._kwargs.items() + } + return expr.FullyAsyncApplyExpression( + expression._fun, + expression._return_type, + propagate_none=expression._propagate_none, + deterministic=expression._deterministic, + autocommit_duration_ms=expression.autocommit_duration_ms, + args=tuple(expr_args), + kwargs=expr_kwargs, + ) + def eval_pointer( self, expression: expr.PointerExpression, **kwargs ) -> expr.PointerExpression: diff --git a/python/pathway/internals/graph_runner/expression_evaluator.py b/python/pathway/internals/graph_runner/expression_evaluator.py index f558a79a..6c5418fe 100644 --- a/python/pathway/internals/graph_runner/expression_evaluator.py +++ b/python/pathway/internals/graph_runner/expression_evaluator.py @@ -29,6 +29,7 @@ get_convert_operators_mapping, get_unary_expression, ) +from pathway.internals.schema import schema_from_types from pathway.internals.udfs import udf if TYPE_CHECKING: @@ -203,6 +204,8 @@ def test_type(val): return val ret = test_type(expression) + assert isinstance(ret, expr.ApplyExpression) + ret._check_for_disallowed_types = False ret._dtype = dtype return ret @@ -270,15 +273,16 @@ def run( # START temporary solution for eval_async_apply for intermediate_storage in eval_state.storages: - [column] = intermediate_storage.get_columns() properties = self._table_properties(intermediate_storage) - engine_input_table = self.scope.override_table_universe( - eval_state.get_temporary_table(intermediate_storage), + # restrict instead of override because of edge case in fully async UDF + # with missing rows. + engine_input_table = self.scope.restrict_table( engine_input_table, + eval_state.get_temporary_table(intermediate_storage), properties, ) input_storage = Storage.merge_storages( - self.context.universe, input_storage, intermediate_storage + self.context.universe, intermediate_storage, input_storage ) # END temporary solution for eval_async_apply @@ -307,7 +311,9 @@ def run_subexpressions( output_columns.append(output_column) output_storage = Storage.flat(self.context.universe, output_columns) - engine_output_table = self.run(output_storage, old_path=None) + engine_output_table = self.run( + output_storage, old_path=None, disable_runtime_typechecking=True + ) # checks already added in the main call to run return (output_columns, output_storage, engine_output_table) def eval_expression( @@ -473,6 +479,80 @@ def eval_async_apply( eval_state.set_temporary_table(output_storage, engine_table) return self.eval_dependency(tmp_column, eval_state=eval_state) + def eval_fully_async_apply( + self, + expression: expr.FullyAsyncApplyExpression, + eval_state: RowwiseEvalState | None = None, + ): + fun, args = self._prepare_positional_apply( + fun=expression._fun, + args=expression._args, + kwargs=expression._kwargs, + ) + + columns, input_storage, engine_input_table = self.run_subexpressions(args) + tmp_column = clmn.MaterializedColumn( + self.context.universe, ColumnProperties(dtype=expression._dtype) + ) + status_column = clmn.MaterializedColumn( + self.context.universe, ColumnProperties(dt.STR) + ) + from pathway.stdlib.utils.async_transformer import ( + _ASYNC_STATUS_COLUMN, + _BaseAsyncTransformer, + ) + + output_columns = { + "result": tmp_column, + _ASYNC_STATUS_COLUMN: status_column, + } + schema = schema_from_types(result=expression._dtype) + + class Transformer(_BaseAsyncTransformer, output_schema=schema): + async def invoke(self, **kwargs) -> dict: + args = [] + for i, (name, arg) in enumerate(kwargs.items()): + if arg is api.ERROR: + return dict(result=api.ERROR) + if arg is None and expression._propagate_none: + return dict(result=None) + assert f"{i}" == name + args.append(arg) + try: + return dict(result=await fun(*args)) + except ( + Exception + ): # FIXME: remove when AsyncTransformer returns `api.ERROR` for failure + self._connector._logger.error( + "Exception in fully_async_udf:", exc_info=True + ) + return dict(result=api.ERROR) + + transformer = Transformer(expression.autocommit_duration_ms) + + ordered_output_columns = [ + output_columns[name] + for name in transformer.wrapped_output_schema._dtypes().keys() + ] + output_storage = Storage.flat(self.context.universe, ordered_output_columns) + paths = [input_storage.get_path(column) for column in columns] + datasource = transformer._get_datasource() + engine_table = self.scope.async_transformer( + engine_input_table, + paths, + transformer._connector.on_subscribe_change, + transformer._connector.on_subscribe_time_end, + transformer._connector.on_subscribe_end, + datasource.datastorage, + datasource.dataformat, + datasource.connector_properties, + skip_errors=False, + ) + + assert eval_state is not None + eval_state.set_temporary_table(output_storage, engine_table) + return self.eval_dependency(tmp_column, eval_state=eval_state) + def eval_cast( self, expression: expr.CastExpression, @@ -1345,8 +1425,10 @@ def run(self, output_storage: Storage) -> api.Table: return self.state.get_table(self.context.universe) -class RemoveErrorsEvaluator(ExpressionEvaluator, context_type=clmn.RemoveErrorsContext): - context: clmn.RemoveErrorsContext +class FilterOutValueEvaluator( + ExpressionEvaluator, context_type=clmn.FilterOutValueContext +): + context: clmn.FilterOutValueContext def run(self, output_storage: Storage) -> api.Table: input_storage = self.state.get_storage(self.context.input_universe()) @@ -1356,9 +1438,10 @@ def run(self, output_storage: Storage) -> api.Table: path = input_storage.get_path(column.expression._column) column_paths.append(path) properties = self._table_properties(output_storage) - return self.scope.remove_errors_from_table( + return self.scope.remove_value_from_table( self.state.get_table(input_storage._universe), column_paths, + self.context.value_to_filter_out, properties, ) @@ -1375,3 +1458,26 @@ def run(self, output_storage: Storage) -> api.Table: self.state.get_table(input_storage._universe), properties, ) + + +class AsyncTransformerEvaluator( + ExpressionEvaluator, context_type=clmn.AsyncTransformerContext +): + context: clmn.AsyncTransformerContext + + def run(self, output_storage: Storage) -> api.Table: + input_storage = self.state.get_storage(self.context.input_universe()) + column_paths = [ + input_storage.get_path(column) for column in self.context.input_columns + ] + return self.scope.async_transformer( + self.state.get_table(input_storage._universe), + column_paths, + self.context.on_change, + self.context.on_time_end, + self.context.on_end, + self.context.datasource.datastorage, + self.context.datasource.dataformat, + self.context.datasource.connector_properties, + skip_errors=True, + ) diff --git a/python/pathway/internals/graph_runner/path_evaluator.py b/python/pathway/internals/graph_runner/path_evaluator.py index c8f88e85..6e943fd1 100644 --- a/python/pathway/internals/graph_runner/path_evaluator.py +++ b/python/pathway/internals/graph_runner/path_evaluator.py @@ -22,11 +22,12 @@ def compute_paths( input_storages: dict[Universe, Storage], operator: op.Operator, context: clmn.Context, + table_columns: Iterable[clmn.Column], ): evaluator: PathEvaluator match operator: case op.InputOperator(): - evaluator = FlatStoragePathEvaluator(context) + evaluator = FlatOrderedStoragePathEvaluator(context) case op.RowTransformerOperator(): evaluator = FlatStoragePathEvaluator(context) case op.ContextualizedIntermediateOperator(): @@ -36,7 +37,7 @@ def compute_paths( f"Operator {operator} in update_storage() but it shouldn't produce tables." ) output_columns = list(output_columns) - return evaluator.compute(output_columns, input_storages).restrict_to( + return evaluator.compute(output_columns, input_storages, table_columns).restrict_to( output_columns, require_all=True ) @@ -77,6 +78,7 @@ def compute( self, output_columns: Iterable[clmn.Column], input_storages: dict[Universe, Storage], + table_columns: Iterable[clmn.Column], ) -> Storage: ... _context_mapping: ClassVar[dict[type[clmn.Context], type[PathEvaluator]]] = {} @@ -93,12 +95,16 @@ def for_context(cls, context: clmn.Context) -> type[PathEvaluator]: class FlatStoragePathEvaluator( PathEvaluator, - context_types=[clmn.GroupedContext, clmn.RemoveErrorsContext], + context_types=[ + clmn.GroupedContext, + clmn.FilterOutValueContext, + ], ): def compute( self, output_columns: Iterable[clmn.Column], input_storages: dict[Universe, Storage], + table_columns: Iterable[clmn.Column], ) -> Storage: return Storage.flat(self.context.universe, output_columns) @@ -111,6 +117,7 @@ def compute( self, output_columns: Iterable[clmn.Column], input_storages: dict[Universe, Storage], + table_columns: Iterable[clmn.Column], ) -> Storage: return Storage.flat(self.context.universe, output_columns, shift=1) @@ -162,6 +169,7 @@ def compute( self, output_columns: Iterable[clmn.Column], input_storages: dict[Universe, Storage], + table_columns: Iterable[clmn.Column], ) -> Storage: input_storage = input_storages[self.context.universe] output_columns = list(output_columns) @@ -192,6 +200,7 @@ def compute( self, output_columns: Iterable[clmn.Column], input_storages: dict[Universe, Storage], + table_columns: Iterable[clmn.Column], ) -> Storage: input_storage = input_storages[self.context.universe] return Storage.merge_storages( @@ -212,6 +221,7 @@ def compute( self, output_columns: Iterable[clmn.Column], input_storages: dict[Universe, Storage], + table_columns: Iterable[clmn.Column], ) -> Storage: context = self.context output_columns_list = list(output_columns) @@ -275,6 +285,7 @@ def compute( storage = evaluator.compute( output_columns_list, {source_universe: input_storages[source_universe]}, + table_columns, ) return storage.with_maybe_flattened_inputs(flattened_inputs) @@ -315,6 +326,7 @@ def compute( self, output_columns: Iterable[clmn.Column], input_storages: dict[Universe, Storage], + table_columns: Iterable[clmn.Column], ) -> Storage: input_storage = input_storages[self.context.input_universe()] paths: dict[clmn.Column, ColumnPath] = {} @@ -342,6 +354,7 @@ def compute( self, output_columns: Iterable[clmn.Column], input_storages: dict[Universe, Storage], + table_columns: Iterable[clmn.Column], ) -> Storage: input_storage = input_storages[self.context.input_universe()] required_columns: StableSet[clmn.Column] = StableSet() @@ -397,6 +410,7 @@ def compute( self, output_columns: Iterable[clmn.Column], input_storages: dict[Universe, Storage], + table_columns: Iterable[clmn.Column], ) -> Storage: left_storage, right_storage = self.maybe_flatten_input_storages( output_columns, input_storages @@ -488,6 +502,7 @@ def compute( self, output_columns: Iterable[clmn.Column], input_storages: dict[Universe, Storage], + table_columns: Iterable[clmn.Column], ) -> Storage: output_columns = list(output_columns) left_input_storage, right_input_storage = self.maybe_flatten_input_storages( @@ -496,11 +511,15 @@ def compute( join_storage = self.merge_storages(left_input_storage, right_input_storage) if self.context.assign_id: output_storage = AddNewColumnsPathEvaluator(self.context).compute( - output_columns, {self.context.universe: left_input_storage} + output_columns, + {self.context.universe: left_input_storage}, + table_columns, ) else: output_storage = FlatStoragePathEvaluator(self.context).compute( - output_columns, {} + output_columns, + {}, + table_columns, ) return output_storage.with_maybe_flattened_inputs( { @@ -518,6 +537,7 @@ def compute( self, output_columns: Iterable[clmn.Column], input_storages: dict[Universe, Storage], + table_columns: Iterable[clmn.Column], ) -> Storage: prefixed_input_storage = input_storages[self.context.orig_universe].with_prefix( (0,) @@ -555,6 +575,7 @@ def compute( self, output_columns: Iterable[clmn.Column], input_storages: dict[Universe, Storage], + table_columns: Iterable[clmn.Column], ) -> Storage: orig_storage_columns: StableSet[clmn.Column] = StableSet() newly_created_columns: StableSet[clmn.ColumnWithReference] = StableSet() @@ -589,3 +610,18 @@ def compute( {"orig_storage": orig_storage, "new_storage": new_storage} ) ) + + +class FlatOrderedStoragePathEvaluator( + PathEvaluator, + context_types=[ + clmn.AsyncTransformerContext, + ], +): + def compute( + self, + output_columns: Iterable[clmn.Column], + input_storages: dict[Universe, Storage], + table_columns: Iterable[clmn.Column], + ) -> Storage: + return Storage.flat(self.context.universe, table_columns) diff --git a/python/pathway/internals/graph_runner/storage_graph.py b/python/pathway/internals/graph_runner/storage_graph.py index b3a45d50..caed5b1a 100644 --- a/python/pathway/internals/graph_runner/storage_graph.py +++ b/python/pathway/internals/graph_runner/storage_graph.py @@ -237,6 +237,7 @@ def _compute_storage_paths_ordinary( storages, operator, table._id_column.context, + table._columns.values(), ) if path_storage.max_depth > 3: # TODO: 3 is arbitrarily specified number. Check what's best. @@ -253,11 +254,13 @@ def _compute_storage_paths_input( self, operator: Operator, storages: dict[Universe, Storage] ) -> None: for table in operator.output_tables: + output_columns = self.column_deps_at_output[operator][table] path_storage = path_evaluator.compute_paths( - table._columns.values(), + output_columns, {}, operator, table._id_column.context, + table._columns.values(), ) self.output_storages[operator][table] = path_storage assert table._universe not in storages diff --git a/python/pathway/internals/table.py b/python/pathway/internals/table.py index 2e168dc1..979c1b11 100644 --- a/python/pathway/internals/table.py +++ b/python/pathway/internals/table.py @@ -1008,6 +1008,7 @@ def groupby( "All Table.groupby() arguments have to be a ColumnReference." ) + self._check_for_disallowed_types(*args) return groupbys.GroupedTable.create( table=self, grouping_columns=args, @@ -1147,6 +1148,7 @@ def deduplicate( _value = value self._validate_expression(_value) self._validate_expression(instance) + self._check_for_disallowed_types(_value, instance) value_col = self._eval(_value) instance_col = self._eval(instance) @@ -2210,6 +2212,7 @@ def sort( ^T0B95XH... | Eve | 15 | 80 | | ^GBSDEEW... """ instance = clmn.ColumnExpression._wrap(instance) + self._check_for_disallowed_types(key, instance) context = clmn.SortingContext( self._eval(key), self._eval(instance), @@ -2261,6 +2264,15 @@ def _validate_expression(self, expression: expr.ColumnExpression): + " sets of keys are equal." ) + def _check_for_disallowed_types(self, *expressions: expr.ColumnExpression): + for expression in expressions: + dtype = self.eval_type(expression) + if isinstance(dtype, dt.Future): + raise TypeError( + f"Using column of type {dtype.typehint} is not allowed here." + + " Consider applying `await_futures()` to the table first." + ) + def _wrap_column_in_context( self, context: clmn.Context, @@ -2511,7 +2523,69 @@ def remove_errors(self) -> Table[TSchema]: 5 | 5 | 1 6 | 2 | 3 """ - context = clmn.RemoveErrorsContext(self._id_column) + context = clmn.FilterOutValueContext(self._id_column, api.ERROR) + return self._table_with_context(context) + + def await_futures(self) -> Table[TSchema]: + """Waits for the results of asynchronous computation. + + It strips the ``Future`` wrapper from table columns where applicable. In practice, + it filters out the ``Pending`` values and produces a column with a data type that + was the argument of `Future`. + + Columns of `Future` data type are produced by fully asynchronous UDFs. Columns of + this type can be propagated further, but can't be used in most expressions + (e.g. arithmetic operations). You can wait for their results using this method + and later use the results in expressions you want. + + Example: + + >>> import pathway as pw + >>> import asyncio + >>> + >>> t = pw.debug.table_from_markdown( + ... ''' + ... a | b + ... 1 | 2 + ... 3 | 4 + ... 5 | 6 + ... ''' + ... ) + >>> + >>> @pw.udf(executor=pw.udfs.fully_async_executor()) + ... async def long_running_async_function(a: int, b: int) -> int: + ... c = a * b + ... await asyncio.sleep(0.1 * c) + ... return c + ... + >>> + >>> result = t.with_columns(res=long_running_async_function(pw.this.a, pw.this.b)) + >>> print(result.schema) + id | a | b | res + ANY_POINTER | INT | INT | Future(INT) + >>> + >>> awaited_result = result.await_futures() + >>> print(awaited_result.schema) + id | a | b | res + ANY_POINTER | INT | INT | INT + >>> pw.debug.compute_and_print(awaited_result, include_id=False) + a | b | res + 1 | 2 | 2 + 3 | 4 | 12 + 5 | 6 | 30 + """ + result = self._await_futures() + new_dtypes = {} + for name, column in result._columns.items(): + if isinstance(column.dtype, dt.Future): + new_dtypes[name] = column.dtype.wrapped + new_schema = self.schema.with_types(**new_dtypes) + return result._with_schema(new_schema) + + @trace_user_frame + @contextualized_operator + def _await_futures(self) -> Table[TSchema]: + context = clmn.FilterOutValueContext(self._id_column, api.PENDING) return self._table_with_context(context) @contextualized_operator @@ -2519,6 +2593,14 @@ def _remove_retractions(self) -> Table[TSchema]: context = clmn.RemoveRetractionsContext(self._id_column) return self._table_with_context(context) + @contextualized_operator + def _async_transformer(self, context: clmn.AsyncTransformerContext) -> Table: + columns = { + name: clmn.ColumnWithoutExpression(context, context.universe, dtype) + for name, dtype in context.schema._dtypes().items() + } + return Table(_columns=columns, _context=context) + def _subtables(self) -> StableSet[Table]: return StableSet([self]) diff --git a/python/pathway/internals/type_interpreter.py b/python/pathway/internals/type_interpreter.py index afd37932..2f2db462 100644 --- a/python/pathway/internals/type_interpreter.py +++ b/python/pathway/internals/type_interpreter.py @@ -221,6 +221,11 @@ def eval_reducer( **kwargs, ) -> expr.ReducerExpression: expression = super().eval_reducer(expression, state=state, **kwargs) + args_dtypes = [e._dtype for e in expression._args] + kwargs_dtypes = [e._dtype for e in expression._kwargs.values()] + self._check_for_disallowed_types( + f"pathway.reducers.{expression._reducer.name}", *args_dtypes, *kwargs_dtypes + ) return _wrap( expression, expression._reducer.return_type( @@ -238,7 +243,13 @@ def eval_apply( **kwargs, ) -> expr.ApplyExpression: expression = super().eval_apply(expression, state=state, **kwargs) - return _wrap(expression, expression._return_type) + args_dtypes = [e._dtype for e in expression._args] + kwargs_dtypes = [e._dtype for e in expression._kwargs.values()] + if expression._check_for_disallowed_types: + self._check_for_disallowed_types( + "pathway.apply", *args_dtypes, *kwargs_dtypes + ) + return _wrap(expression, expression._maybe_optional_return_type) def eval_async_apply( self, @@ -247,7 +258,21 @@ def eval_async_apply( **kwargs, ) -> expr.AsyncApplyExpression: expression = super().eval_async_apply(expression, state=state, **kwargs) - return _wrap(expression, expression._return_type) + args_dtypes = [e._dtype for e in expression._args] + kwargs_dtypes = [e._dtype for e in expression._kwargs.values()] + self._check_for_disallowed_types( + "pathway.apply_async", *args_dtypes, *kwargs_dtypes + ) + return _wrap(expression, expression._maybe_optional_return_type) + + def eval_fully_async_apply( + self, + expression: expr.FullyAsyncApplyExpression, + state: TypeInterpreterState | None = None, + **kwargs, + ) -> expr.FullyAsyncApplyExpression: + expression = super().eval_fully_async_apply(expression, state=state, **kwargs) + return _wrap(expression, dt.Future(expression._maybe_optional_return_type)) def eval_call( self, @@ -275,6 +300,9 @@ def eval_pointer( ) -> expr.PointerExpression: expression = super().eval_pointer(expression, state=state, **kwargs) arg_types = [arg._dtype for arg in expression._args] + if expression._instance is not None: + arg_types.append(expression._instance._dtype) + self._check_for_disallowed_types("pathway.pointer_from", *arg_types) if expression._optional and any( isinstance(arg, dt.Optional) or arg == dt.ANY for arg in arg_types ): @@ -327,6 +355,7 @@ def eval_coalesce( ) -> expr.CoalesceExpression: expression = super().eval_coalesce(expression, state=state, **kwargs) dtypes = [arg._dtype for arg in expression._args] + self._check_for_disallowed_types("pathway.coalesce", *dtypes) ret_type = dtypes[0] non_optional_arg = False for dtype in dtypes: @@ -360,10 +389,12 @@ def eval_require( args = [ self.eval_expression(arg, state=state, **kwargs) for arg in expression._args ] + arg_dtypes = [arg._dtype for arg in args] new_state = state.with_new_col( [arg for arg in expression._args if isinstance(arg, expr.ColumnReference)] ) val = self.eval_expression(expression._val, state=new_state, **kwargs) + self._check_for_disallowed_types("pathway.require", val._dtype, *arg_dtypes) expression = expr.RequireExpression(val, *args) ret_type = dt.Optional(val._dtype) return _wrap(expression, ret_type) @@ -375,6 +406,7 @@ def eval_not_none( **kwargs, ) -> expr.IsNotNoneExpression: ret = super().eval_not_none(expression, state=state, **kwargs) + self._check_for_disallowed_types("pathway.is_not_none", ret._expr._dtype) return _wrap(ret, dt.BOOL) def eval_none( @@ -384,6 +416,7 @@ def eval_none( **kwargs, ) -> expr.IsNoneExpression: ret = super().eval_none(expression, state=state, **kwargs) + self._check_for_disallowed_types("pathway.is_none", ret._expr._dtype) return _wrap(ret, dt.BOOL) def eval_ifelse( @@ -393,7 +426,7 @@ def eval_ifelse( **kwargs, ) -> expr.IfElseExpression: assert state is not None - if_ = self.eval_expression(expression._if, state=state) + if_ = self.eval_expression(expression._if, state=state, **kwargs) if_dtype = if_._dtype if if_dtype != dt.BOOL: raise TypeError( @@ -404,19 +437,19 @@ def eval_ifelse( if_._expr, expr.ColumnReference ): then_ = self.eval_expression( - expression._then, state=state.with_new_col([if_._expr]) + expression._then, state=state.with_new_col([if_._expr]), **kwargs ) else: - then_ = self.eval_expression(expression._then, state=state) + then_ = self.eval_expression(expression._then, state=state, **kwargs) if isinstance(if_, expr.IsNoneExpression) and isinstance( if_._expr, expr.ColumnReference ): else_ = self.eval_expression( - expression._else, state=state.with_new_col([if_._expr]) + expression._else, state=state.with_new_col([if_._expr], **kwargs) ) else: - else_ = self.eval_expression(expression._else, state=state) + else_ = self.eval_expression(expression._else, state=state, **kwargs) then_dtype = then_._dtype else_dtype = else_._dtype @@ -441,16 +474,9 @@ def eval_make_tuple( ) -> expr.MakeTupleExpression: expression = super().eval_make_tuple(expression, state=state, **kwargs) dtypes = tuple(arg._dtype for arg in expression._args) + self._check_for_disallowed_types("pathway.make_tuple", *dtypes) return _wrap(expression, dt.Tuple(*dtypes)) - def _eval_json_get( - self, - expression: expr.GetExpression, - state: TypeInterpreterState | None = None, - **kwargs, - ) -> expr.GetExpression: - return _wrap(expression, dt.JSON) - def eval_get( self, expression: expr.GetExpression, @@ -570,6 +596,7 @@ def eval_unwrap( ) -> expr.UnwrapExpression: expression = super().eval_unwrap(expression, state=state, **kwargs) dtype = expression._expr._dtype + self._check_for_disallowed_types("pathway.unwrap", dtype) return _wrap(expression, dt.unoptionalize(dtype)) def eval_fill_error( @@ -589,6 +616,19 @@ def eval_fill_error( ) return _wrap(expression, lca) + def _check_for_disallowed_types(self, name: str, *dtypes: dt.DType) -> None: + disallowed_dtypes: list[dt.DType] = [] + for dtype in dtypes: + if isinstance(dtype, dt.Future): + disallowed_dtypes.append(dtype) + if disallowed_dtypes: + dtypes_repr = ", ".join(f"{dtype.typehint}" for dtype in disallowed_dtypes) + # adjust message if more than dt.Future is involved + raise TypeError( + f"Cannot perform {name} when column of type {dtypes_repr} is involved." + + " Consider applying `await_futures()` to the table used here." + ) + class ReducerInterprerer(TypeInterpreter): id_column_type: dt.DType diff --git a/python/pathway/internals/udfs/__init__.py b/python/pathway/internals/udfs/__init__.py index 4662d922..6f01d704 100644 --- a/python/pathway/internals/udfs/__init__.py +++ b/python/pathway/internals/udfs/__init__.py @@ -26,6 +26,7 @@ async_executor, async_options, auto_executor, + fully_async_executor, sync_executor, with_capacity, with_timeout, @@ -45,6 +46,7 @@ "auto_executor", "async_executor", "sync_executor", + "fully_async_executor", "CacheStrategy", "DefaultCache", "DiskCache", @@ -203,6 +205,7 @@ def __call__(self, *args, **kwargs) -> expr.ColumnExpression: return_type=self._get_return_type(), propagate_none=self.propagate_none, deterministic=self.deterministic, + **self.executor.additional_expression_args(), args=args, kwargs=kwargs, ) diff --git a/python/pathway/internals/udfs/executors.py b/python/pathway/internals/udfs/executors.py index 4a8149aa..1716893e 100644 --- a/python/pathway/internals/udfs/executors.py +++ b/python/pathway/internals/udfs/executors.py @@ -8,7 +8,7 @@ import sys from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import ParamSpec, TypeVar +from typing import Any, ParamSpec, TypeVar import pathway.internals.expression as expr from pathway.internals.runtime_type_check import check_arg_types @@ -31,6 +31,9 @@ def _wrap(self, fun: Callable) -> Callable: ... @abc.abstractmethod def _apply_expression_type(self) -> type[expr.ApplyExpression]: ... + def additional_expression_args(self) -> dict[str, Any]: + return {} + @dataclass class AutoExecutor(Executor): @@ -219,6 +222,104 @@ def async_executor( ) +@dataclass(frozen=True, kw_only=True) +class FullyAsyncExecutor(AsyncExecutor): + autocommit_duration_ms: int | None + + @property + def _apply_expression_type(self) -> type[expr.ApplyExpression]: + return expr.FullyAsyncApplyExpression + + def additional_expression_args(self) -> dict[str, Any]: + return dict(autocommit_duration_ms=self.autocommit_duration_ms) + + +def fully_async_executor( + *, + capacity: int | None = None, + timeout: float | None = None, + retry_strategy: AsyncRetryStrategy | None = None, + autocommit_duration_ms: int | None = 1500, +) -> Executor: + """ + Returns the fully asynchronous executor for Pathway UDFs. + + Can be applied to a regular or an asynchronous function. If applied to a regular + function, it is executed in ``asyncio`` loop's ``run_in_executor``. + + In contrast to regular asynchronous UDFs, these UDFs are fully asynchronous. + It means that computations from the next batch can start even if the previous batch hasn't + finished yet. When a UDF is started, instead of a result, a special ``Pending`` value + is emitted. When the function finishes, an update with the true return value is produced. + + Using fully asynchronous UDFs allows processing time to advance even if the function + doesn't return. As a result downstream computations are not blocked. + + The data type of column returned from the fully async UDF is ``Future[return_type]`` to + allow for ``Pending`` values. Columns of this type can be propagated further, but can't + be used in most expressions (e.g. arithmetic operations). They can be passed to the next + fully async UDF though. To strip the ``Future`` wrapper and wait for the result, you can + use :py:meth:`pathway.Table.await_futures` method on :py:class:`pathway.Table`. In practice, + it filters out the ``Pending`` values and produces a column with the data type as returned + by the fully async UDF. + + Args: + capacity: Maximum number of concurrent operations allowed. + Defaults to None, indicating no specific limit. + timeout: Maximum time (in seconds) to wait for the function result. When both + ``timeout`` and ``retry_strategy`` are used, timeout applies to a single retry. + Defaults to None, indicating no time limit. + retry_strategy: Strategy for handling retries in case of failures. + Defaults to None, meaning no retries. + + Example: + + >>> import pathway as pw + >>> import asyncio + >>> + >>> t = pw.debug.table_from_markdown( + ... ''' + ... a | b | __time__ + ... 1 | 2 | 2 + ... 3 | 4 | 4 + ... 5 | 6 | 4 + ... ''' + ... ) + >>> + >>> @pw.udf(executor=pw.udfs.fully_async_executor()) + ... async def long_running_async_function(a: int, b: int) -> int: + ... c = a * b + ... await asyncio.sleep(0.1 * c) + ... return c + ... + >>> + >>> result = t.with_columns(res=long_running_async_function(pw.this.a, pw.this.b)) + >>> pw.debug.compute_and_print(result, include_id=False) + a | b | res + 1 | 2 | 2 + 3 | 4 | 12 + 5 | 6 | 30 + >>> + >>> pw.debug.compute_and_print_update_stream(result, include_id=False) # doctest: +SKIP + a | b | res | __time__ | __diff__ + 1 | 2 | Pending | 2 | 1 + 3 | 4 | Pending | 4 | 1 + 5 | 6 | Pending | 4 | 1 + 1 | 2 | Pending | 1739290145300 | -1 + 1 | 2 | 2 | 1739290145300 | 1 + 3 | 4 | Pending | 1739290146300 | -1 + 3 | 4 | 12 | 1739290146300 | 1 + 5 | 6 | Pending | 1739290148100 | -1 + 5 | 6 | 30 | 1739290148100 | 1 + """ + return FullyAsyncExecutor( + capacity=capacity, + timeout=timeout, + retry_strategy=retry_strategy, + autocommit_duration_ms=autocommit_duration_ms, + ) + + T = TypeVar("T") P = ParamSpec("P") diff --git a/python/pathway/io/python/__init__.py b/python/pathway/io/python/__init__.py index 9702b809..47ac5766 100644 --- a/python/pathway/io/python/__init__.py +++ b/python/pathway/io/python/__init__.py @@ -27,7 +27,6 @@ _get_unique_name, assert_schema_not_none, get_data_format_type, - internal_read_method, read_schema, ) @@ -314,6 +313,47 @@ def _deletions_enabled(self) -> bool: return True +def _create_python_datasource( + subject: ConnectorSubject, + *, + schema: type[Schema], + autocommit_duration_ms: int | None = 1500, + name: str | None = None, + _stacklevel: int = 1, + **kwargs, +) -> datasource.GenericDataSource: + schema, api_schema = read_schema(schema) + data_format = api.DataFormat( + **api_schema, + format_type="transparent", + session_type=subject._session_type, + ) + data_storage = api.DataStorage( + storage_type="python", + python_subject=api.PythonSubject( + start=subject.start, + seek=subject.seek, + on_persisted_run=subject.on_persisted_run, + read=subject._read, + end=subject.end, + is_internal=subject._is_internal(), + deletions_enabled=subject._deletions_enabled, + ), + ) + data_source_options = datasource.DataSourceOptions( + commit_duration_ms=autocommit_duration_ms, + unique_name=_get_unique_name(name, kwargs, stacklevel=_stacklevel + 1), + ) + return datasource.GenericDataSource( + datastorage=data_storage, + dataformat=data_format, + data_source_options=data_source_options, + schema=schema, + datasource_name=subject._datasource_name, + append_only=not subject._deletions_enabled, + ) + + @check_arg_types @trace_user_frame def read( @@ -404,37 +444,14 @@ def read( schema |= MetadataSchema schema = assert_schema_not_none(schema, data_format_type) - schema, api_schema = read_schema(schema) - data_format = api.DataFormat( - **api_schema, - format_type="transparent", - session_type=subject._session_type, - ) - data_storage = api.DataStorage( - storage_type="python", - python_subject=api.PythonSubject( - start=subject.start, - seek=subject.seek, - on_persisted_run=subject.on_persisted_run, - read=subject._read, - end=subject.end, - is_internal=subject._is_internal(), - deletions_enabled=subject._deletions_enabled, - ), - read_method=internal_read_method(format), - ) - data_source_options = datasource.DataSourceOptions( - commit_duration_ms=autocommit_duration_ms, - unique_name=_get_unique_name(name, kwargs, stacklevel=_stacklevel + 5), - ) return table_from_datasource( - datasource.GenericDataSource( - datastorage=data_storage, - dataformat=data_format, - data_source_options=data_source_options, + _create_python_datasource( + subject, schema=schema, - datasource_name=subject._datasource_name, - append_only=not subject._deletions_enabled, + autocommit_duration_ms=autocommit_duration_ms, + name=name, + _stacklevel=_stacklevel + 5, + **kwargs, ), debug_datasource=datasource.debug_datasource(debug_data), ) diff --git a/python/pathway/stdlib/utils/async_transformer.py b/python/pathway/stdlib/utils/async_transformer.py index 11c0d030..83a58b9d 100644 --- a/python/pathway/stdlib/utils/async_transformer.py +++ b/python/pathway/stdlib/utils/async_transformer.py @@ -6,24 +6,21 @@ import collections import functools import inspect -import json import logging import re -from abc import ABC, abstractmethod +from abc import ABCMeta, abstractmethod from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from enum import Enum from typing import Any, ClassVar import pathway.internals as pw +import pathway.internals.column as clmn import pathway.internals.dtype as dt import pathway.io as io -from pathway.internals import api, operator, parse_graph, udfs -from pathway.internals.api import Pointer -from pathway.internals.helpers import StableSet -from pathway.internals.operator import Operator +from pathway.internals import api, datasource, udfs +from pathway.internals.api import Pointer, SessionType from pathway.internals.schema import Schema, schema_from_types -from pathway.internals.table_subscription import subscribe from pathway.internals.type_interpreter import eval_type @@ -34,7 +31,7 @@ class _AsyncStatus(Enum): _ASYNC_STATUS_COLUMN = "_async_status" -_AsyncStatusSchema = schema_from_types(**{_ASYNC_STATUS_COLUMN: str}) +_AsyncStatusSchema = schema_from_types(**{_ASYNC_STATUS_COLUMN: dt.Future(dt.STR)}) _INSTANCE_COLUMN = "_pw_instance" @@ -43,6 +40,7 @@ class _Entry: key: Pointer time: int is_addition: bool + task_id: Pointer ResultType = dict[str, api.Value] | _AsyncStatus | None @@ -52,7 +50,9 @@ class _Entry: class _Instance: pending: collections.deque[_Entry] = field(default_factory=collections.deque) finished: dict[_Entry, ResultType] = field(default_factory=dict) - buffer: list[tuple[Pointer, bool, ResultType]] = field(default_factory=list) + buffer: list[tuple[Pointer, bool, Pointer, ResultType]] = field( + default_factory=list + ) buffer_time: int | None = None correct: bool = True @@ -61,15 +61,14 @@ class _AsyncConnector(io.python.ConnectorSubject): _requests: asyncio.Queue _apply: Callable _loop: asyncio.AbstractEventLoop - _transformer: AsyncTransformer - _state: dict[Pointer, Any] + _transformer: _BaseAsyncTransformer _tasks: dict[Pointer, asyncio.Task] _invoke: Callable[..., Awaitable[dict[str, Any]]] _instances: dict[api.Value, _Instance] _time_finished: int | None _logger: logging.Logger - def __init__(self, transformer: AsyncTransformer) -> None: + def __init__(self, transformer: _BaseAsyncTransformer) -> None: super().__init__(datasource_name="async-transformer") self._transformer = transformer self._event_loop = asyncio.new_event_loop() @@ -92,7 +91,6 @@ def set_options( def run(self) -> None: self._tasks = {} - self._state = {} self._transformer.open() self._instances = collections.defaultdict(_Instance) self._time_finished = None @@ -108,28 +106,49 @@ async def loop_forever(event_loop: asyncio.AbstractEventLoop): self._on_time_end(request) continue - (key, values, time, addition) = request - instance = values[_INSTANCE_COLUMN] - entry = _Entry(key=key, time=time, is_addition=addition) + (key, values, time, diff) = request + row = {} + input_table = self._transformer._input_table + task_id = values[-1] + values = values[:-1] + if input_table is not None: + for field_name, field_value in zip( + input_table._columns.keys(), values, strict=True + ): + row[field_name] = field_value + else: + for i, field_value in enumerate(values): + row[f"{i}"] = field_value + + assert diff in [-1, 1], "diff should be 1 or -1" + addition = diff == 1 + instance = row.get(_INSTANCE_COLUMN, key) + entry = _Entry( + key=key, time=time, is_addition=addition, task_id=task_id + ) self._instances[instance].pending.append(entry) previous_task = self._tasks.get(key, None) - if previous_task is None: - self._set_status(key, _AsyncStatus.PENDING) async def task( key: Pointer, values: dict[str, Any], time: int, addition: bool, + task_id: Pointer, previous_task: asyncio.Task | None, ): - instance = values.pop(_INSTANCE_COLUMN) + instance = values.pop(_INSTANCE_COLUMN, key) if not addition: if previous_task is not None: await previous_task self._on_task_finished( - key, instance, time, is_addition=False, result=None + key, + instance, + time, + is_addition=False, + result=None, + task_id=task_id, ) else: result: dict[str, Any] | _AsyncStatus @@ -147,11 +166,16 @@ async def task( if previous_task is not None: await previous_task self._on_task_finished( - key, instance, time, is_addition=True, result=result + key, + instance, + time, + is_addition=True, + result=result, + task_id=task_id, ) current_task = event_loop.create_task( - task(key, values, time, addition, previous_task) + task(key, row, time, addition, task_id, previous_task) ) self._tasks[key] = current_task @@ -176,9 +200,10 @@ def _on_task_finished( *, is_addition: bool, result: dict[str, Any] | _AsyncStatus | None, + task_id: Pointer, ) -> None: instance_data = self._instances[instance] - entry = _Entry(key=key, time=time, is_addition=is_addition) + entry = _Entry(key=key, time=time, is_addition=is_addition, task_id=task_id) instance_data.finished[entry] = result self._maybe_produce_instance(instance) @@ -202,7 +227,9 @@ def _maybe_produce_instance(self, instance: api.Value) -> None: result = instance_data.finished.pop(entry) if result == _AsyncStatus.FAILURE: instance_data.correct = False - instance_data.buffer.append((entry.key, entry.is_addition, result)) + instance_data.buffer.append( + (entry.key, entry.is_addition, entry.task_id, result) + ) instance_data.pending.popleft() if ( @@ -216,34 +243,32 @@ def _maybe_produce_instance(self, instance: api.Value) -> None: def _flush_buffer(self, instance_data: _Instance) -> None: if not instance_data.buffer: return + self.commit() self._disable_commits() - for key, is_addition, result in instance_data.buffer: + for key, is_addition, task_id, result in instance_data.buffer: if is_addition and instance_data.correct: assert isinstance(result, dict) - self._upsert(key, result) + self._upsert(key, result, task_id) elif is_addition: - self._set_status(key, _AsyncStatus.FAILURE) + self._set_failure(key, task_id) else: - self._remove_by_key(key) + self._remove_by_key(key, task_id) self._enable_commits() # does a commit as well instance_data.buffer.clear() - def _set_status(self, key: Pointer, status: _AsyncStatus) -> None: + def _set_failure(self, key: Pointer, task_id: Pointer) -> None: + # TODO: replace None with api.ERROR data = {col: None for col in self._transformer.output_schema.column_names()} - self._upsert(key, data, status) + self._upsert(key, data, task_id, _AsyncStatus.FAILURE) - def _upsert(self, key: Pointer, data: dict, status=_AsyncStatus.SUCCESS) -> None: - data = {**data, _ASYNC_STATUS_COLUMN: status.value} - payload = json.dumps(data).encode() - self._remove_by_key(key) - self._add(key, payload) - self._state[key] = data + def _upsert( + self, key: Pointer, data: dict, task_id: Pointer, status=_AsyncStatus.SUCCESS + ) -> None: + data[_ASYNC_STATUS_COLUMN] = status.value + self._add_inner(task_id, data) - def _remove_by_key(self, key) -> None: - if key in self._state: - payload = json.dumps(self._state[key]).encode() - self._remove(key, payload) - del self._state[key] + def _remove_by_key(self, key: Pointer, task_id: Pointer) -> None: + self._remove_inner(task_id, {}) def _check_result_against_schema(self, result: dict) -> None: if result.keys() != self._transformer.output_schema.keys(): @@ -253,7 +278,7 @@ def on_stop(self) -> None: self._transformer.close() def on_subscribe_change( - self, key: Pointer, row: dict[str, Any], time: int, is_addition: bool + self, key: Pointer, row: list[Any], time: int, is_addition: bool ) -> None: self._put_request((key, row, time, is_addition)) @@ -277,8 +302,72 @@ def _maybe_create_queue(self) -> None: def _is_internal(self) -> bool: return True + @property + def _session_type(self) -> SessionType: + return SessionType.UPSERT + + +class _BaseAsyncTransformer(metaclass=ABCMeta): + output_schema: ClassVar[type[pw.Schema]] + wrapped_output_schema: ClassVar[type[pw.Schema]] + _connector: _AsyncConnector + _autocommit_duration_ms: int | None + _input_table: pw.Table | None + + def __init__(self, autocommit_duration_ms: int | None = 1500) -> None: + assert self.output_schema is not None + self._connector = _AsyncConnector(self) + self._autocommit_duration_ms = autocommit_duration_ms + self._input_table = None + + def __init_subclass__( + cls, /, output_schema: type[pw.Schema] | None = None, **kwargs + ): + super().__init_subclass__(**kwargs) + if output_schema is None: + return + cls.output_schema = output_schema + cls.wrapped_output_schema = ( + output_schema.with_types( + **{ + key: dt.Future(dt.Optional(orig_dtype)) + for key, orig_dtype in output_schema._dtypes().items() + } + ) + | _AsyncStatusSchema + ) + + def _get_datasource(self) -> datasource.GenericDataSource: + return io.python._create_python_datasource( + self._connector, + schema=self.wrapped_output_schema, + autocommit_duration_ms=self._autocommit_duration_ms, + ) + + def open(self) -> None: + """ + Called before actual work. Suitable for one time setup. + """ + pass + + def close(self) -> None: + """ + Called once at the end. Proper place for cleanup. + """ + pass + + @abstractmethod + async def invoke(self, *args, **kwargs) -> dict[str, Any]: + """ + Called for every row of input_table. The arguments will correspond to the + columns in the input table. + + Should return dict of values matching :py:attr:`output_schema`. + """ + ... -class AsyncTransformer(ABC): + +class AsyncTransformer(_BaseAsyncTransformer, metaclass=ABCMeta): """ Allows to perform async transformations on a table. @@ -294,7 +383,7 @@ class AsyncTransformer(ABC): ... ret: int ... >>> class AsyncIncrementTransformer(pw.AsyncTransformer, output_schema=OutputSchema): - ... async def invoke(self, value) -> Dict[str, Any]: + ... async def invoke(self, value) -> dict[str, Any]: ... await asyncio.sleep(0.1) ... return {"ret": value + 1 } ... @@ -310,8 +399,6 @@ class AsyncTransformer(ABC): 45 """ - output_schema: ClassVar[type[pw.Schema]] - _connector: _AsyncConnector _input_table: pw.Table _instance_expression: pw.ColumnExpression | api.Value @@ -322,8 +409,7 @@ def __init__( instance: pw.ColumnExpression | api.Value = pw.this.id, autocommit_duration_ms: int | None = 1500, ) -> None: - assert self.output_schema is not None - self._connector = _AsyncConnector(self) + super().__init__(autocommit_duration_ms=autocommit_duration_ms) # TODO: when AsyncTransformer uses persistence backend for cache # just take the settings for persistence config @@ -342,7 +428,6 @@ def __init__( ) self._input_table = input_table - self._autocommit_duration_ms = autocommit_duration_ms def _check_signature_matches_schema( self, sig: inspect.Signature, schema: type[Schema] @@ -366,30 +451,7 @@ def _check_signature_matches_schema( raise e def __init_subclass__(cls, /, output_schema: type[pw.Schema], **kwargs): - super().__init_subclass__(**kwargs) - cls.output_schema = output_schema - - def open(self) -> None: - """ - Called before actual work. Suitable for one time setup. - """ - pass - - def close(self) -> None: - """ - Called once at the end. Proper place for cleanup. - """ - pass - - @abstractmethod - async def invoke(self, *args, **kwargs) -> dict[str, Any]: - """ - Called for every row of input_table. The arguments will correspond to the - columns in the input table. - - Should return dict of values matching :py:attr:`output_schema`. - """ - ... + super().__init_subclass__(output_schema, **kwargs) def with_options( self, @@ -423,7 +485,7 @@ def successful(self) -> pw.Table: The resulting table containing only rows that were executed successfully. """ return ( - self.output_table.filter( + self.finished.filter( pw.this[_ASYNC_STATUS_COLUMN] == _AsyncStatus.SUCCESS.value ) .without(pw.this[_ASYNC_STATUS_COLUMN]) @@ -437,7 +499,7 @@ def failed(self) -> pw.Table: If the ``instance`` argument is specified, it also contains rows that were executed successfully but at least one element from their instance with less or equal time failed. """ - return self.output_table.filter( + return self.finished.filter( pw.this[_ASYNC_STATUS_COLUMN] == _AsyncStatus.FAILURE.value ).without(pw.this[_ASYNC_STATUS_COLUMN]) @@ -454,9 +516,7 @@ def finished(self) -> pw.Table: If you want to get only rows that executed successfully, use ``successful`` property instead. """ - return self.output_table.filter( - pw.this[_ASYNC_STATUS_COLUMN] != _AsyncStatus.PENDING.value - ) + return self.output_table.await_futures() @functools.cached_property def output_table(self) -> pw.Table: @@ -472,40 +532,16 @@ def output_table(self) -> pw.Table: a Table containing only rows that were executed successfully. """ - subscribe( - self._input_table, - skip_persisted_batch=False, - on_change=self._connector.on_subscribe_change, - on_time_end=self._connector.on_subscribe_time_end, - on_end=self._connector.on_subscribe_end, - ) - output_node = list(parse_graph.G.global_scope.nodes)[-1] - - schema = self.output_schema.with_types( - **{ - key: dt.Optional(orig_dtype) - for key, orig_dtype in self.output_schema._dtypes().items() - } - ) - - table: pw.Table = io.python.read( - self._connector, - schema=schema | _AsyncStatusSchema, - autocommit_duration_ms=self._autocommit_duration_ms, + input_id_column = self._input_table._id_column + input_columns = list(self._input_table._columns.values()) + + context = clmn.AsyncTransformerContext( + input_id_column, + input_columns, + self.wrapped_output_schema, + self._connector.on_subscribe_change, + self._connector.on_subscribe_time_end, + self._connector.on_subscribe_end, + self._get_datasource(), ) - input_node = table._source.operator - - class AsyncInputHandle(operator.InputHandle): - @property - def dependencies(self) -> StableSet[Operator]: - return StableSet([output_node]) - - input_node._inputs = { - "async_input": AsyncInputHandle( - operator=input_node, - name="async_input", - value=self._input_table, - ) - } - - return table.promise_universe_is_subset_of(self._input_table) + return self._input_table._async_transformer(context) diff --git a/python/pathway/tests/test_async_transformer.py b/python/pathway/tests/test_async_transformer.py index b8d86741..be82c892 100644 --- a/python/pathway/tests/test_async_transformer.py +++ b/python/pathway/tests/test_async_transformer.py @@ -26,7 +26,6 @@ run, wait_result_with_checker, write_csv, - xfail_on_multiple_threads, ) @@ -106,8 +105,6 @@ async def invoke(self, value: int) -> dict[str, Any]: ) -@pytest.mark.flaky(reruns=2) -@xfail_on_multiple_threads @needs_multiprocessing_fork def test_idempotency(monkeypatch): monkeypatch.delenv("PATHWAY_PERSISTENT_STORAGE", raising=False) @@ -250,6 +247,7 @@ async def invoke(self, value: int) -> dict[str, Any]: expected, persistence_config=pw.persistence.Config( pw.persistence.Backend.filesystem(cache_dir), + persistence_mode=pw.PersistenceMode.SELECTIVE_PERSISTING, ), ) diff --git a/python/pathway/tests/test_column_properties.py b/python/pathway/tests/test_column_properties.py index 0d77b68e..78e068d8 100644 --- a/python/pathway/tests/test_column_properties.py +++ b/python/pathway/tests/test_column_properties.py @@ -786,3 +786,23 @@ class Schema(pw.Schema): assert result._id_column.properties.append_only assert result.a._column.properties.append_only assert result.b._column.properties.append_only + + +@pytest.mark.parametrize("append_only_1", [True, False]) +@pytest.mark.parametrize("append_only_2", [True, False]) +def test_fully_async_udf(append_only_1, append_only_2): + class Schema(pw.Schema): + a: int = pw.column_definition(append_only=append_only_1) + b: int = pw.column_definition(append_only=append_only_2) + + @pw.udf(executor=pw.udfs.fully_async_executor()) + def foo(a: int, b: int) -> int: + return a + b + + table = table_from_datasource(TestDataSource(schema=Schema)) + result = table.with_columns(c=foo(pw.this.a, pw.this.b)) + + assert result._id_column.properties.append_only == (append_only_1 or append_only_2) + assert result.a._column.properties.append_only == append_only_1 + assert result.b._column.properties.append_only == append_only_2 + assert result.c._column.properties.append_only is False diff --git a/python/pathway/tests/test_expression_repr.py b/python/pathway/tests/test_expression_repr.py index e74b8bff..e2eaad84 100644 --- a/python/pathway/tests/test_expression_repr.py +++ b/python/pathway/tests/test_expression_repr.py @@ -111,6 +111,48 @@ def test_apply(): ) +def test_apply_udf(): + @pw.udf + def foo(a: int) -> int: + return a + + t = T( + """ + pet | owner | age + 1 | Alice | 10 + """ + ) + assert repr(foo(t.age)) == "pathway.apply(foo, .age)" + + +def test_async_apply_udf(): + @pw.udf + async def foo(a: int) -> int: + return a + + t = T( + """ + pet | owner | age + 1 | Alice | 10 + """ + ) + assert repr(foo(t.age)) == "pathway.apply_async(foo, .age)" + + +def test_fully_async_apply_udf(): + @pw.udf(executor=pw.udfs.fully_async_executor()) + async def foo(a: int) -> int: + return a + + t = T( + """ + pet | owner | age + 1 | Alice | 10 + """ + ) + assert repr(foo(t.age)) == "pathway.apply_fully_async(foo, .age)" + + def test_cast(): t = T( """ diff --git a/python/pathway/tests/test_persistence.py b/python/pathway/tests/test_persistence.py index 9f189669..f0a738e9 100644 --- a/python/pathway/tests/test_persistence.py +++ b/python/pathway/tests/test_persistence.py @@ -1,5 +1,6 @@ # Copyright © 2024 Pathway +import asyncio import json import multiprocessing import os @@ -10,8 +11,8 @@ import pytest import pathway as pw -from pathway.engine import SessionType from pathway.internals import api +from pathway.internals.api import SessionType from pathway.internals.parse_graph import G from pathway.tests.utils import ( CsvPathwayChecker, @@ -708,6 +709,13 @@ def logic(t: pw.Table) -> pw.Table: run(["a,b"], {"2,5,-1"}) +def get_checker(output_path: pathlib.Path, expected: set[str]) -> Callable: + def check() -> None: + assert_sets_equality_from_path(output_path, expected) + + return LogicChecker(check) + + @pytest.mark.parametrize( "mode", [api.PersistenceMode.PERSISTING, api.PersistenceMode.OPERATOR_PERSISTING], @@ -728,7 +736,7 @@ class InputSchema(pw.Schema): persistence_mode=mode, ) - def setup(inputs: list[str]) -> None: + def wait_result(inputs: list[str], expected: set[str]) -> None: nonlocal count count += 1 G.clear() @@ -737,48 +745,18 @@ def setup(inputs: list[str]) -> None: t_1 = pw.io.csv.read(input_path, schema=InputSchema, mode="streaming") res = t_1._buffer(pw.this.t + 10, pw.this.t) pw.io.csv.write(res, output_path) + wait_result_with_checker( + get_checker(output_path, expected), + timeout_sec=10, + target=run, + kwargs={"persistence_config": persistence_config}, + ) - def get_checker(expected: set[str]) -> Callable: - def check() -> None: - assert_sets_equality_from_path(output_path, expected) - - return LogicChecker(check) - - setup(["t", "1", "3", "11"]) - wait_result_with_checker( - get_checker({"1,1"}), - timeout_sec=10, - target=run, - kwargs={"persistence_config": persistence_config}, - ) - setup(["t", "15", "16"]) - wait_result_with_checker( - get_checker({"3,1"}), - timeout_sec=10, - target=run, - kwargs={"persistence_config": persistence_config}, - ) - setup(["t", "6", "21"]) - wait_result_with_checker( - get_checker({"6,1", "11,1"}), - timeout_sec=10, - target=run, - kwargs={"persistence_config": persistence_config}, - ) - setup(["t", "9", "10"]) - wait_result_with_checker( - get_checker({"9,1", "10,1"}), - timeout_sec=10, - target=run, - kwargs={"persistence_config": persistence_config}, - ) - setup(["t", "26"]) - wait_result_with_checker( - get_checker({"15,1", "16,1"}), - timeout_sec=10, - target=run, - kwargs={"persistence_config": persistence_config}, - ) + wait_result(["t", "1", "3", "11"], {"1,1"}) + wait_result(["t", "15", "16"], {"3,1"}) + wait_result(["t", "6", "21"], {"6,1", "11,1"}) + wait_result(["t", "9", "10"], {"9,1", "10,1"}) + wait_result(["t", "26"], {"15,1", "16,1"}) @pytest.mark.parametrize("mode", [api.PersistenceMode.OPERATOR_PERSISTING]) @@ -817,7 +795,7 @@ class InputSchema(pw.Schema): persistence_mode=mode, ) - def setup(inputs: list[str]) -> None: + def wait_result(inputs: list[str], expected: set[str]) -> None: nonlocal count count += 1 G.clear() @@ -826,55 +804,19 @@ def setup(inputs: list[str]) -> None: t_1 = pw.io.csv.read(input_path, schema=InputSchema, mode="streaming") res = t_1._forget(pw.this.t + 10, pw.this.t, mark_forgetting_records=False) pw.io.csv.write(res, output_path) + wait_result_with_checker( + get_checker(output_path, expected), + timeout_sec=10, + target=run, + kwargs={"persistence_config": persistence_config}, + ) - def get_checker(expected: set[str]) -> Callable: - def check() -> None: - assert_sets_equality_from_path(output_path, expected) - - return LogicChecker(check) - - setup(["t", "1", "3", "11"]) - wait_result_with_checker( - get_checker({"1,1", "3,1", "11,1"}), - timeout_sec=10, - target=run, - kwargs={"persistence_config": persistence_config}, - ) - setup(["t", "15", "16"]) - wait_result_with_checker( - get_checker({"1,-1", "15,1", "16,1"}), - timeout_sec=10, - target=run, - kwargs={"persistence_config": persistence_config}, - ) - setup(["t", "6", "21"]) - wait_result_with_checker( - get_checker({"3,-1", "21,1"}), - timeout_sec=10, - target=run, - kwargs={"persistence_config": persistence_config}, - ) - setup(["t", "9", "10"]) - wait_result_with_checker( - get_checker({"11,-1"}), - timeout_sec=10, - target=run, - kwargs={"persistence_config": persistence_config}, - ) - setup(["t", "26"]) - wait_result_with_checker( - get_checker({"26,1"}), - timeout_sec=10, - target=run, - kwargs={"persistence_config": persistence_config}, - ) - setup(["t", "22"]) - wait_result_with_checker( - get_checker({"15,-1", "16,-1", "22,1"}), - timeout_sec=10, - target=run, - kwargs={"persistence_config": persistence_config}, - ) + wait_result(["t", "1", "3", "11"], {"1,1", "3,1", "11,1"}) + wait_result(["t", "15", "16"], {"1,-1", "15,1", "16,1"}) + wait_result(["t", "6", "21"], {"3,-1", "21,1"}) + wait_result(["t", "9", "10"], {"11,-1"}) + wait_result(["t", "26"], {"26,1"}) + wait_result(["t", "22"], {"15,-1", "16,-1", "22,1"}) @pytest.mark.parametrize( @@ -923,3 +865,175 @@ def run_computation(inputs: list[dict[str, int]], expected: set[str]): run_computation([{"a": 3, "b": 9}], {"3,10,-1", "3,9,1"}) run_computation([{"a": 4, "b": 6}], {"4,6,1"}) run_computation([{"a": 1, "b": 0}], {"1,4,-1", "1,0,1"}) + + +@pytest.mark.parametrize("mode", [api.PersistenceMode.OPERATOR_PERSISTING]) +@pytest.mark.parametrize("as_udf", [True, False]) +@needs_multiprocessing_fork +def test_async_transformer_append_only(tmp_path, mode, as_udf): + class InputSchema(pw.Schema): + a: int + + input_path = tmp_path / "1" + os.makedirs(input_path) + output_path = tmp_path / "out.csv" + persistent_storage_path = tmp_path / "p" + count = 0 + persistence_config = pw.persistence.Config( + pw.persistence.Backend.filesystem(persistent_storage_path), + persistence_mode=mode, + ) + + can_pass = [True, True, False, False, True, True, True, False, True] + + class OutputSchema(pw.Schema): + ret: int + + if as_udf: + + @pw.udf(executor=pw.udfs.fully_async_executor()) + async def decrement(a: int) -> int: + while not can_pass[a]: + await asyncio.sleep(1) + return a - 1 + + else: + + class Decrementer(pw.AsyncTransformer, output_schema=OutputSchema): + async def invoke(self, a) -> dict[str, int]: + while not can_pass[a]: + await asyncio.sleep(1) + return {"ret": a - 1} + + def wait_result(inputs: list[str], expected: set[str]) -> None: + nonlocal count + count += 1 + G.clear() + path = input_path / str(count) + write_lines(path, inputs) + t_1 = pw.io.csv.read(input_path, schema=InputSchema, mode="streaming") + if as_udf: + res = t_1.select(ret=decrement(pw.this.a)).await_futures() + else: + res = Decrementer(t_1).successful + pw.io.csv.write(res, output_path) + wait_result_with_checker( + get_checker(output_path, expected), + timeout_sec=30, + target=run, + kwargs={"persistence_config": persistence_config}, + ) + + wait_result(["a", "0", "1", "2"], {"-1,1", "0,1"}) + can_pass[2] = True + can_pass[3] = True + wait_result(["a", "3"], {"1,1", "2,1"}) + wait_result(["a", "4", "5"], {"3,1", "4,1"}) + wait_result(["a", "6", "7"], {"5,1"}) + can_pass[7] = True + wait_result(["a", "8"], {"6,1", "7,1"}) + + +def get_async_transformer_tester( + tmp_path: pathlib.Path, + input_path: pathlib.Path, + mode: api.PersistenceMode, + as_udf: bool, + can_pass: list[bool], +) -> Callable[[list[str], set[str]], None]: + class InputSchema(pw.Schema): + a: int = pw.column_definition(primary_key=True) + b: int + + os.makedirs(input_path) + output_path = tmp_path / "out.csv" + persistent_storage_path = tmp_path / "p" + count = 0 + persistence_config = pw.persistence.Config( + pw.persistence.Backend.filesystem(persistent_storage_path), + persistence_mode=mode, + ) + + class OutputSchema(pw.Schema): + a: int + b: int + + if as_udf: + + @pw.udf(executor=pw.udfs.fully_async_executor()) + async def decrement(a: int) -> int: + while not can_pass[a]: + await asyncio.sleep(1) + return a - 1 + + else: + + class Decrementer(pw.AsyncTransformer, output_schema=OutputSchema): + async def invoke(self, a, b) -> dict[str, int]: + while not can_pass[b]: + await asyncio.sleep(1) + return {"a": a, "b": b - 1} + + def wait_result(inputs: list[str], expected: set[str]) -> None: + nonlocal count + count += 1 + G.clear() + path = input_path / str(count) + write_lines(path, inputs) + t_1 = pw.io.csv.read(input_path, schema=InputSchema, mode="streaming") + if as_udf: + res = t_1.select(a=pw.this.a, b=decrement(pw.this.b)).await_futures() + else: + res = Decrementer(t_1).successful + pw.io.csv.write(res, output_path) + wait_result_with_checker( + get_checker(output_path, expected), + timeout_sec=30, + target=run, + kwargs={"persistence_config": persistence_config}, + ) + + return wait_result + + +@pytest.mark.parametrize("mode", [api.PersistenceMode.OPERATOR_PERSISTING]) +@needs_multiprocessing_fork +def test_async_transformer(tmp_path, mode): + input_path = tmp_path / "1" + can_pass = [True, True, False, False, True, True, True, False, True] + wait_result = get_async_transformer_tester( + tmp_path, input_path, mode, False, can_pass + ) + wait_result(["a,b", "0,0", "1,1", "2,2"], {"0,-1,1", "1,0,1"}) + can_pass[2] = True + can_pass[3] = True + wait_result(["a,b", "3,3"], {"2,1,1", "3,2,1"}) + wait_result(["a,b", "0,4", "2,3"], {"0,-1,-1", "2,1,-1", "0,3,1", "2,2,1"}) + wait_result(["a,b", "6,6", "7,7"], {"6,5,1"}) + os.remove(input_path / "4") + wait_result(["a,b", "8,8"], {"6,5,-1", "8,7,1"}) + os.remove(input_path / "3") + wait_result(["a,b"], {"0,3,-1", "2,2,-1"}) + + +@pytest.mark.parametrize("mode", [api.PersistenceMode.OPERATOR_PERSISTING]) +@needs_multiprocessing_fork +def test_fully_async_udf(tmp_path, mode): + input_path = tmp_path / "1" + can_pass = [True, True, False, False, True, True, True, False, True] + wait_result = get_async_transformer_tester( + tmp_path, input_path, mode, True, can_pass + ) + wait_result(["a,b", "0,0", "1,1", "2,2"], {"0,-1,1", "1,0,1"}) + can_pass[2] = True + can_pass[3] = True + wait_result(["a,b", "3,3"], {"2,1,1", "3,2,1"}) + os.remove(input_path / "1") + wait_result( + ["a,b", "0,4", "2,3"], {"0,-1,-1", "2,1,-1", "0,3,1", "2,2,1", "1,0,-1"} + ) + wait_result(["a,b", "6,6", "7,7"], {"6,5,1"}) + os.remove(input_path / "4") + wait_result(["a,b", "8,8"], {"6,5,-1", "8,7,1"}) + os.remove(input_path / "3") + wait_result(["a,b"], {"0,3,-1", "2,2,-1"}) diff --git a/python/pathway/tests/test_udf.py b/python/pathway/tests/test_udf.py index 7f7ff76d..bb34fa99 100644 --- a/python/pathway/tests/test_udf.py +++ b/python/pathway/tests/test_udf.py @@ -8,6 +8,7 @@ import re import sys import threading +import time import warnings from typing import Optional from unittest import mock @@ -16,10 +17,13 @@ import pathway as pw from pathway.internals import api +from pathway.internals.udfs.executors import Executor from pathway.tests.utils import ( T, assert_stream_equality, assert_table_equality, + assert_table_equality_wo_index, + assert_table_equality_wo_types, run_all, warns_here, xfail_on_multiple_threads, @@ -89,12 +93,22 @@ def __wrapped__(self, a: int) -> int: ) -def test_udf_async_options(tmp_path: pathlib.Path): +def get_async_executor(fully_async: bool) -> Executor: + if fully_async: + return pw.udfs.fully_async_executor() + else: + return pw.udfs.async_executor() + + +@pytest.mark.parametrize("fully_async", [True, False]) +def test_udf_async_options(tmp_path: pathlib.Path, fully_async): cache_dir = tmp_path / "test_cache" counter = mock.Mock() - @pw.udf(cache_strategy=pw.udfs.DiskCache()) + @pw.udf( + executor=get_async_executor(fully_async), cache_strategy=pw.udfs.DiskCache() + ) async def inc(x: int) -> int: counter() return x + 5 @@ -108,6 +122,8 @@ async def inc(x: int) -> int: """ ) result = input.select(ret=inc(pw.this.foo)) + if fully_async: + result = result.await_futures() expected = T( """ ret @@ -136,12 +152,13 @@ async def inc(x: int) -> int: assert counter.call_count == 3 +@pytest.mark.parametrize("fully_async", [True, False]) @pytest.mark.skipif(sys.version_info < (3, 11), reason="test requires asyncio.Barrier") -def test_udf_async(): +def test_udf_async(fully_async): barrier = asyncio.Barrier(3) # type: ignore[attr-defined] # mypy complains because of versions lower than 3.11 - @pw.udf + @pw.udf(executor=get_async_executor(fully_async)) async def inc(a: int) -> int: await barrier.wait() return a + 3 @@ -157,6 +174,9 @@ async def inc(a: int) -> int: result = input.select(ret=inc(pw.this.a)) + if fully_async: + result = result.await_futures() + assert_table_equality( result, T( @@ -193,10 +213,11 @@ def inc(a: int) -> int: run_all() -def test_udf_sync_with_async_executor(): +@pytest.mark.parametrize("fully_async", [True, False]) +def test_udf_sync_with_async_executor(fully_async): barrier = threading.Barrier(3, timeout=10) - @pw.udf(executor=pw.udfs.async_executor()) + @pw.udf(executor=get_async_executor(fully_async)) def inc(a: int) -> int: barrier.wait() return a + 3 @@ -212,6 +233,9 @@ def inc(a: int) -> int: result = input.select(ret=inc(pw.this.a)) + if fully_async: + result = result.await_futures() + assert_table_equality( result, T( @@ -225,7 +249,8 @@ def inc(a: int) -> int: ) -def test_udf_async_class(): +@pytest.mark.parametrize("fully_async", [True, False]) +def test_udf_async_class(fully_async): class Inc(pw.UDF): def __init__(self, inc, **kwargs) -> None: super().__init__(**kwargs) @@ -244,8 +269,10 @@ async def __wrapped__(self, a: int) -> int: """ ) - inc = Inc(40) + inc = Inc(40, executor=get_async_executor(fully_async)) result = input.select(ret=inc(pw.this.a)) + if fully_async: + result = result.await_futures() assert_table_equality( result, @@ -260,10 +287,11 @@ async def __wrapped__(self, a: int) -> int: ) -def test_udf_propagate_none(): +@pytest.mark.parametrize("fully_async", [True, False]) +def test_udf_propagate_none(fully_async): internal_add = mock.Mock() - @pw.udf(propagate_none=True) + @pw.udf(executor=get_async_executor(fully_async), propagate_none=True) def add(a: int, b: int) -> int: assert a is not None assert b is not None @@ -280,6 +308,8 @@ def add(a: int, b: int) -> int: ) result = input.select(ret=add(pw.this.a, pw.this.b)) + if fully_async: + result = result.await_futures() assert_table_equality( result, @@ -531,10 +561,11 @@ async def inc(a: int) -> int: assert internal_inc.call_count == 5 -def test_async_udf_propagate_none(): +@pytest.mark.parametrize("fully_async", [True, False]) +def test_async_udf_propagate_none(fully_async): internal_add = mock.Mock() - @pw.udf(propagate_none=True) + @pw.udf(propagate_none=True, executor=get_async_executor(fully_async)) async def add(a: int, b: int) -> int: assert a is not None assert b is not None @@ -551,6 +582,8 @@ async def add(a: int, b: int) -> int: ) result = input.select(ret=add(pw.this.a, pw.this.b)) + if fully_async: + result = result.await_futures() assert_table_equality( result, @@ -566,10 +599,11 @@ async def add(a: int, b: int) -> int: internal_add.assert_called_once() -def test_async_udf_with_none(): +@pytest.mark.parametrize("fully_async", [True, False]) +def test_async_udf_with_none(fully_async): internal_add = mock.Mock() - @pw.udf() + @pw.udf(executor=get_async_executor(fully_async)) async def add(a: int, b: int) -> int: internal_add() if a is None: @@ -588,6 +622,8 @@ async def add(a: int, b: int) -> int: ) result = input.select(ret=add(pw.this.a, pw.this.b)) + if fully_async: + result = result.await_futures() assert_table_equality( result, @@ -603,8 +639,14 @@ async def add(a: int, b: int) -> int: assert internal_add.call_count == 3 -def test_udf_timeout(): - @pw.udf(executor=pw.udfs.async_executor(timeout=0.1)) +@pytest.mark.parametrize("fully_async", [True, False]) +def test_udf_timeout(fully_async): + if fully_async: + executor = pw.udfs.fully_async_executor(timeout=0.1) + else: + executor = pw.udfs.async_executor(timeout=0.1) + + @pw.udf(executor=executor) async def inc(a: int) -> int: await asyncio.sleep(2) return a + 1 @@ -618,7 +660,9 @@ async def inc(a: int) -> int: input.select(ret=inc(pw.this.a)) expected: type[Exception] - if sys.version_info < (3, 11): + if fully_async: + expected = api.EngineError + elif sys.version_info < (3, 11): expected = asyncio.exceptions.TimeoutError else: expected = TimeoutError @@ -626,8 +670,14 @@ async def inc(a: int) -> int: run_all() -def test_udf_too_fast_for_timeout(): - @pw.udf(executor=pw.udfs.async_executor(timeout=10.0)) +@pytest.mark.parametrize("fully_async", [True, False]) +def test_udf_too_fast_for_timeout(fully_async): + if fully_async: + executor = pw.udfs.fully_async_executor(timeout=10.0) + else: + executor = pw.udfs.async_executor(timeout=10.0) + + @pw.udf(executor=executor) async def inc(a: int) -> int: return a + 1 @@ -641,6 +691,8 @@ async def inc(a: int) -> int: ) result = input.select(ret=inc(pw.this.a)) + if fully_async: + result = result.await_futures() assert_table_equality( result, T( @@ -654,11 +706,11 @@ async def inc(a: int) -> int: ) -@pytest.mark.parametrize("sync", [True, False]) -def test_udf_in_memory_cache(sync: bool) -> None: +@pytest.mark.parametrize("sync", ["sync", "async", "fully_async"]) +def test_udf_in_memory_cache(sync: str) -> None: internal_inc = mock.Mock() - if sync: + if sync == "sync": @pw.udf(cache_strategy=pw.udfs.InMemoryCache()) def inc(a: int) -> int: @@ -667,7 +719,10 @@ def inc(a: int) -> int: else: - @pw.udf(cache_strategy=pw.udfs.InMemoryCache()) + @pw.udf( + cache_strategy=pw.udfs.InMemoryCache(), + executor=get_async_executor(sync == "fully_async"), + ) async def inc(a: int) -> int: await asyncio.sleep(a / 10) internal_inc(a) @@ -684,6 +739,8 @@ async def inc(a: int) -> int: """ ) result = input.select(ret=inc(pw.this.a)) + if sync == "fully_async": + result = result.await_futures() expected = T( """ ret @@ -704,11 +761,11 @@ async def inc(a: int) -> int: assert internal_inc.call_count == 3 # count did not change -@pytest.mark.parametrize("sync", [True, False]) -def test_udf_in_memory_cache_with_limit(sync: bool) -> None: +@pytest.mark.parametrize("sync", ["sync", "async", "fully_async"]) +def test_udf_in_memory_cache_with_limit(sync: str) -> None: internal_inc = mock.Mock() - if sync: + if sync == "sync": @pw.udf(cache_strategy=pw.udfs.InMemoryCache(max_size=0)) def inc(a: int) -> int: @@ -717,7 +774,10 @@ def inc(a: int) -> int: else: - @pw.udf(cache_strategy=pw.udfs.InMemoryCache(max_size=0)) + @pw.udf( + cache_strategy=pw.udfs.InMemoryCache(max_size=0), + executor=get_async_executor(sync == "fully_async"), + ) async def inc(a: int) -> int: await asyncio.sleep(a / 10) internal_inc(a) @@ -732,6 +792,8 @@ async def inc(a: int) -> int: """ ) result = input.select(ret=inc(pw.this.a)) + if sync == "fully_async": + result = result.await_futures() expected = T( """ ret @@ -745,11 +807,23 @@ async def inc(a: int) -> int: assert internal_inc.call_count == 3 -@pytest.mark.parametrize("sync", [True, False]) +@pytest.mark.parametrize( + "sync", + [ + "sync", + "async", + pytest.param( + "fully_async", + marks=pytest.mark.xfail( + sys.platform != "linux", reason="InMemoryCache uses incompatible loop" + ), + ), + ], +) def test_udf_in_memory_cache_multiple_places(sync: bool) -> None: internal_inc = mock.Mock() - if sync: + if sync == "sync": @pw.udf(cache_strategy=pw.udfs.InMemoryCache()) def inc(a: int) -> int: @@ -758,7 +832,10 @@ def inc(a: int) -> int: else: - @pw.udf(cache_strategy=pw.udfs.InMemoryCache()) + @pw.udf( + cache_strategy=pw.udfs.InMemoryCache(), + executor=get_async_executor(sync == "fully_async"), + ) async def inc(a: int) -> int: internal_inc(a) return a + 1 @@ -775,6 +852,8 @@ async def inc(a: int) -> int: ) result = input.with_columns(ret=inc(pw.this.a)) result = result.with_columns(ret_2=inc(pw.this.a)) + if sync == "fully_async": + result = result.await_futures() expected = T( """ a | ret | ret_2 @@ -815,10 +894,19 @@ def f(a: int) -> int: f(pw.this.a) -def test_cast_on_return() -> None: - @pw.udf() - def f(a: int) -> float: - return a +@pytest.mark.parametrize("sync", ["sync", "async", "fully_async"]) +def test_cast_on_return(sync: str) -> None: + if sync == "sync": + + @pw.udf() + def f(a: int) -> float: + return a + + else: + + @pw.udf(executor=get_async_executor(sync == "fully_async")) + async def f(a: int) -> float: + return a t = pw.debug.table_from_markdown( """ @@ -829,6 +917,8 @@ def f(a: int) -> float: """ ).with_columns(a=f(pw.this.a)) + if sync == "fully_async": + t = t.await_futures() res = t.select(c=pw.this.a + pw.this.b) expected = pw.debug.table_from_markdown( """ @@ -966,3 +1056,412 @@ def f(a: int) -> int: ) assert_stream_equality(result, expected) + + +def test_fully_async_udf(): + @pw.udf(executor=pw.udfs.fully_async_executor()) + async def inc(a: int) -> int: + return a + 1 + + input = pw.debug.table_from_markdown( + """ + a + 1 + 2 + 3 + """ + ) + + result = input.select(ret=inc(pw.this.a)) + + assert_table_equality_wo_types( + result, + T( + """ + ret + 2 + 3 + 4 + """, + ), + ) + + +def test_fully_async_udf_propagation_allowed(): + @pw.udf(executor=pw.udfs.fully_async_executor()) + async def inc(a: int) -> int: + return a + 1 + + input = pw.debug.table_from_markdown( + """ + a + 1 + 2 + 3 + """ + ) + + t = input.with_columns(ret=inc(pw.this.a)) + result = t.select(a=pw.this.a + 2, b=pw.this.ret) + + assert_table_equality_wo_types( + result, + T( + """ + a | b + 3 | 2 + 4 | 3 + 5 | 4 + """, + ), + ) + + +@pytest.mark.skipif( + sys.version_info < (3, 11), reason="_asyncio.Future arg not printed" +) +def test_future_dtype_disallowed_expression(): + @pw.udf(executor=pw.udfs.fully_async_executor()) + async def inc(a: int) -> int: + return a + 1 + + input = pw.debug.table_from_markdown( + """ + a + 1 + 2 + 3 + """ + ) + + msg = "Pathway does not support using binary operator add on columns of types _asyncio.Future[int], ." + with pytest.raises(TypeError, match=re.escape(msg)): + input.select(ret=inc(pw.this.a) + 1) + + +def table_with_future_ret() -> pw.Table: + @pw.udf(executor=pw.udfs.fully_async_executor()) + async def inc(a: int) -> int: + return a + 1 + + input = pw.debug.table_from_markdown( + """ + a + 1 + 2 + 3 + """ + ) + return input.with_columns(ret=inc(pw.this.a)) + + +@pytest.mark.skipif( + sys.version_info < (3, 11), reason="_asyncio.Future arg not printed" +) +def test_future_dtype_disallowed_reduce(): + t = table_with_future_ret() + msg = ( + "Cannot perform pathway.reducers.sum when column of type _asyncio.Future[int] is involved." + + " Consider applying `await_futures()` to the table used here" + ) + with pytest.raises(TypeError, match=re.escape(msg)): + t.reduce(s=pw.reducers.sum(pw.this.ret)) + + +@pytest.mark.skipif( + sys.version_info < (3, 11), reason="_asyncio.Future arg not printed" +) +def test_future_dtype_disallowed_in_groupby(): + t = table_with_future_ret() + msg = ( + "Using column of type _asyncio.Future[int] is not allowed here." + + " Consider applying `await_futures()` to the table first." + ) + with pytest.raises(TypeError, match=re.escape(msg)): + t.groupby(pw.this.ret) + + +@pytest.mark.skipif( + sys.version_info < (3, 11), reason="_asyncio.Future arg not printed" +) +def test_future_dtype_disallowed_in_sort_key(): + t = table_with_future_ret() + msg = ( + "Using column of type _asyncio.Future[int] is not allowed here." + + " Consider applying `await_futures()` to the table first." + ) + with pytest.raises(TypeError, match=re.escape(msg)): + t.sort(pw.this.ret) + + +@pytest.mark.skipif( + sys.version_info < (3, 11), reason="_asyncio.Future arg not printed" +) +def test_future_dtype_disallowed_in_sort_instance(): + t = table_with_future_ret() + msg = ( + "Using column of type _asyncio.Future[int] is not allowed here." + + " Consider applying `await_futures()` to the table first." + ) + with pytest.raises(TypeError, match=re.escape(msg)): + t.sort(pw.this.a, instance=pw.this.ret) + + +@pytest.mark.skipif( + sys.version_info < (3, 11), reason="_asyncio.Future arg not printed" +) +def test_future_dtype_disallowed_in_deduplicate(): + t = table_with_future_ret() + + def acceptor(new_value, old_value) -> bool: + return new_value >= old_value + 2 + + msg = ( + "Using column of type _asyncio.Future[int] is not allowed here." + + " Consider applying `await_futures()` to the table first." + ) + with pytest.raises(TypeError, match=re.escape(msg)): + t.deduplicate(value=pw.this.ret, acceptor=acceptor) + + +@pytest.mark.skipif( + sys.version_info < (3, 11), reason="_asyncio.Future arg not printed" +) +def test_future_dtype_disallowed_in_deduplicate_instance(): + t = table_with_future_ret() + + def acceptor(new_value, old_value) -> bool: + return new_value >= old_value + 2 + + msg = ( + "Using column of type _asyncio.Future[int] is not allowed here." + + " Consider applying `await_futures()` to the table first." + ) + with pytest.raises(TypeError, match=re.escape(msg)): + t.deduplicate(value=pw.this.a, instance=pw.this.ret, acceptor=acceptor) + + +@pytest.mark.skipif( + sys.version_info < (3, 11), reason="_asyncio.Future arg not printed" +) +def test_future_dtype_disallowed_in_expressions(): + t = table_with_future_ret() + msg = ( + "Cannot perform pathway.pointer_from when column of type _asyncio.Future[int] is involved." + + " Consider applying `await_futures()` to the table used here." + ) + with pytest.raises(TypeError, match=re.escape(msg)): + t.select(p=t.pointer_from(t.ret)) + + msg = "Cannot perform pathway.if_else on columns of types _asyncio.Future[int] and ." + with pytest.raises(TypeError, match=re.escape(msg)): + t.select(p=pw.if_else(t.a > 2, t.ret, 2)) + + msg = ( + "Cannot perform pathway.make_tuple when column of type _asyncio.Future[int] is involved." + + " Consider applying `await_futures()` to the table used here." + ) + with pytest.raises(TypeError, match=re.escape(msg)): + t.select(p=pw.make_tuple(t.ret, 2)) + + msg = ( + "Cannot perform pathway.is_none when column of type _asyncio.Future[int] is involved." + + " Consider applying `await_futures()` to the table used here." + ) + with pytest.raises(TypeError, match=re.escape(msg)): + t.select(p=t.ret.is_none()) + + msg = ( + "Cannot perform pathway.is_not_none when column of type _asyncio.Future[int] is involved." + + " Consider applying `await_futures()` to the table used here." + ) + with pytest.raises(TypeError, match=re.escape(msg)): + t.select(p=t.ret.is_not_none()) + + @pw.udf + def foo(a: int) -> int: + return a - 1 + + @pw.udf + async def bar(a: int) -> int: + return a - 1 + + msg = ( + "Cannot perform pathway.apply when column of type _asyncio.Future[int] is involved." + + " Consider applying `await_futures()` to the table used here." + ) + with pytest.raises(TypeError, match=re.escape(msg)): + t.select(p=foo(t.ret)) + + msg = ( + "Cannot perform pathway.apply_async when column of type _asyncio.Future[int] is involved." + + " Consider applying `await_futures()` to the table used here." + ) + with pytest.raises(TypeError, match=re.escape(msg)): + t.select(p=bar(t.ret)) + + +@pytest.mark.skipif( + sys.version_info < (3, 11), reason="_asyncio.Future arg not printed" +) +def test_future_dtype_disallowed_in_expressions_2(): + @pw.udf(executor=pw.udfs.fully_async_executor()) + async def inc(a: int | None) -> int | None: + if a is None: + return None + return a + 1 + + input = pw.debug.table_from_markdown( + """ + a | b + 1 | 1 + 2 | 1 + 3 | 1 + | 1 + """ + ) + t = input.with_columns(ret=inc(pw.this.a)) + + msg = ( + "Cannot perform pathway.coalesce when column of type _asyncio.Future[int | None] is involved." + + " Consider applying `await_futures()` to the table used here." + ) + with pytest.raises(TypeError, match=re.escape(msg)): + t.select(p=pw.coalesce(t.ret, t.b)) + + msg = ( + "Cannot perform pathway.require when column of type _asyncio.Future[int | None] is involved." + + " Consider applying `await_futures()` to the table used here." + ) + with pytest.raises(TypeError, match=re.escape(msg)): + t.select(p=pw.require(t.ret, t.a)) + + +def test_fully_async_udf_expression_allowed_after_await(): + result = table_with_future_ret().await_futures().select(ret=pw.this.ret + 2) + + assert_table_equality( + result, + T( + """ + ret + 4 + 5 + 6 + """, + ), + ) + + +def test_fully_async_udf_reducer_allowed_after_await(): + result = ( + table_with_future_ret().await_futures().reduce(s=pw.reducers.sum(pw.this.ret)) + ) + + assert_table_equality_wo_index( + result, + T( + """ + s + 9 + """, + ), + ) + + +def test_fully_async_udf_chaining(): + @pw.udf(executor=pw.udfs.fully_async_executor()) + async def inc(a: int) -> int: + print(a) + return a + 1 + + input = pw.debug.table_from_markdown( + """ + a + 1 + 2 + 3 + """ + ) + + result = input.select(ret=inc(inc(pw.this.a))) + + assert_table_equality_wo_types( + result, + T( + """ + ret + 3 + 4 + 5 + """, + ), + ) + + +@pytest.mark.parametrize("fully_async", [True, False]) +def test_fully_async_udf_error_propagation(fully_async): + + @pw.udf(executor=get_async_executor(fully_async)) + async def inc(a: int) -> int: + return a + 1 + + input = pw.debug.table_from_markdown( + """ + a | b + 1 | 1 + 2 | 0 + 3 | 1 + """ + ) + + result = input.select(ret=inc(pw.this.a // pw.this.b)) + if fully_async: + result = result.await_futures() + result = result.select(ret=pw.fill_error(pw.this.ret, -1)) + + assert_table_equality( + result, + T( + """ + ret + 2 + -1 + 4 + """, + ), + terminate_on_error=False, + ) + + +def test_fully_async_udf_first_result_after_deletion_and_next_insertion(): + class InputSchema(pw.Schema): + a: int + + class InputSubject(pw.io.python.ConnectorSubject): + def run(self): + time.sleep(2) + self._add_inner(api.ref_scalar(3), dict(a=10)) + time.sleep(0.2) + self._remove_inner(api.ref_scalar(3), dict(a=10)) + time.sleep(0.2) + self._add_inner(api.ref_scalar(3), dict(a=12)) + + @pw.udf + def foo(a: int) -> int: + return a + 1 + + @pw.udf(executor=pw.udfs.fully_async_executor(autocommit_duration_ms=10)) + async def bar(a: int) -> int: + time.sleep(0.5) + return a + 2 + + t = pw.io.python.read(InputSubject(), schema=InputSchema, autocommit_duration_ms=10) + res = t.select(x=foo(pw.this.a), y=bar(pw.this.a)) + expected = pw.debug.table_from_markdown( + """ + | x | y + 3 | 13 | 14 + """ + ) + assert_table_equality_wo_types(res, expected) diff --git a/python/pathway/udfs.py b/python/pathway/udfs.py index 723606e6..e7b61cba 100644 --- a/python/pathway/udfs.py +++ b/python/pathway/udfs.py @@ -57,6 +57,7 @@ async_options, auto_executor, coerce_async, + fully_async_executor, sync_executor, udf, with_cache_strategy, @@ -69,6 +70,7 @@ "auto_executor", "async_executor", "sync_executor", + "fully_async_executor", "CacheStrategy", "DefaultCache", "DiskCache", diff --git a/src/connectors/adaptors.rs b/src/connectors/adaptors.rs index 928cc8fb..8e9ff377 100644 --- a/src/connectors/adaptors.rs +++ b/src/connectors/adaptors.rs @@ -1,6 +1,4 @@ // Copyright © 2024 Pathway - -use std::collections::hash_map::Entry; use std::collections::HashMap; use differential_dataflow::collection::AsCollection; @@ -56,27 +54,7 @@ impl UpsertSession fn consolidate_buffer(&mut self) { let mut keep: HashMap = HashMap::new(); for ((key, value), time, diff) in self.buffer.drain(..) { - if diff == 1 { - // If there's no entry, insert. - // If there's an entry that is an insertion, replace it with a newer insertion - // (both have the same timestamp). - // If there's an entry that is a deletion, replace it with a newer entry. - // It'll remove the old value (upsert) and insert a new one. - keep.insert(key, (value, time, diff)); - } else { - assert_eq!(diff, -1); - match keep.entry(key) { - Entry::Occupied(occupied_entry) => { - // If there's an entry, remove it. - occupied_entry.remove(); - } - Entry::Vacant(vacant_entry) => { - // If there's no entry, it means we remove entry from previous batches. - // So we have to keep a deletion. - vacant_entry.insert((value, time, diff)); - } - } - } + keep.insert(key, (value, time, diff)); } self.buffer.extend( keep.into_iter() diff --git a/src/connectors/data_format.rs b/src/connectors/data_format.rs index cc91624b..b86ca72d 100644 --- a/src/connectors/data_format.rs +++ b/src/connectors/data_format.rs @@ -423,6 +423,12 @@ pub enum FormatterError { #[error("Error value is not bson-serializable")] ErrorValueNonBsonSerializable, + #[error("Pending value is not json-serializable")] + PendingValueNonJsonSerializable, + + #[error("Pending value is not bson-serializable")] + PendingValueNonBsonSerializable, + #[error("this connector doesn't support this value type")] UnsupportedValueType, @@ -1235,6 +1241,7 @@ fn serialize_value_to_json(value: &Value) -> Result { Ok(json!(encoded)) } Value::Error => Err(FormatterError::ErrorValueNonJsonSerializable), + Value::Pending => Err(FormatterError::PendingValueNonJsonSerializable), } } @@ -2062,6 +2069,7 @@ fn serialize_value_to_bson(value: &Value) -> Result { Value::PyObjectWrapper(_) => Err(FormatterError::TypeNonBsonSerializable { type_: Type::PyObjectWrapper, }), + Value::Pending => Err(FormatterError::PendingValueNonBsonSerializable), } } diff --git a/src/connectors/data_lake/delta.rs b/src/connectors/data_lake/delta.rs index 6eb8e898..84a3d8fd 100644 --- a/src/connectors/data_lake/delta.rs +++ b/src/connectors/data_lake/delta.rs @@ -167,7 +167,7 @@ impl DeltaBatchWriter { DeltaTableKernelType::Struct(struct_descriptor.into()) } Type::Optional(wrapped) => return Self::delta_table_type(wrapped), - Type::Any => return Err(WriteError::UnsupportedType(type_.clone())), + Type::Any | Type::Future(_) => return Err(WriteError::UnsupportedType(type_.clone())), }; Ok(delta_type) } diff --git a/src/connectors/data_lake/iceberg.rs b/src/connectors/data_lake/iceberg.rs index e080076d..4011e914 100644 --- a/src/connectors/data_lake/iceberg.rs +++ b/src/connectors/data_lake/iceberg.rs @@ -196,7 +196,7 @@ impl IcebergTableParams { let array_type = IcebergListType::new(nested_type.into()); IcebergType::List(array_type) } - Type::Any | Type::Array(_, _) | Type::Tuple(_) => { + Type::Any | Type::Array(_, _) | Type::Tuple(_) | Type::Future(_) => { return Err(WriteError::UnsupportedType(type_.clone())) } }; diff --git a/src/connectors/data_lake/writer.rs b/src/connectors/data_lake/writer.rs index c1f2883a..a2958c87 100644 --- a/src/connectors/data_lake/writer.rs +++ b/src/connectors/data_lake/writer.rs @@ -360,7 +360,7 @@ impl LakeWriter { let struct_descriptor = ArrowFields::from(struct_fields); ArrowDataType::Struct(struct_descriptor) } - Type::Any => return Err(WriteError::UnsupportedType(type_.clone())), + Type::Any | Type::Future(_) => return Err(WriteError::UnsupportedType(type_.clone())), }) } diff --git a/src/connectors/data_storage.rs b/src/connectors/data_storage.rs index 95359c51..31706ffd 100644 --- a/src/connectors/data_storage.rs +++ b/src/connectors/data_storage.rs @@ -1185,7 +1185,7 @@ impl PsqlWriter { } return Err(WriteError::UnsupportedType(type_.clone())); } - Type::Any | Type::Array(_, _) => { + Type::Any | Type::Array(_, _) | Type::Future(_) => { return Err(WriteError::UnsupportedType(type_.clone())) } }) @@ -1299,6 +1299,7 @@ mod to_sql { try_forward!(Vec, bincode::serialize(self).map_err(|e| *e)?); "python object" } + Self::Pending => "pending", }; Err(Box::new(WrongPathwayType { pathway_type: pathway_type.to_owned(), diff --git a/src/engine/dataflow.rs b/src/engine/dataflow.rs index 37c21200..3afbb668 100644 --- a/src/engine/dataflow.rs +++ b/src/engine/dataflow.rs @@ -3,6 +3,7 @@ #![allow(clippy::module_name_repetitions)] #![allow(clippy::non_canonical_partial_ord_impl)] // False positive with Derivative +mod async_transformer; mod complex_columns; pub mod config; mod export; @@ -85,6 +86,7 @@ use timely::progress::timestamp::Refines; use timely::progress::Timestamp as TimestampTrait; use xxhash_rust::xxh3::Xxh3 as Hasher; +use self::async_transformer::async_transformer; use self::complex_columns::complex_columns; use self::export::{export_table, import_table}; use self::maybe_total::MaybeTotalScope; @@ -99,7 +101,9 @@ use self::variable::SafeVariable; use super::error::{DataError, DataResult, DynError, DynResult, Trace}; use super::expression::AnyExpression; use super::external_index_wrappers::{ExternalIndexData, ExternalIndexQuery}; -use super::graph::{DataRow, ExportedTable, OperatorProperties, SubscribeCallbacks}; +use super::graph::{ + DataRow, ExportedTable, OperatorProperties, SubscribeCallbacks, SubscribeConfig, +}; use super::http_server::maybe_run_http_server_thread; use super::license::License; use super::progress_reporter::{maybe_run_reporter, MonitoringLevel}; @@ -853,6 +857,17 @@ enum Tuple { More(Arc<[Value]>), } +impl Tuple { + fn with_appended(self, value: Value) -> Self { + match self { + Tuple::Zero => Tuple::One(value), + Tuple::One(old_value) => Tuple::Two([old_value, value]), + Tuple::Two([value_1, value_2]) => Tuple::More(Arc::new([value_1, value_2, value])), + Tuple::More(values) => Tuple::More(values.iter().cloned().chain([value]).collect()), + } + } +} + impl Deref for Tuple { type Target = [Value]; @@ -996,6 +1011,16 @@ impl FilterOutErrors for Collection { } } +trait FilterOutPending { + fn filter_out_pending(&self) -> Self; +} + +impl FilterOutPending for Collection { + fn filter_out_pending(&self) -> Self { + self.filter(move |(_key, values)| !values.as_value_slice().contains(&Value::Pending)) + } +} + #[derive(Derivative, Debug, Clone, Serialize, Deserialize)] #[derivative(PartialEq, Eq, PartialOrd, Ord, Hash)] struct KeyWith( @@ -3090,17 +3115,18 @@ impl DataflowGraphInner { Ok(()) } - fn remove_errors_from_table( + fn remove_value_from_table( &mut self, table_handle: TableHandle, column_paths: Vec, + value: Value, table_properties: Arc, ) -> Result { let new_values = self .extract_columns(table_handle, column_paths)? .as_collection() - .filter_out_errors(None) - .map_named("remove_errors_from_table", |(key, tuple)| { + .filter(move |(_key, values)| !values.as_value_slice().contains(&value)) + .map_named("remove_value_from_table", |(key, tuple)| { (key, Value::from(tuple.as_value_slice())) }); @@ -3616,35 +3642,37 @@ impl> DataflowGraphInner .alloc(Table::from_collection(values).with_properties(table_properties))) } - fn new_upsert_collection( + fn maybe_persisted_upsert_collection( &mut self, collection: &Collection, ) -> Result> { - collection.maybe_persist_with_logic( - self, - "upsert_collection", - |collection| { - let upsert_stream = collection.inner.map(|((key, value), time, diff)| { - // same behavior for new and persisted variants - let value = match value { - OldOrNew::Old(value) | OldOrNew::New(value) => value, - }; - let value_for_upsert = if diff == 1 { - Some(value) - } else { - assert_eq!(diff, -1); - None - }; - (key, value_for_upsert, time) - }); - arrange_from_upsert::>>>( - &upsert_stream, - "UpsertSession", - ) - .as_collection(|k, v| (*k, v.clone())) - }, - |d| d, - ) + collection + .maybe_persist_with_logic( + self, + "upsert_collection", + |collection| { + let upsert_stream = collection.inner.map(|((key, value), time, diff)| { + // same behavior for new and persisted variants + let value = match value { + OldOrNew::Old(value) | OldOrNew::New(value) => value, + }; + let value_for_upsert = if diff == 1 { + Some(value) + } else { + assert_eq!(diff, -1); + None + }; + (key, value_for_upsert, time) + }); + arrange_from_upsert::>>>( + &upsert_stream, + "UpsertSession", + ) + .as_collection(|k, v| (*k, v.clone())) + }, + |d| d, + )? + .filter_out_persisted(&mut self.persistence_wrapper) } fn new_collection( @@ -3663,7 +3691,7 @@ impl> DataflowGraphInner SessionType::Upsert => { let mut upsert_session = UpsertSession::new(); let collection = upsert_session.to_collection(&mut self.scope); - let collection = self.new_upsert_collection(&collection)?; + let collection = self.maybe_persisted_upsert_collection(&collection)?; Ok((Box::new(upsert_session), collection)) } } @@ -4096,15 +4124,19 @@ impl> DataflowGraphInner } #[allow(clippy::too_many_arguments)] + #[allow(clippy::too_many_lines)] fn subscribe_table( &mut self, table_handle: TableHandle, column_paths: Vec, callbacks: SubscribeCallbacks, - skip_persisted_batch: bool, - skip_errors: bool, + config: SubscribeConfig, unique_name: Option, sort_by_indices: Option>, + logic: impl FnOnce( + &mut DataflowGraphInner, + Collection, + ) -> Result>, ) -> Result<()> { let worker_index = self.scope.index(); @@ -4117,7 +4149,7 @@ impl> DataflowGraphInner .persistence_wrapper .get_worker_persistent_storage() .cloned(); - let skip_initial_time = skip_persisted_batch && worker_persistent_storage.is_some(); + let skip_initial_time = config.skip_persisted_batch && worker_persistent_storage.is_some(); let error_reporter = self.error_reporter.clone(); let error_reporter_2 = self.error_reporter.clone(); @@ -4128,6 +4160,7 @@ impl> DataflowGraphInner mut on_data, mut on_time_end, mut on_end, + mut on_frontier, } = callbacks; let wrapper_2 = wrapper.clone(); @@ -4138,12 +4171,17 @@ impl> DataflowGraphInner let output_columns = self .extract_columns(table_handle, column_paths)? .as_collection(); - let output_columns = if skip_errors { + let output_columns = if config.skip_errors { output_columns.filter_out_errors(Some(error_logger)) } else { output_columns }; - output_columns + let output_columns = if config.skip_pending { + output_columns.filter_out_pending() + } else { + output_columns + }; + logic(self, output_columns)? .consolidate_for_output(true) .inspect(move |batch| { if batch.time.is_from_persistence() && skip_initial_time { @@ -4178,10 +4216,17 @@ impl> DataflowGraphInner // the first inspect for this frontier. if let Err(frontier) = event { stats.on_time_committed(frontier.first().copied().map(|t| t.0)); - if worker_index == 0 && frontier.is_empty() { - if let Some(on_end) = on_end.as_mut() { + if worker_index == 0 { + if frontier.is_empty() { + if let Some(on_end) = on_end.as_mut() { + wrapper_2 + .run(on_end) + .unwrap_with_reporter(&error_reporter_2); + } + } else if let Some(on_frontier) = on_frontier.as_mut() { + assert_eq!(frontier.len(), 1); wrapper_2 - .run(on_end) + .run(|| on_frontier(frontier[0])) .unwrap_with_reporter(&error_reporter_2); } } @@ -4940,8 +4985,7 @@ impl Graph for InnerDataflowGraph { _table_handle: TableHandle, _column_paths: Vec, _callbacks: SubscribeCallbacks, - _skip_persisted_batch: bool, - _skip_errors: bool, + _config: SubscribeConfig, _unique_name: Option, _sort_by_indices: Option>, ) -> Result<()> { @@ -5326,15 +5370,33 @@ impl Graph for InnerDataflowGraph { Err(Error::IoNotPossible) } - fn remove_errors_from_table( + fn remove_value_from_table( &self, table_handle: TableHandle, column_paths: Vec, + value: Value, table_properties: Arc, ) -> Result { - self.0 - .borrow_mut() - .remove_errors_from_table(table_handle, column_paths, table_properties) + self.0.borrow_mut().remove_value_from_table( + table_handle, + column_paths, + value, + table_properties, + ) + } + + fn async_transformer( + &self, + _table_handle: TableHandle, + _column_paths: Vec, + _callbacks: SubscribeCallbacks, + _reader: Box, + _parser: Box, + _commit_duration: Option, + _table_properties: Arc, + _skip_errors: bool, + ) -> Result { + Err(Error::IoNotPossible) } } @@ -5534,8 +5596,7 @@ impl> Graph for OuterDataflo table_handle: TableHandle, column_paths: Vec, callbacks: SubscribeCallbacks, - skip_persisted_batch: bool, - skip_errors: bool, + config: SubscribeConfig, unique_name: Option, sort_by_indices: Option>, ) -> Result<()> { @@ -5543,10 +5604,10 @@ impl> Graph for OuterDataflo table_handle, column_paths, callbacks, - skip_persisted_batch, - skip_errors, + config, unique_name, sort_by_indices, + |_graph, collection| Ok(collection), ) } @@ -5980,15 +6041,43 @@ impl> Graph for OuterDataflo self.0.borrow_mut().import_table(table) } - fn remove_errors_from_table( + fn remove_value_from_table( &self, table_handle: TableHandle, column_paths: Vec, + value: Value, table_properties: Arc, ) -> Result { - self.0 - .borrow_mut() - .remove_errors_from_table(table_handle, column_paths, table_properties) + self.0.borrow_mut().remove_value_from_table( + table_handle, + column_paths, + value, + table_properties, + ) + } + + fn async_transformer( + &self, + table_handle: TableHandle, + column_paths: Vec, + callbacks: SubscribeCallbacks, + reader: Box, + parser: Box, + commit_duration: Option, + table_properties: Arc, + skip_errors: bool, + ) -> Result { + async_transformer( + &mut self.0.borrow_mut(), + table_handle, + column_paths, + callbacks, + reader, + parser, + commit_duration, + table_properties, + skip_errors, + ) } } diff --git a/src/engine/dataflow/async_transformer.rs b/src/engine/dataflow/async_transformer.rs new file mode 100644 index 00000000..626447a5 --- /dev/null +++ b/src/engine/dataflow/async_transformer.rs @@ -0,0 +1,396 @@ +// Copyright © 2025 Pathway + +use std::cell::RefCell; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::mem::take; +use std::rc::Rc; +use std::sync::Arc; +use std::time::Duration; + +use differential_dataflow::input::InputSession; +use differential_dataflow::{AsCollection, Collection}; +use timely::dataflow::operators::{Exchange, Inspect, Map}; +use timely::progress::Timestamp as _; + +use crate::connectors::adaptors::InputAdaptor; +use crate::connectors::data_format::Parser; +use crate::connectors::data_storage::ReaderBuilder; +use crate::connectors::{Connector, PersistenceMode, SnapshotAccess}; +use crate::engine::graph::{SubscribeCallbacks, SubscribeConfig}; +use crate::engine::{ + ColumnPath, Error, Key, OriginalOrRetraction, Result, TableHandle, TableProperties, Timestamp, + Value, +}; + +use super::maybe_total::MaybeTotalScope; +use super::operators::output::ConsolidateForOutput; +use super::operators::{MapWrapped, MaybeTotal, Reshard}; +use super::{DataflowGraphInner, MaybePersist, Table, Tuple}; + +/// `AsyncTransformer` allows for fully asynchronous computation on the python side. +/// Computation results are returned to the engine in a later (or equal) time +/// than the entry that triggered the computation. +/// +/// To achieve full asynchronicity, `AsyncTransformer` utilizes +/// python output connector (subscribe) and python input connector. +/// The computation happens in two streams: +/// +/// 1. stream to subscribe operator: +/// - output columns extraction (in subscribe) +/// - computation of final time in the input collection +/// (used to call `on_end` in subscribe when the input stream finishes) +/// - concatenation of the forgetting stream with input rows that have their computation finished +/// - persisting the stream. On restart used to produce rows for which `AsyncTransformer` didn't finish in the previous run. +/// - saving values of the input stream in `values_currently_processed` so that we can get input values when we get a result. +/// To get an appropriate input entry, `task_id` is generated. Every task has a different `task_id` (even rows with the same key). +/// In this step also forgetting entries are removed (`is_original` check). +/// - consolidation (in subscribe) +/// - sending data to python (in subscribe), +/// `on_frontier` is used to call `on_end` when the original stream finishes. +/// Without that, we would never finish as the forgetting stream would provide new time updates forever. +/// When we call `on_end`, the python connector finishes and that makes the forgetting stream finish. +/// +/// 2. stream from python input connector +/// - special input session (`AsyncTransformerSession`) used to assign `seq_id` for each entry. +/// It is used to deduplicate values with the same (key, time) pair. The deduplication cannot +/// be done in the input session itself because we need all values to produce the forgetting stream. +/// - running `PythonConnector`/`TransparentParser` pair with this input session +/// - extraction of a true row key and enriching the stream with input values from `values_currently_processed`. +/// - creation of the forgetting stream using the key and input values retrieved earlier. +/// The time is moved to the next retraction time so that it is easy to filter out forgetting entries later. +/// - grouping of the stream as batches (`consolidate_for_output`) +/// - keeping only the last entry (with greatest `seq_id`) for each key in each batch. +/// Having more entries for a single (key, time) pair would result in inconsistency in upserting later. +/// The final produced entry for some (key, time) pairs could not be the one with the greatest `seq_id`. +/// - upsert operator with persistence. Upsert is there to avoid recomputation for deletions. Without upserting, +/// there could be problems with inconsistency as the `AsyncTransformer` can be non-deterministic. +/// Persistence is there to be able to update values computed in previous runs. + +struct AsyncTransformerSession { + input_session: InputSession, + sequential_id: i64, +} + +impl AsyncTransformerSession { + #[allow(clippy::wrong_self_convention)] // consistent with other InputAdaptor implementors + fn to_collection>( + &mut self, + scope: &mut S, + ) -> Collection { + self.input_session.to_collection(scope) + } +} + +impl InputAdaptor for AsyncTransformerSession { + /// The implementation below mostly reuses differetial dataflow's `InputSession` internals. + /// + /// It adds a sequential id for each entry so that it is possible later to + /// deduplicate entries for a single (key, time) pair leaving only the last one. + + fn new() -> Self { + AsyncTransformerSession { + input_session: InputSession::new(), + sequential_id: 1, + } + } + + fn flush(&mut self) { + self.input_session.flush(); + } + + fn advance_to(&mut self, time: Timestamp) { + if *self.time() < time { + self.sequential_id = 0; + } + self.input_session.advance_to(time); + } + + fn insert(&mut self, key: Key, value: Value) { + self.input_session.insert((key, value, self.sequential_id)); + self.sequential_id += 1; + } + + fn remove(&mut self, key: Key, value: Value) { + self.input_session.remove((key, value, self.sequential_id)); + self.sequential_id += 1; + } + + fn time(&self) -> &Timestamp { + self.input_session.time() + } +} + +struct StreamCloseData { + max_time: Timestamp, + stream_closed: bool, + on_end_called: bool, +} + +impl StreamCloseData { + fn new(max_time: Timestamp, stream_closed: bool, on_end_called: bool) -> Self { + Self { + max_time, + stream_closed, + on_end_called, + } + } +} + +fn run_input_connector( + graph: &mut DataflowGraphInner, + reader: Box, + parser: Box, + commit_duration: Option, + input_session: Box>, +) -> Result<()> +where + S: MaybeTotalScope, +{ + let connector = Connector::new( + commit_duration, + parser.column_count(), + graph.terminate_on_error, + graph.create_error_logger()?.into(), + ); + let state = connector.run( + reader, + parser, + input_session, + move |values, _offset| { + let values = values.expect("key should be present"); + if let [Value::Pointer(key)] = values[..] { + key + } else { + panic!("values should contain exactly one key") + } + }, + graph.output_probe.clone(), + None, + None, + true, + None, + PersistenceMode::Batch, // default value from connector_table + SnapshotAccess::Full, // default value from connector_table + graph.error_reporter.clone(), + )?; + + graph.pollers.push(state.poller); + graph.connector_threads.push(state.input_thread_handle); + + Ok(()) +} + +fn run_output_connector( + graph: &mut DataflowGraphInner, + table_handle: TableHandle, + column_paths: Vec, + python_input_values: &Collection, + mut callbacks: SubscribeCallbacks, + values_currently_processed: Rc>>, + skip_errors: bool, +) -> Result<()> +where + S: MaybeTotalScope, +{ + let forgetting_stream = python_input_values + .inner + .map(move |((key, input_values, _values, _seq_id), time, diff)| { + // move time to retraction time so that it can be filtered out easily later + ((key, input_values), time.next_retraction_time(), diff) + }) + .as_collection() + .negate(); + + // safe because it is created inside a timely worker and a single worker exectues all operators in a single thread + let max_time_in_original_stream = Rc::new(RefCell::new(StreamCloseData::new( + Timestamp::minimum(), + false, + false, + ))); + + let logic = { + let max_time_in_original_stream = max_time_in_original_stream.clone(); + move |graph: &mut DataflowGraphInner, collection: Collection| { + Ok(collection + .inner + .exchange(|_data| 0) // Move all data to worker 0 to compute max time present in data. + // Currently async transformer is run only on worker 0 so it is not a bottleneck. + // If parallelization is needed, count the number of entries in each worker + // and only run ``on_end`` when all entries are sent. + .inspect_core(move |data_or_frontier| match data_or_frontier { + // needed for a hack to run ``on_end`` at the end of the true input stream + Ok(data) => { + let mut max_time_in_original_stream = + max_time_in_original_stream.borrow_mut(); + for (_key_val, timestamp, _diff) in data.1 { + max_time_in_original_stream.max_time = + std::cmp::max(*timestamp, max_time_in_original_stream.max_time); + } + } + Err(frontier) => { + if frontier.is_empty() { + max_time_in_original_stream.borrow_mut().stream_closed = true; + } + } + }) + .as_collection() + .concat(&forgetting_stream) + .maybe_persist(graph, "async_transformer")? // distributes data between workers if persistence is enabled + .reshard_to_first_worker() // if persistence is disabled, we need to reshard anyway + .inner + .flat_map(move |((key, tuple), time, diff)| { + // filter out retractions + // put new values in values_currently_processed so that they can be retracted later + if time.is_original() { + let task_id = Key::random(); + values_currently_processed + .borrow_mut() + .insert(task_id, (key, tuple.clone())); + Some(( + (key, tuple.with_appended(Value::Pointer(task_id))), + time, + diff, + )) + } else { + None + } + }) + .as_collection()) + } + }; + + // hack to run ``on_end`` at the end of the true input stream (ignoring forgetting stream from input reader) + let on_end = take(&mut callbacks.on_end); + if let Some(mut on_end) = on_end { + callbacks.on_frontier = Some(Box::new(move |timestamp| { + let mut max_time_in_original_stream = max_time_in_original_stream.borrow_mut(); + if max_time_in_original_stream.stream_closed + && max_time_in_original_stream.max_time < timestamp // strict comparison because in on_frontier timestamp is still on input + && !max_time_in_original_stream.on_end_called + { + max_time_in_original_stream.on_end_called = true; + on_end()?; + } + Ok(()) + })); + } + + graph.subscribe_table( + table_handle, + column_paths, + callbacks, + SubscribeConfig { + skip_persisted_batch: false, + skip_errors, + skip_pending: true, + }, + None, + None, + logic, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn async_transformer( + graph: &mut DataflowGraphInner, + table_handle: TableHandle, + column_paths: Vec, + callbacks: SubscribeCallbacks, + reader: Box, + parser: Box, + commit_duration: Option, + table_properties: Arc, + skip_errors: bool, +) -> Result +where + S: MaybeTotalScope, +{ + let mut input_session = AsyncTransformerSession::new(); + let python_input_values = input_session.to_collection(&mut graph.scope); + + if graph.scope.index() == 0 { + run_input_connector( + graph, + reader, + parser, + commit_duration, + Box::new(input_session), + )?; + } + + // safe because it is created inside a single timely worker and a single worker exectues all operators in a single thread + let values_currently_processed: Rc>> = + Rc::new(RefCell::new(HashMap::new())); + + let values = python_input_values + .reshard_to_first_worker() + .inner + .map({ + let values_currently_processed = values_currently_processed.clone(); + move |((task_id, values, seq_id), time, diff)| { + // get input values (they have to be the same to remove values from persisted input) + let (input_key, input_values) = values_currently_processed + .borrow_mut() + .remove(&task_id) + .expect("task_id has to be present"); + + ((input_key, input_values, values, seq_id), time, diff) + } + }) + .as_collection(); + + run_output_connector( + graph, + table_handle, + column_paths, + &values, + callbacks, + values_currently_processed, + skip_errors, + )?; + + let table = graph + .tables + .get(table_handle) + .ok_or(Error::InvalidTableHandle)?; + + let pending = MaybeTotal::distinct( + &table + .values() + .clone() + .maybe_persist(graph, "AsyncTransformer: Pending")?, + ) + .filter_out_persisted(&mut graph.persistence_wrapper)? + .map_named("AsyncTransformer: Pending", |(key, _values)| { + (key, Tuple::Zero, Value::Pending, 0) + }); + + let output_values = values + .concat(&pending) + .consolidate_for_output(false) + .flat_map(|batch| { + let mut keep: HashMap = HashMap::new(); + for ((key, _input_values, values, seq_id), diff) in batch.data { + match keep.entry(key) { + Entry::Occupied(mut entry) => { + let (_, _, entry_seq_id) = entry.get(); + if seq_id > *entry_seq_id { + entry.insert((values, diff, seq_id)); + } + } + Entry::Vacant(entry) => { + entry.insert((values, diff, seq_id)); + } + } + } + keep.into_iter() + .map(move |(key, (value, diff, _seq_id))| ((key, value), batch.time, diff)) + }) + .as_collection(); + + let output_values_maybe_persisted = graph.maybe_persisted_upsert_collection(&output_values)?; + Ok(graph.tables.alloc( + Table::from_collection(output_values_maybe_persisted).with_properties(table_properties), + )) +} diff --git a/src/engine/dataflow/operators.rs b/src/engine/dataflow/operators.rs index 43fede84..b88668e8 100644 --- a/src/engine/dataflow/operators.rs +++ b/src/engine/dataflow/operators.rs @@ -419,6 +419,7 @@ where R: Semigroup, { fn reshard(&self) -> Collection; + fn reshard_to_first_worker(&self) -> Collection; } impl Reshard for Collection @@ -432,4 +433,8 @@ where .exchange(|(data, _time, _diff)| data.shard()) .as_collection() } + + fn reshard_to_first_worker(&self) -> Collection { + self.inner.exchange(|_| 0).as_collection() + } } diff --git a/src/engine/dataflow/persist.rs b/src/engine/dataflow/persist.rs index fe95f5d3..95fba7c2 100644 --- a/src/engine/dataflow/persist.rs +++ b/src/engine/dataflow/persist.rs @@ -23,7 +23,7 @@ use crate::engine::dataflow::maybe_total::MaybeTotalScope; use crate::engine::dataflow::operators::stateful_reduce::StatefulReduce; use crate::engine::dataflow::operators::MapWrapped; use crate::engine::dataflow::shard::Shard; -use crate::engine::dataflow::{MaybeUpdate, Poller, SortingCell}; +use crate::engine::dataflow::{MaybeUpdate, Poller, SortingCell, Tuple}; use crate::engine::reduce::IntSumState; use crate::engine::{Key, Result, Timestamp, Value}; use crate::persistence::config::PersistenceManagerConfig; @@ -204,6 +204,7 @@ pub(super) enum PersistableCollection { KeyIsizeIsize(Collection), KeyKeyValueKeyValueIsize(Collection), KeyVecValueIsize(Collection), isize>), + KeyTupleIsize(Collection), } macro_rules! impl_conversion { @@ -291,6 +292,7 @@ impl_conversion!( (Key, Vec), isize ); +impl_conversion!(PersistableCollection::KeyTupleIsize, (Key, Tuple), isize); pub struct TimestampBasedPersistenceWrapper { persistence_config: PersistenceManagerConfig, @@ -423,6 +425,9 @@ impl> PersistenceWrapper PersistableCollection::KeyVecValueIsize(collection) => { self.generic_maybe_persist(&collection, name, persistent_id) } + PersistableCollection::KeyTupleIsize(collection) => { + self.generic_maybe_persist(&collection, name, persistent_id) + } } } @@ -482,6 +487,9 @@ impl> PersistenceWrapper PersistableCollection::KeyVecValueIsize(collection) => { generic_filter_out_persisted(&collection) } + PersistableCollection::KeyTupleIsize(collection) => { + generic_filter_out_persisted(&collection) + } } } diff --git a/src/engine/dataflow/shard.rs b/src/engine/dataflow/shard.rs index 789930bc..8d1fa5ef 100644 --- a/src/engine/dataflow/shard.rs +++ b/src/engine/dataflow/shard.rs @@ -31,6 +31,12 @@ impl Shard for (Key, T, U) { } } +impl Shard for (Key, T, U, V) { + fn shard(&self) -> u64 { + self.0.shard() + } +} + impl Shard for i32 { #[allow(clippy::cast_sign_loss)] fn shard(&self) -> u64 { diff --git a/src/engine/graph.rs b/src/engine/graph.rs index 4ebc541a..e9f9ac5a 100644 --- a/src/engine/graph.rs +++ b/src/engine/graph.rs @@ -128,31 +128,26 @@ pub enum ColumnPath { impl ColumnPath { pub fn extract(&self, key: &Key, value: &Value) -> Result { - match self { - Self::Key => Ok(Value::from(*key)), - Self::ValuePath(path) => { - let mut value = value; - for i in path { - if *value == Value::None || *value == Value::Error { - break; - // needed in outer joins and replacing rows with duplicated ids with error - } - value = value - .as_tuple()? - .get(*i) - .ok_or_else(|| Error::InvalidColumnPath(self.clone()))?; - } - Ok(value.clone()) - } - } + self.extract_inner(Some(key), value) } pub fn extract_from_value(&self, value: &Value) -> Result { + self.extract_inner(None, value) + } + + fn extract_inner(&self, key: Option<&Key>, value: &Value) -> Result { match self { - Self::Key => Err(Error::ExtractFromValueNotSupportedForKey), + Self::Key => match key { + Some(key) => Ok(Value::from(*key)), + None => Err(Error::ExtractFromValueNotSupportedForKey), + }, Self::ValuePath(path) => { let mut value = value; for i in path { + if *value == Value::None || *value == Value::Error || *value == Value::Pending { + break; + // needed in outer joins and replacing rows with duplicated ids with error + } value = value .as_tuple()? .get(*i) @@ -550,6 +545,7 @@ pub struct SubscribeCallbacks { pub on_data: Option, pub on_time_end: Option, pub on_end: Option, + pub on_frontier: Option, } pub struct SubscribeCallbacksBuilder { @@ -564,6 +560,7 @@ impl SubscribeCallbacksBuilder { on_data: None, on_time_end: None, on_end: None, + on_frontier: None, }, } } @@ -604,6 +601,13 @@ impl Default for SubscribeCallbacksBuilder { } } +#[derive(Debug, Clone, Copy)] +pub struct SubscribeConfig { + pub skip_persisted_batch: bool, + pub skip_errors: bool, + pub skip_pending: bool, +} + pub type ExportedTableCallback = Box ControlFlow<()> + Send>; pub trait ExportedTable: Send + Sync + Any { @@ -736,8 +740,7 @@ pub trait Graph { table_handle: TableHandle, column_paths: Vec, callbacks: SubscribeCallbacks, - skip_persisted_batch: bool, - skip_errors: bool, + config: SubscribeConfig, unique_name: Option, sort_by_indices: Option>, ) -> Result<()>; @@ -986,12 +989,26 @@ pub trait Graph { fn import_table(&self, table: Arc) -> Result; - fn remove_errors_from_table( + fn remove_value_from_table( &self, table_handle: TableHandle, column_paths: Vec, + value: Value, table_properties: Arc, ) -> Result; + + #[allow(clippy::too_many_arguments)] + fn async_transformer( + &self, + table_handle: TableHandle, + column_paths: Vec, + callbacks: SubscribeCallbacks, + reader: Box, + parser: Box, + commit_duration: Option, + table_properties: Arc, + skip_errors: bool, + ) -> Result; } #[allow(clippy::module_name_repetitions)] @@ -1194,8 +1211,7 @@ impl Graph for ScopedGraph { table_handle: TableHandle, column_paths: Vec, callbacks: SubscribeCallbacks, - skip_persisted_batch: bool, - skip_errors: bool, + config: SubscribeConfig, unique_name: Option, sort_by_indices: Option>, ) -> Result<()> { @@ -1204,8 +1220,7 @@ impl Graph for ScopedGraph { table_handle, column_paths, callbacks, - skip_persisted_batch, - skip_errors, + config, unique_name, sort_by_indices, ) @@ -1641,12 +1656,40 @@ impl Graph for ScopedGraph { self.try_with(|g| g.import_table(table)) } - fn remove_errors_from_table( + fn remove_value_from_table( + &self, + table_handle: TableHandle, + column_paths: Vec, + value: Value, + table_properties: Arc, + ) -> Result { + self.try_with(|g| { + g.remove_value_from_table(table_handle, column_paths, value, table_properties) + }) + } + + fn async_transformer( &self, table_handle: TableHandle, column_paths: Vec, + callbacks: SubscribeCallbacks, + reader: Box, + parser: Box, + commit_duration: Option, table_properties: Arc, + skip_errors: bool, ) -> Result { - self.try_with(|g| g.remove_errors_from_table(table_handle, column_paths, table_properties)) + self.try_with(|g| { + g.async_transformer( + table_handle, + column_paths, + callbacks, + reader, + parser, + commit_duration, + table_properties, + skip_errors, + ) + }) } } diff --git a/src/engine/timestamp.rs b/src/engine/timestamp.rs index ad34adbb..d8115bed 100644 --- a/src/engine/timestamp.rs +++ b/src/engine/timestamp.rs @@ -152,10 +152,16 @@ pub trait OriginalOrRetraction { fn is_retraction(&self) -> bool { !self.is_original() } + #[must_use] + fn next_retraction_time(&self) -> Self; } impl OriginalOrRetraction for Timestamp { fn is_original(&self) -> bool { self.0 % 2 == 0 } + + fn next_retraction_time(&self) -> Self { + Self(self.0 + 1) + } } diff --git a/src/engine/value.rs b/src/engine/value.rs index 37cef457..d9005c5b 100644 --- a/src/engine/value.rs +++ b/src/engine/value.rs @@ -225,6 +225,7 @@ pub enum Value { Json(Handle), Error, PyObjectWrapper(Handle), + Pending, } const _: () = assert!(align_of::() <= 16); @@ -362,6 +363,7 @@ impl Display for Value { Self::Json(json) => write!(fmt, "{json}"), Self::Error => write!(fmt, "Error"), Self::PyObjectWrapper(ob) => write!(fmt, "{ob}"), + Self::Pending => write!(fmt, "Pending"), } } } @@ -501,6 +503,7 @@ pub enum Kind { Json, Error, PyObjectWrapper, + Pending, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -521,6 +524,7 @@ pub enum Type { List(Arc), PyObjectWrapper, Optional(Arc), + Future(Arc), } impl Type { @@ -565,6 +569,7 @@ impl Display for Type { Type::List(arg) => write!(f, "list[{arg}]"), Type::PyObjectWrapper => write!(f, "PyObjectWrapper"), Type::Optional(arg) => write!(f, "{arg} | None"), + Type::Future(arg) => write!(f, "Future[{arg}]"), } } } @@ -589,6 +594,7 @@ impl Value { Self::Json(_) => Kind::Json, Self::Error => Kind::Error, Self::PyObjectWrapper(_) => Kind::PyObjectWrapper, + Self::Pending => Kind::Pending, } } } @@ -733,6 +739,7 @@ impl HashInto for Value { Self::Json(json) => json.hash_into(hasher), Self::Error => panic!("trying to hash error"), // FIXME Self::PyObjectWrapper(ob) => ob.hash_into(hasher), + Self::Pending => panic!("trying to hash pending"), // FIXME } } } diff --git a/src/python_api.rs b/src/python_api.rs index 2ee91b25..cb16a29f 100644 --- a/src/python_api.rs +++ b/src/python_api.rs @@ -5,7 +5,8 @@ use crate::async_runtime::create_async_tokio_runtime; use crate::engine::graph::{ - ErrorLogHandle, ExportedTable, OperatorProperties, SubscribeCallbacksBuilder, + ErrorLogHandle, ExportedTable, OperatorProperties, SubscribeCallbacks, + SubscribeCallbacksBuilder, SubscribeConfig, }; use crate::engine::license::{Error as LicenseError, License}; use crate::engine::{ @@ -332,6 +333,9 @@ fn py_type_error(ob: &Bound, type_: &Type) -> PyErr { } pub fn extract_value(ob: &Bound, type_: &Type) -> PyResult { + if ob.is_instance_of::() { + return Ok(Value::Error); + } let extracted = match type_ { Type::Any => ob.extract().ok(), Type::Optional(arg) => { @@ -415,6 +419,13 @@ pub fn extract_value(ob: &Bound, type_: &Type) -> PyResult { }; Some(Value::from(value.into_internal())) } + Type::Future(arg) => { + if ob.is_instance_of::() { + Some(Value::Pending) + } else { + Some(extract_value(ob, arg)?) + } + } }; extracted.ok_or_else(|| py_type_error(ob, type_)) } @@ -424,6 +435,10 @@ impl<'py> FromPyObject<'py> for Value { let py = ob.py(); if ob.is_none() { Ok(Value::None) + } else if ob.is_exact_instance_of::() { + Ok(Value::Error) + } else if ob.is_exact_instance_of::() { + Ok(Value::Pending) } else if let Ok(s) = ob.downcast_exact::() { Ok(Value::from(s.to_str()?)) } else if let Ok(b) = ob.downcast_exact::() { @@ -533,6 +548,7 @@ impl ToPyObject for Value { Self::Json(j) => json_to_py_object(py, j), Self::Error => ERROR.clone_ref(py).into_py(py), Self::PyObjectWrapper(op) => PyObjectWrapper::from_internal(py, op).into_py(py), + Self::Pending => PENDING.clone_ref(py).into_py(py), } } } @@ -1682,6 +1698,10 @@ impl PathwayType { pub fn optional(wrapped: Type) -> Type { Type::Optional(wrapped.into()) } + #[staticmethod] + pub fn future(wrapped: Type) -> Type { + Type::Future(wrapped.into()) + } } #[pyclass(module = "pathway.engine", frozen, name = "ReadMethod")] @@ -2042,6 +2062,30 @@ mod error { } use error::{Error, ERROR}; +mod pending { + use once_cell::sync::Lazy; + use pyo3::prelude::*; + + struct InnerPending; + #[pyclass(module = "pathway.engine", frozen)] + pub struct Pending(InnerPending); + + #[pymethods] + impl Pending { + #[allow(clippy::unused_self)] + fn __repr__(&self) -> &'static str { + "Pending" + } + } + + pub static PENDING: Lazy> = Lazy::new(|| { + Python::with_gil(|py| { + Py::new(py, Pending(InnerPending)).expect("creating PENDING should not fail") + }) + }); +} +use pending::{Pending, PENDING}; + impl<'py> FromPyObject<'py> for ColumnPath { fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { let py = ob.py(); @@ -3165,33 +3209,16 @@ impl Scope { sort_by_indices: Option>, ) -> PyResult<()> { self_.borrow().register_unique_name(unique_name.as_ref())?; - let callbacks = SubscribeCallbacksBuilder::new() - .wrapper(BatchWrapper::WithGil) - .on_data(Box::new(move |key, values, time, diff| { - Python::with_gil(|py| { - on_change.call1(py, (key, PyTuple::new_bound(py, values), time, diff))?; - Ok(()) - }) - })) - .on_time_end(Box::new(move |new_time| { - Python::with_gil(|py| { - on_time_end.call1(py, (new_time,))?; - Ok(()) - }) - })) - .on_end(Box::new(move || { - Python::with_gil(|py| { - on_end.call0(py)?; - Ok(()) - }) - })) - .build(); + let callbacks = build_subscribe_callback(on_change, on_time_end, on_end); self_.borrow().graph.subscribe_table( table.handle, column_paths, callbacks, - skip_persisted_batch, - skip_errors, + SubscribeConfig { + skip_persisted_batch, + skip_errors, + skip_pending: true, + }, unique_name, sort_by_indices, )?; @@ -3266,19 +3293,96 @@ impl Scope { Table::new(self_, table_handle) } - pub fn remove_errors_from_table( + pub fn remove_value_from_table( self_: &Bound, table: PyRef, #[pyo3(from_py_with = "from_py_iterable")] column_paths: Vec, + value: Value, table_properties: TableProperties, ) -> PyResult> { - let new_table_handle = self_.borrow().graph.remove_errors_from_table( + let new_table_handle = self_.borrow().graph.remove_value_from_table( table.handle, column_paths, + value, table_properties.0, )?; Table::new(self_, new_table_handle) } + + #[allow(clippy::too_many_arguments)] + pub fn async_transformer( + self_: &Bound, + table: PyRef
, + #[pyo3(from_py_with = "from_py_iterable")] column_paths: Vec, + on_change: Py, + on_time_end: Py, + on_end: Py, + data_source: &Bound, + data_format: &Bound, + properties: ConnectorProperties, + skip_errors: bool, + ) -> PyResult> { + let py = self_.py(); + + let callbacks = build_subscribe_callback(on_change, on_time_end, on_end); + let connector_index = *self_.borrow().total_connectors.borrow(); + *self_.borrow().total_connectors.borrow_mut() += 1; + let (reader_impl, parallel_readers) = data_source.borrow().construct_reader( + py, + &data_format.borrow(), + connector_index, + self_.borrow().worker_index(), + self_.borrow().license.as_ref(), + false, + )?; + assert_eq!(parallel_readers, 1); // python connector that has parallel_readers == 1 has to be used + + let parser_impl = data_format.borrow().construct_parser(py)?; + let commit_duration = properties + .commit_duration_ms + .map(time::Duration::from_millis); + let column_properties = properties.column_properties(); + + let table_handle = self_.borrow().graph.async_transformer( + table.handle, + column_paths, + callbacks, + reader_impl, + parser_impl, + commit_duration, + Arc::new(EngineTableProperties::flat(column_properties)), + skip_errors, + )?; + Table::new(self_, table_handle) + } +} + +fn build_subscribe_callback( + on_change: Py, + on_time_end: Py, + on_end: Py, +) -> SubscribeCallbacks { + SubscribeCallbacksBuilder::new() + .wrapper(BatchWrapper::WithGil) + .on_data(Box::new(move |key, values, time, diff| { + Python::with_gil(|py| { + on_change.call1(py, (key, PyTuple::new_bound(py, values), time, diff))?; + Ok(()) + }) + })) + .on_time_end(Box::new(move |new_time| { + Python::with_gil(|py| { + on_time_end.call1(py, (new_time,))?; + Ok(()) + }) + })) + .on_end(Box::new(move || { + Python::with_gil(|py| { + on_end.call0(py)?; + Ok(()) + }) + })) + .build() } type CapturedTableData = Arc>>; @@ -3306,8 +3410,11 @@ fn capture_table_data( table.handle, column_paths, callbacks, - false, - false, + SubscribeConfig { + skip_persisted_batch: false, + skip_errors: false, + skip_pending: false, + }, None, None, )?; @@ -5705,6 +5812,7 @@ fn engine(_py: Python<'_>, m: &Bound) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -5726,6 +5834,7 @@ fn engine(_py: Python<'_>, m: &Bound) -> PyResult<()> { m.add("DONE", &*DONE)?; m.add("ERROR", &*ERROR)?; + m.add("PENDING", &*PENDING)?; Ok(()) }