Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1830524 Add decoder logic for Dataframe.join #2802

Open
wants to merge 4 commits into
base: vbudati/SNOW-1794510-merge-decoder
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion tests/ast/data/Dataframe.join.asof.test
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ df1 = session.create_dataframe([["A", 1, 15, 3.21], ["A", 2, 16, 3.22], ["B", 1,

df2 = session.create_dataframe([["A", 1, 14, 3.19], ["B", 2, 16, 3.04]], schema=["c1", "c2", "c3", "c4"])

df1.join(df2, on=(df1["c1"] == df2["c1"]) & (df1["c2"] == df2["c2"]), how="asof", lsuffix="_L", rsuffix="_R", match_condition=df1["c3"] >= df2["c3"]).sort("C1_L", "C2_L").collect()
df1.join(df2, on=(df1["c1"] == df2["c1"]) & (df1["c2"] == df2["c2"]), how="asof", lsuffix="_L", rsuffix="_R", match_condition=df1["c3"] >= df2["c3"]).sort("C1_L", "C2_L", ascending=None).collect()

## EXPECTED ENCODED AST

Expand Down Expand Up @@ -522,6 +522,14 @@ body {
assign {
expr {
sp_dataframe_sort {
ascending {
null_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 36
}
}
}
cols {
string_val {
src {
Expand Down
146 changes: 10 additions & 136 deletions tests/ast/data/Dataframe.join.prefix.test
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ df2 = session.create_dataframe([[1, 2, 3, 4, 5]], schema=['\"A\"','\"B\"','\"C\"

df3 = df1.filter(col("\"A\"") == 1).join(df2.select((col("\"A\"") + 1).as_("\"A\""), col("\"B\""), col("\"C\""), col("\"l_0001_C\""), col("\"l_0003_B\"")))

df4 = df3.sort(df3.columns)
# Commented out since df3.columns produces different results in the first encoding and in the encode-decode-encode result.
Copy link
Collaborator

@sfc-gh-evandenberg sfc-gh-evandenberg Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you find out why this is happening? We shouldn't be removing valid test cases here. I don't see a reason why encode-decode-encode needs to be value equivalent, this is a good example, as long as they are semantically equivalent (the uniqueness of generated columns correspond correctly) that is all that is required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unique columns correspond to each other in both cases but the values are different. I'm not sure what the best way around this is since hardcoding the decoder seems like a bad idea. I can add the test back in.

I'm not familiar with how the column names are generated but can try to figure that out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed!

# df4 = df3.sort(df3.columns)

df4.collect()
df3.collect()

## EXPECTED UNPARSER OUTPUT

Expand All @@ -22,9 +23,7 @@ df3 = df2.select((col("\"A\"") + 1).as_("\"A\""), col("\"B\""), col("\"C\""), co

df3 = df3_res1.join(df3, how="inner")

df4 = df3.sort(["\"l_0004_A\"", "\"l_0004_B\"", "\"l_0004_C\"", "\"r_0000_A\"", "\"l_0000_A\"", "\"l_0002_A\"", "\"r_0006_A\"", "\"r_0006_B\"", "\"r_0006_C\"", "\"l_0001_C\"", "\"l_0003_B\""])

df4.collect()
df3.collect()

## EXPECTED ENCODED AST

Expand Down Expand Up @@ -499,159 +498,34 @@ body {
}
}
}
body {
assign {
expr {
sp_dataframe_sort {
cols {
string_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 31
}
v: "\"l_0004_A\""
}
}
cols {
string_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 31
}
v: "\"l_0004_B\""
}
}
cols {
string_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 31
}
v: "\"l_0004_C\""
}
}
cols {
string_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 31
}
v: "\"r_0000_A\""
}
}
cols {
string_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 31
}
v: "\"l_0000_A\""
}
}
cols {
string_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 31
}
v: "\"l_0002_A\""
}
}
cols {
string_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 31
}
v: "\"r_0006_A\""
}
}
cols {
string_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 31
}
v: "\"r_0006_B\""
}
}
cols {
string_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 31
}
v: "\"r_0006_C\""
}
}
cols {
string_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 31
}
v: "\"l_0001_C\""
}
}
cols {
string_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 31
}
v: "\"l_0003_B\""
}
}
df {
sp_dataframe_ref {
id {
bitfield1: 5
}
}
}
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 31
}
}
}
symbol {
value: "df4"
}
uid: 6
var_id {
bitfield1: 6
}
}
}
body {
assign {
expr {
sp_dataframe_collect {
block: true
case_sensitive: true
id {
bitfield1: 6
bitfield1: 5
}
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 33
start_line: 34
}
}
}
symbol {
}
uid: 7
uid: 6
var_id {
bitfield1: 7
bitfield1: 6
}
}
}
body {
eval {
uid: 8
uid: 7
var_id {
bitfield1: 7
bitfield1: 6
}
}
}
Expand Down
95 changes: 90 additions & 5 deletions tests/ast/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,42 @@ def decode_data_type_expr(
"Unknown data type: %s" % data_type_expr.WhichOneof("variant")
)

def decode_join_type(self, join_type: proto.SpJoinType) -> str:
"""
Decode a join type expression to get the join type.

Parameters
----------
join_type : proto.SpJoinType
The expression to decode.

Returns
-------
str
The decoded join type.
"""
match join_type.WhichOneof("variant"):
case "sp_join_type__asof":
return "asof"
case "sp_join_type__cross":
return "cross"
case "sp_join_type__full_outer":
return "full"
case "sp_join_type__inner":
return "inner"
case "sp_join_type__left_anti":
return "anti"
case "sp_join_type__left_outer":
return "left"
case "sp_join_type__left_semi":
return "semi"
case "sp_join_type__right_outer":
return "right"
case _:
raise ValueError(
"Unknown join type: %s" % join_type.WhichOneof("variant")
)

def decode_timezone_expr(self, tz_expr: proto.PythonTimeZone) -> Any:
"""
Decode a Python timezone expression to get the timezone.
Expand Down Expand Up @@ -935,6 +971,11 @@ def decode_expr(self, expr: proto.Expr) -> Any:
other = self.decode_expr(expr.sp_dataframe_except.other)
return df.except_(other)

case "sp_dataframe_filter":
df = self.decode_expr(expr.sp_dataframe_filter.df)
condition = self.decode_expr(expr.sp_dataframe_filter.condition)
return df.filter(condition)

case "sp_dataframe_first":
df = self.decode_expr(expr.sp_dataframe_first.df)
block = expr.sp_dataframe_first.block
Expand All @@ -960,6 +1001,47 @@ def decode_expr(self, expr: proto.Expr) -> Any:
other = self.decode_expr(expr.sp_dataframe_intersect.other)
return df.intersect(other)

case "sp_dataframe_join":
d = MessageToDict(expr.sp_dataframe_join)
join_expr = d.get("joinExpr", None)
join_expr = (
self.decode_expr(expr.sp_dataframe_join.join_expr)
if join_expr
else None
)
join_type = d.get("joinType", None)
join_type = (
self.decode_join_type(expr.sp_dataframe_join.join_type)
if join_type
else None
)
lhs = self.decode_expr(expr.sp_dataframe_join.lhs)
rhs = self.decode_expr(expr.sp_dataframe_join.rhs)
lsuffix = d.get("lsuffix", "")
rsuffix = d.get("rsuffix", "")
match_condition = d.get("matchCondition", None)
match_condition = (
self.decode_expr(expr.sp_dataframe_join.match_condition)
if match_condition
else None
)
return lhs.join(
right=rhs,
on=join_expr,
how=join_type,
lsuffix=lsuffix,
rsuffix=rsuffix,
match_condition=match_condition,
)

case "sp_dataframe_natural_join":
lhs = self.decode_expr(expr.sp_dataframe_natural_join.lhs)
rhs = self.decode_expr(expr.sp_dataframe_natural_join.rhs)
join_type = self.decode_join_type(
expr.sp_dataframe_natural_join.join_type
)
return lhs.natural_join(right=rhs, how=join_type)

case "sp_dataframe_na_drop__python":
df = self.decode_expr(expr.sp_dataframe_na_drop__python.df)
how = expr.sp_dataframe_na_drop__python.how
Expand Down Expand Up @@ -1063,14 +1145,17 @@ def decode_expr(self, expr: proto.Expr) -> Any:

case "sp_dataframe_sort":
df = self.decode_expr(expr.sp_dataframe_sort.df)
cols = self.decode_col_exprs(
expr.sp_dataframe_sort.cols, expr.sp_dataframe_sort.cols.variadic
is_variadic = (
expr.sp_dataframe_sort.cols_variadic
if hasattr(expr.sp_dataframe_sort, "cols_variadic")
else False
)
cols = self.decode_col_exprs(expr.sp_dataframe_sort.cols, is_variadic)
ascending = self.decode_expr(expr.sp_dataframe_sort.ascending)
if expr.sp_dataframe_sort.cols_variadic:
return df.sort(*cols, ascending)
if is_variadic:
return df.sort(*cols, ascending=ascending)
else:
return df.sort(cols, ascending)
return df.sort(cols, ascending=ascending)

case "sp_dataframe_unpivot":
df = self.decode_expr(expr.sp_dataframe_unpivot.df)
Expand Down
Loading