Skip to content

Commit

Permalink
changes for select_expr, stat, pivot, rollup, indexers
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-vbudati committed Dec 21, 2024
1 parent ca608fb commit ade0e82
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 3 deletions.
20 changes: 18 additions & 2 deletions tests/ast/data/DataFrame.pivot.test
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ df1 = df.pivot("mo", values=["JAN", "FEB"], default_on_null=None)

df1 = df1.sum("t")

df1 = df1.sort("k")
df1 = df1.sort("k", ascending=None)

df2 = df.pivot("mo", values=["JAN", "FEB"], default_on_null="Nothing")

df2 = df2.sum("t")

df2 = df2.sort("k")
df2 = df2.sort("k", ascending=None)

## EXPECTED ENCODED AST

Expand Down Expand Up @@ -426,6 +426,14 @@ body {
assign {
expr {
sp_dataframe_sort {
ascending {
null_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 38
}
}
}
cols {
string_val {
src {
Expand Down Expand Up @@ -575,6 +583,14 @@ body {
assign {
expr {
sp_dataframe_sort {
ascending {
null_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 40
}
}
}
cols {
string_val {
src {
Expand Down
2 changes: 1 addition & 1 deletion tests/ast/data/DataFrame.unpivot.test
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ df = df.unpivot("sales", "month", ["jan", "feb"])

df = session.create_dataframe([(1, "electronics", 100, 200), (2, "clothes", 100, 300)], schema=["empid", "dept", "jan", "feb"])

df = df.unpivot("sales", "month", ["jan", "feb"])
df = df.unpivot("sales", "month", ["jan", "feb"], False)

## EXPECTED ENCODED AST

Expand Down
93 changes: 93 additions & 0 deletions tests/ast/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,6 +1185,28 @@ def decode_expr(self, expr: proto.Expr) -> Any:
)
return df.na.replace(to_replace, value, subset)

case "sp_dataframe_pivot":
df = self.decode_expr(expr.sp_dataframe_pivot.df)
pivot_col = self.decode_expr(expr.sp_dataframe_pivot.pivot_col)
default_on_null = self.decode_expr(
expr.sp_dataframe_pivot.default_on_null
)
match expr.sp_dataframe_pivot.values.WhichOneof("sealed_value"):
case "sp_pivot_value__dataframe":
values = self.decode_expr(
expr.sp_dataframe_pivot.values.sp_pivot_value__dataframe.v
)
case "sp_pivot_value__expr":
values = self.decode_expr(
expr.sp_dataframe_pivot.values.sp_pivot_value__expr.v
)
case _:
raise ValueError(
"Unknown pivot value: %s"
% expr.sp_dataframe_pivot.values.WhichOneof("sealed_value")
)
return df.pivot(pivot_col, values, default_on_null)

case "sp_dataframe_ref":
return self.symbol_table[expr.sp_dataframe_ref.id.bitfield1][1]

Expand All @@ -1196,6 +1218,14 @@ def decode_expr(self, expr: proto.Expr) -> Any:
)
return df.rename(col_or_mapper, new_column)

case "sp_dataframe_rollup":
df = self.decode_expr(expr.sp_dataframe_rollup.df)
cols = self.decode_col_exprs(expr.sp_dataframe_rollup.cols.args)
if MessageToDict(expr.sp_dataframe_rollup.cols).get("variadic", False):
return df.rollup(*cols)
else:
return df.rollup(cols)

case "sp_dataframe_select__columns":
df = self.decode_expr(expr.sp_dataframe_select__columns.df)
# The columns can be a list of Expr or a single Expr.
Expand All @@ -1213,6 +1243,16 @@ def decode_expr(self, expr: proto.Expr) -> Any:
)
return val

case "sp_dataframe_select__exprs":
df = self.decode_expr(expr.sp_dataframe_select__exprs.df)
exprs = list(expr.sp_dataframe_select__exprs.exprs)
if MessageToDict(expr.sp_dataframe_select__exprs).get(
"variadic", False
):
return df.select_expr(*exprs)
else:
return df.select_expr(exprs)

case "sp_dataframe_show":
df = self.decode_expr(
self.symbol_table[expr.sp_dataframe_show.id.bitfield1][1]
Expand All @@ -1230,6 +1270,59 @@ def decode_expr(self, expr: proto.Expr) -> Any:
else:
return df.sort(cols, ascending=ascending)

case "sp_dataframe_stat_approx_quantile":
d = MessageToDict(expr.sp_dataframe_stat_approx_quantile)
if "df" in d:
df = self.decode_expr(expr.sp_dataframe_stat_approx_quantile.df)
else:
df = self.symbol_table[
expr.sp_dataframe_stat_approx_quantile.id.bitfield1
][1]
cols = [
self.decode_expr(col)
for col in expr.sp_dataframe_stat_approx_quantile.cols
]
percentile = list(expr.sp_dataframe_stat_approx_quantile.percentile)
statement_params = self.get_statement_params(d)
return df._stat.approx_quantile(
cols, percentile, statement_params=statement_params
)

case "sp_dataframe_stat_corr":
df = self.symbol_table[expr.sp_dataframe_stat_corr.id.bitfield1][1]
col1 = self.decode_expr(expr.sp_dataframe_stat_corr.col1)
col2 = self.decode_expr(expr.sp_dataframe_stat_corr.col2)
statement_params = self.get_statement_params(
MessageToDict(expr.sp_dataframe_stat_corr)
)
return df._stat.corr(col1, col2, statement_params=statement_params)

case "sp_dataframe_stat_cov":
df = self.symbol_table[expr.sp_dataframe_stat_cov.id.bitfield1][1]
col1 = self.decode_expr(expr.sp_dataframe_stat_cov.col1)
col2 = self.decode_expr(expr.sp_dataframe_stat_cov.col2)
statement_params = self.get_statement_params(
MessageToDict(expr.sp_dataframe_stat_cov)
)
return df._stat.cov(col1, col2, statement_params=statement_params)

case "sp_dataframe_stat_cross_tab":
df = self.symbol_table[expr.sp_dataframe_stat_cross_tab.id.bitfield1][1]
col1 = self.decode_expr(expr.sp_dataframe_stat_cross_tab.col1)
col2 = self.decode_expr(expr.sp_dataframe_stat_cross_tab.col2)
statement_params = self.get_statement_params(
MessageToDict(expr.sp_dataframe_stat_cross_tab)
)
return df._stat.crosstab(col1, col2, statement_params=statement_params)

case "sp_dataframe_stat_sample_by":
df = self.decode_expr(expr.sp_dataframe_stat_sample_by.df)
col = self.decode_expr(expr.sp_dataframe_stat_sample_by.col)
fractions = self.decode_dsl_map_expr(
expr.sp_dataframe_stat_sample_by.fractions
)
return df._stat.sample_by(col, fractions)

case "sp_dataframe_to_df":
df = self.decode_expr(expr.sp_dataframe_to_df.df)
col_names = list(expr.sp_dataframe_to_df.col_names)
Expand Down

0 comments on commit ade0e82

Please sign in to comment.