diff --git a/tests/ast/decoder.py b/tests/ast/decoder.py index 21368b17c3..38fc45d128 100644 --- a/tests/ast/decoder.py +++ b/tests/ast/decoder.py @@ -3,7 +3,8 @@ # import logging -from typing import Any, Optional, Iterable, List, Union, Dict, Tuple +import re +from typing import Any, Optional, Iterable, List, Union, Dict, Tuple, Callable from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal @@ -11,7 +12,7 @@ from google.protobuf.json_format import MessageToDict -from snowflake.snowpark import Session, Column +from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions import snowflake.snowpark.functions from snowflake.snowpark.functions import udf, when from snowflake.snowpark.types import ( @@ -73,6 +74,38 @@ def capture_local_variable_name(self, assign_expr: proto.Assign) -> str: """ return assign_expr.symbol.value + def get_dataframe_analytics_function_column_formatter( + self, sp_dataframe_analytics_expr: proto.Expr + ) -> Callable: + """ + Create a dataframe analytics function column formatter. + This is mainly to pass the df_analytics_functions.test. + + Parameters + ---------- + sp_dataframe_analytics_expr : proto.Expr + The dataframe analytics expression. + + Returns + ------- + Callable + The dataframe analytics function column formatter. + """ + if "formattedColNames" in MessageToDict(sp_dataframe_analytics_expr): + formatted_col_names = list(sp_dataframe_analytics_expr.formatted_col_names) + w_lambda_pattern = re.compile(r"^(\w+)_W_(\w+)$") + xy_lambda_pattern = re.compile(r"^(\w+)_X_(\w+)_Y_(\w+)$") + if all(re.match(xy_lambda_pattern, col) for col in formatted_col_names): + return ( + lambda input, agg, window_size: f"{input}_X_{agg}_Y_{window_size}" + ) + elif all(re.match(w_lambda_pattern, col) for col in formatted_col_names): + return lambda input, agg: f"{input}_W_{agg}" + else: + return lambda input_col, agg, window: f"{agg}_{input_col}_{window}" + else: + return DataFrameAnalyticsFunctions._default_col_formatter + def decode_col_exprs(self, expr: proto.Expr, is_variadic: bool) -> List[Column]: """ Decode a protobuf object to a list of column expressions. @@ -286,7 +319,7 @@ def decode_dataframe_schema_expr( case _: raise ValueError( "Unknown dataframe schema type: %s" - % df_schema_expr.WhichOneof("variant") + % df_schema_expr.WhichOneof("sealed_value") ) def decode_data_type_expr( @@ -856,8 +889,11 @@ def decode_expr(self, expr: proto.Expr) -> Any: # DATAFRAME FUNCTIONS case "sp_create_dataframe": data = self.decode_dataframe_data_expr(expr.sp_create_dataframe.data) - schema = self.decode_dataframe_schema_expr( - expr.sp_create_dataframe.schema + d = MessageToDict(expr.sp_create_dataframe) + schema = ( + self.decode_dataframe_schema_expr(expr.sp_create_dataframe.schema) + if "schema" in d + else None ) df = self.session.create_dataframe(data=data, schema=schema) if hasattr(expr, "var_id"): @@ -882,6 +918,96 @@ def decode_expr(self, expr: proto.Expr) -> Any: name = expr.sp_dataframe_alias.name return df.alias(name) + case "sp_dataframe_analytics_compute_lag": + df = self.decode_expr(expr.sp_dataframe_analytics_compute_lag.df) + cols = [ + self.decode_expr(col) + for col in expr.sp_dataframe_analytics_compute_lag.cols + ] + group_by = list(expr.sp_dataframe_analytics_compute_lag.group_by) + lags = list(expr.sp_dataframe_analytics_compute_lag.lags) + order_by = list(expr.sp_dataframe_analytics_compute_lag.order_by) + col_formatter = self.get_dataframe_analytics_function_column_formatter( + expr.sp_dataframe_analytics_compute_lag + ) + return df.analytics.compute_lag( + cols, lags, order_by, group_by, col_formatter + ) + + case "sp_dataframe_analytics_compute_lead": + df = self.decode_expr(expr.sp_dataframe_analytics_compute_lead.df) + cols = [ + self.decode_expr(col) + for col in expr.sp_dataframe_analytics_compute_lead.cols + ] + group_by = list(expr.sp_dataframe_analytics_compute_lead.group_by) + leads = list(expr.sp_dataframe_analytics_compute_lead.leads) + order_by = list(expr.sp_dataframe_analytics_compute_lead.order_by) + col_formatter = self.get_dataframe_analytics_function_column_formatter( + expr.sp_dataframe_analytics_compute_lead + ) + return df.analytics.compute_lead( + cols, leads, order_by, group_by, col_formatter + ) + + case "sp_dataframe_analytics_cumulative_agg": + df = self.decode_expr(expr.sp_dataframe_analytics_cumulative_agg.df) + gen_aggs = self.decode_dsl_map_expr( + expr.sp_dataframe_analytics_cumulative_agg.aggs + ) + # The aggs dict created has generator objects as the kv pairs. Convert them to strings/list of strings. + aggs = {str(k): list(v) for k, v in gen_aggs.items()} + group_by = list(expr.sp_dataframe_analytics_cumulative_agg.group_by) + order_by = list(expr.sp_dataframe_analytics_cumulative_agg.order_by) + is_forward = ( + expr.sp_dataframe_analytics_cumulative_agg.is_forward + if hasattr(expr.sp_dataframe_analytics_cumulative_agg, "is_forward") + else False + ) + col_formatter = self.get_dataframe_analytics_function_column_formatter( + expr.sp_dataframe_analytics_cumulative_agg + ) + return df.analytics.cumulative_agg( + aggs, group_by, order_by, is_forward, col_formatter + ) + + case "sp_dataframe_analytics_moving_agg": + df = self.decode_expr(expr.sp_dataframe_analytics_moving_agg.df) + gen_aggs = self.decode_dsl_map_expr( + expr.sp_dataframe_analytics_moving_agg.aggs + ) + # The aggs dict created has generator objects as the kv pairs. Convert them to strings/list of strings. + aggs = {str(k): list(v) for k, v in gen_aggs.items()} + group_by = list(expr.sp_dataframe_analytics_moving_agg.group_by) + order_by = list(expr.sp_dataframe_analytics_moving_agg.order_by) + window_sizes = list(expr.sp_dataframe_analytics_moving_agg.window_sizes) + col_formatter = self.get_dataframe_analytics_function_column_formatter( + expr.sp_dataframe_analytics_moving_agg + ) + return df.analytics.moving_agg( + aggs, window_sizes, order_by, group_by, col_formatter + ) + + case "sp_dataframe_analytics_time_series_agg": + df = self.decode_expr(expr.sp_dataframe_analytics_time_series_agg.df) + gen_aggs = self.decode_dsl_map_expr( + expr.sp_dataframe_analytics_time_series_agg.aggs + ) + # The aggs dict created has generator objects as the kv pairs. Convert them to strings/list of strings. + aggs = {str(k): list(v) for k, v in gen_aggs.items()} + group_by = list(expr.sp_dataframe_analytics_time_series_agg.group_by) + sliding_interval = ( + expr.sp_dataframe_analytics_time_series_agg.sliding_interval + ) + time_col = expr.sp_dataframe_analytics_time_series_agg.time_col + windows = list(expr.sp_dataframe_analytics_time_series_agg.windows) + col_formatter = self.get_dataframe_analytics_function_column_formatter( + expr.sp_dataframe_analytics_time_series_agg + ) + return df.analytics.time_series_agg( + time_col, aggs, windows, group_by, sliding_interval, col_formatter + ) + case "sp_dataframe_col": col_name = expr.sp_dataframe_col.col_name df = self.decode_expr(expr.sp_dataframe_col.df) @@ -1072,6 +1198,14 @@ def decode_expr(self, expr: proto.Expr) -> Any: else: return df.sort(cols, ascending) + case "sp_dataframe_to_df": + df = self.decode_expr(expr.sp_dataframe_to_df.df) + col_names = list(expr.sp_dataframe_to_df.col_names) + if expr.sp_dataframe_to_df.variadic: + return df.to_df(*col_names) + else: + return df.to_df(col_names) + case "sp_dataframe_unpivot": df = self.decode_expr(expr.sp_dataframe_unpivot.df) column_list = [