Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1830534: Add decoder logic for the dataframe analytics functions #2803

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 139 additions & 5 deletions tests/ast/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
#

import logging
from typing import Any, Optional, Iterable, List, Union, Dict, Tuple
import re
from typing import Any, Optional, Iterable, List, Union, Dict, Tuple, Callable
from datetime import date, datetime, time, timedelta, timezone
from decimal import Decimal

import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto

from google.protobuf.json_format import MessageToDict

from snowflake.snowpark import Session, Column
from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions
import snowflake.snowpark.functions
from snowflake.snowpark.functions import udf, when
from snowflake.snowpark.types import (
Expand Down Expand Up @@ -73,6 +74,38 @@ def capture_local_variable_name(self, assign_expr: proto.Assign) -> str:
"""
return assign_expr.symbol.value

def get_dataframe_analytics_function_column_formatter(
self, sp_dataframe_analytics_expr: proto.Expr
) -> Callable:
"""
Create a dataframe analytics function column formatter.
This is mainly to pass the df_analytics_functions.test.

Parameters
----------
sp_dataframe_analytics_expr : proto.Expr
The dataframe analytics expression.

Returns
-------
Callable
The dataframe analytics function column formatter.
"""
if "formattedColNames" in MessageToDict(sp_dataframe_analytics_expr):
formatted_col_names = list(sp_dataframe_analytics_expr.formatted_col_names)
w_lambda_pattern = re.compile(r"^(\w+)_W_(\w+)$")
xy_lambda_pattern = re.compile(r"^(\w+)_X_(\w+)_Y_(\w+)$")
if all(re.match(xy_lambda_pattern, col) for col in formatted_col_names):
return (
lambda input, agg, window_size: f"{input}_X_{agg}_Y_{window_size}"
)
elif all(re.match(w_lambda_pattern, col) for col in formatted_col_names):
return lambda input, agg: f"{input}_W_{agg}"
else:
return lambda input_col, agg, window: f"{agg}_{input_col}_{window}"
else:
return DataFrameAnalyticsFunctions._default_col_formatter

def decode_col_exprs(self, expr: proto.Expr, is_variadic: bool) -> List[Column]:
"""
Decode a protobuf object to a list of column expressions.
Expand Down Expand Up @@ -286,7 +319,7 @@ def decode_dataframe_schema_expr(
case _:
raise ValueError(
"Unknown dataframe schema type: %s"
% df_schema_expr.WhichOneof("variant")
% df_schema_expr.WhichOneof("sealed_value")
)

def decode_data_type_expr(
Expand Down Expand Up @@ -856,8 +889,11 @@ def decode_expr(self, expr: proto.Expr) -> Any:
# DATAFRAME FUNCTIONS
case "sp_create_dataframe":
data = self.decode_dataframe_data_expr(expr.sp_create_dataframe.data)
schema = self.decode_dataframe_schema_expr(
expr.sp_create_dataframe.schema
d = MessageToDict(expr.sp_create_dataframe)
schema = (
self.decode_dataframe_schema_expr(expr.sp_create_dataframe.schema)
if "schema" in d
else None
)
df = self.session.create_dataframe(data=data, schema=schema)
if hasattr(expr, "var_id"):
Expand All @@ -882,6 +918,96 @@ def decode_expr(self, expr: proto.Expr) -> Any:
name = expr.sp_dataframe_alias.name
return df.alias(name)

case "sp_dataframe_analytics_compute_lag":
df = self.decode_expr(expr.sp_dataframe_analytics_compute_lag.df)
cols = [
self.decode_expr(col)
for col in expr.sp_dataframe_analytics_compute_lag.cols
]
group_by = list(expr.sp_dataframe_analytics_compute_lag.group_by)
lags = list(expr.sp_dataframe_analytics_compute_lag.lags)
order_by = list(expr.sp_dataframe_analytics_compute_lag.order_by)
col_formatter = self.get_dataframe_analytics_function_column_formatter(
expr.sp_dataframe_analytics_compute_lag
)
return df.analytics.compute_lag(
cols, lags, order_by, group_by, col_formatter
)

case "sp_dataframe_analytics_compute_lead":
df = self.decode_expr(expr.sp_dataframe_analytics_compute_lead.df)
cols = [
self.decode_expr(col)
for col in expr.sp_dataframe_analytics_compute_lead.cols
]
group_by = list(expr.sp_dataframe_analytics_compute_lead.group_by)
leads = list(expr.sp_dataframe_analytics_compute_lead.leads)
order_by = list(expr.sp_dataframe_analytics_compute_lead.order_by)
col_formatter = self.get_dataframe_analytics_function_column_formatter(
expr.sp_dataframe_analytics_compute_lead
)
return df.analytics.compute_lead(
cols, leads, order_by, group_by, col_formatter
)

case "sp_dataframe_analytics_cumulative_agg":
df = self.decode_expr(expr.sp_dataframe_analytics_cumulative_agg.df)
gen_aggs = self.decode_dsl_map_expr(
expr.sp_dataframe_analytics_cumulative_agg.aggs
)
# The aggs dict created has generator objects as the kv pairs. Convert them to strings/list of strings.
aggs = {str(k): list(v) for k, v in gen_aggs.items()}
group_by = list(expr.sp_dataframe_analytics_cumulative_agg.group_by)
order_by = list(expr.sp_dataframe_analytics_cumulative_agg.order_by)
is_forward = (
expr.sp_dataframe_analytics_cumulative_agg.is_forward
if hasattr(expr.sp_dataframe_analytics_cumulative_agg, "is_forward")
else False
)
col_formatter = self.get_dataframe_analytics_function_column_formatter(
expr.sp_dataframe_analytics_cumulative_agg
)
return df.analytics.cumulative_agg(
aggs, group_by, order_by, is_forward, col_formatter
)

case "sp_dataframe_analytics_moving_agg":
df = self.decode_expr(expr.sp_dataframe_analytics_moving_agg.df)
gen_aggs = self.decode_dsl_map_expr(
expr.sp_dataframe_analytics_moving_agg.aggs
)
# The aggs dict created has generator objects as the kv pairs. Convert them to strings/list of strings.
aggs = {str(k): list(v) for k, v in gen_aggs.items()}
group_by = list(expr.sp_dataframe_analytics_moving_agg.group_by)
order_by = list(expr.sp_dataframe_analytics_moving_agg.order_by)
window_sizes = list(expr.sp_dataframe_analytics_moving_agg.window_sizes)
col_formatter = self.get_dataframe_analytics_function_column_formatter(
expr.sp_dataframe_analytics_moving_agg
)
return df.analytics.moving_agg(
aggs, window_sizes, order_by, group_by, col_formatter
)

case "sp_dataframe_analytics_time_series_agg":
df = self.decode_expr(expr.sp_dataframe_analytics_time_series_agg.df)
gen_aggs = self.decode_dsl_map_expr(
expr.sp_dataframe_analytics_time_series_agg.aggs
)
# The aggs dict created has generator objects as the kv pairs. Convert them to strings/list of strings.
aggs = {str(k): list(v) for k, v in gen_aggs.items()}
group_by = list(expr.sp_dataframe_analytics_time_series_agg.group_by)
sliding_interval = (
expr.sp_dataframe_analytics_time_series_agg.sliding_interval
)
time_col = expr.sp_dataframe_analytics_time_series_agg.time_col
windows = list(expr.sp_dataframe_analytics_time_series_agg.windows)
col_formatter = self.get_dataframe_analytics_function_column_formatter(
expr.sp_dataframe_analytics_time_series_agg
)
return df.analytics.time_series_agg(
time_col, aggs, windows, group_by, sliding_interval, col_formatter
)

case "sp_dataframe_col":
col_name = expr.sp_dataframe_col.col_name
df = self.decode_expr(expr.sp_dataframe_col.df)
Expand Down Expand Up @@ -1072,6 +1198,14 @@ def decode_expr(self, expr: proto.Expr) -> Any:
else:
return df.sort(cols, ascending)

case "sp_dataframe_to_df":
df = self.decode_expr(expr.sp_dataframe_to_df.df)
col_names = list(expr.sp_dataframe_to_df.col_names)
if expr.sp_dataframe_to_df.variadic:
return df.to_df(*col_names)
else:
return df.to_df(col_names)

case "sp_dataframe_unpivot":
df = self.decode_expr(expr.sp_dataframe_unpivot.df)
column_list = [
Expand Down
Loading