Skip to content

Commit

Permalink
feat(flink): implement support for array expansion (#8511)
Browse files Browse the repository at this point in the history
  • Loading branch information
chloeh13q authored Apr 12, 2024
1 parent 97ff704 commit a6e6564
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 14 deletions.
1 change: 0 additions & 1 deletion ibis/backends/flink/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class FlinkCompiler(SQLGlotCompiler):
ops.RowID,
ops.StringSplit,
ops.Translate,
ops.Unnest,
)
)

Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/flink/tests/test_memtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
def test_create_memtable(con, data, schema, expected):
t = ibis.memtable(data, schema=ibis.schema(schema))
# cannot use con.execute(t) directly because of some behavioral discrepancy between
# `TableEnvironment.execute_sql()` and `TableEnvironment.sql_query()`
# `TableEnvironment.execute_sql()` and `TableEnvironment.sql_query()`; this doesn't
# seem to be an issue if we don't execute memtable directly
result = con.raw_sql(con.compile(t))
# raw_sql() returns a `TableResult` object and doesn't natively convert to pandas
assert list(result.collect()) == expected
Expand Down
76 changes: 75 additions & 1 deletion ibis/backends/sql/dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Trino,
)
from sqlglot.dialects.dialect import rename_func
from sqlglot.helper import seq_get
from sqlglot.helper import find_new_name, seq_get

ClickHouse.Generator.TRANSFORMS |= {
sge.ArraySize: rename_func("length"),
Expand Down Expand Up @@ -111,20 +111,94 @@ def _interval_with_precision(self, e):
return f"INTERVAL {formatted_arg} {unit}"


def _explode_to_unnest():
"""Convert explode into unnest.
NOTE: Flink doesn't support UNNEST WITH ORDINALITY or UNNEST WITH OFFSET.
"""

def _explode_to_unnest(expression: sge.Expression) -> sge.Expression:
if isinstance(expression, sge.Select):
from sqlglot.optimizer.scope import Scope

taken_select_names = set(expression.named_selects)
taken_source_names = {name for name, _ in Scope(expression).references}

def new_name(names: set[str], name: str) -> str:
name = find_new_name(names, name)
names.add(name)
return name

# we use list here because expression.selects is mutated inside the loop
for select in list(expression.selects):
explode = select.find(sge.Explode)

if explode:
explode_alias = ""

if isinstance(select, sge.Alias):
explode_alias = select.args["alias"]
alias = select
elif isinstance(select, sge.Aliases):
explode_alias = select.aliases[1]
alias = select.replace(sge.alias_(select.this, "", copy=False))
else:
alias = select.replace(sge.alias_(select, ""))
explode = alias.find(sge.Explode)
assert explode

explode_arg = explode.this

# This ensures that we won't use EXPLODE's argument as a new selection
if isinstance(explode_arg, sge.Column):
taken_select_names.add(explode_arg.output_name)

unnest_source_alias = new_name(taken_source_names, "_u")

if not explode_alias:
explode_alias = new_name(taken_select_names, "col")

alias.set("alias", sge.to_identifier(explode_alias))

column = sge.column(explode_alias, table=unnest_source_alias)

explode.replace(column)

expression.join(
sge.alias_(
sge.Unnest(
expressions=[explode_arg.copy()],
),
unnest_source_alias,
table=[explode_alias],
),
join_type="CROSS",
copy=False,
)

return expression

return _explode_to_unnest


class Flink(Hive):
class Generator(Hive.Generator):
UNNEST_WITH_ORDINALITY = False

TYPE_MAPPING = Hive.Generator.TYPE_MAPPING.copy() | {
sge.DataType.Type.TIME: "TIME",
sge.DataType.Type.STRUCT: "ROW",
}

TRANSFORMS = Hive.Generator.TRANSFORMS.copy() | {
sge.Select: transforms.preprocess([_explode_to_unnest()]),
sge.Stddev: rename_func("stddev_samp"),
sge.StddevPop: rename_func("stddev_pop"),
sge.StddevSamp: rename_func("stddev_samp"),
sge.Variance: rename_func("var_samp"),
sge.VariancePop: rename_func("var_pop"),
sge.ArrayConcat: rename_func("array_concat"),
sge.ArraySize: rename_func("cardinality"),
sge.Length: rename_func("char_length"),
sge.TryCast: lambda self,
e: f"TRY_CAST({e.this.sql(self.dialect)} AS {e.to.sql(self.dialect)})",
Expand Down
59 changes: 51 additions & 8 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_array_discovery(backend):
reason="BigQuery doesn't support casting array<T> to array<U>",
raises=GoogleBadRequest,
)
@pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_unnest_simple(backend):
array_types = backend.array_types
expected = (
Expand All @@ -261,7 +261,7 @@ def test_unnest_simple(backend):


@builtin_array
@pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_unnest_complex(backend):
array_types = backend.array_types
df = array_types.execute()
Expand Down Expand Up @@ -356,7 +356,7 @@ def test_unnest_no_nulls(backend):
raises=ValueError,
reason="all the input arrays must have same number of dimensions",
)
@pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_unnest_default_name(backend):
array_types = backend.array_types
df = array_types.execute()
Expand Down Expand Up @@ -583,7 +583,7 @@ def test_array_contains(backend, con, col, value):
pytest.mark.notyet(
["flink"],
raises=Py4JJavaError,
reason="SQL validation failed; Flink does not support ARRAY[]",
reason="SQL validation failed; Flink does not support ARRAY[]", # https://issues.apache.org/jira/browse/FLINK-20578
),
pytest.mark.broken(
["datafusion"],
Expand Down Expand Up @@ -621,7 +621,7 @@ def test_array_position(con, a, expected_array):
pytest.mark.notyet(
["flink"],
raises=Py4JJavaError,
reason="SQL validation failed; Flink does not support ARRAY[]",
reason="SQL validation failed; Flink does not support ARRAY[]", # https://issues.apache.org/jira/browse/FLINK-20578
)
],
),
Expand Down Expand Up @@ -803,7 +803,7 @@ def test_array_intersect(con, data):
@builtin_array
@pytest.mark.notimpl(["postgres"], raises=PsycoPg2SyntaxError)
@pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError)
@pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
@pytest.mark.broken(
["trino"], reason="inserting maps into structs doesn't work", raises=TrinoUserError
)
Expand All @@ -818,6 +818,39 @@ def test_unnest_struct(con):
tm.assert_series_equal(result, expected)


@builtin_array
@pytest.mark.notimpl(
["clickhouse"],
raises=ClickHouseDatabaseError,
reason="ClickHouse won't accept dicts for struct type values",
)
@pytest.mark.notimpl(["postgres"], raises=PsycoPg2SyntaxError)
@pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
@pytest.mark.broken(
["trino"], reason="inserting maps into structs doesn't work", raises=TrinoUserError
)
@pytest.mark.broken(
["flink"], reason="flink unnests a and b as separate columns", raises=Py4JJavaError
)
def test_unnest_struct_with_multiple_fields(con):
data = {
"value": [
[{"a": 1, "b": "banana"}, {"a": 2, "b": "apple"}],
[{"a": 3, "b": "coconut"}, {"a": 4, "b": "orange"}],
]
}
t = ibis.memtable(
data, schema=ibis.schema({"value": "!array<!struct<a: !int, b: !string>>"})
)
expr = t.value.unnest()

result = con.execute(expr)

expected = pd.DataFrame(data).explode("value").iloc[:, 0].reset_index(drop=True)
tm.assert_series_equal(result, expected)


array_zip_notimpl = pytest.mark.notimpl(
[
"dask",
Expand Down Expand Up @@ -889,7 +922,7 @@ def test_zip_null(con, fn):
)
@pytest.mark.notimpl(["postgres"], raises=PsycoPg2SyntaxError)
@pytest.mark.notimpl(["risingwave"], raises=PsycoPg2ProgrammingError)
@pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(
["polars"],
raises=com.OperationNotDefinedError,
Expand All @@ -903,6 +936,11 @@ def test_zip_null(con, fn):
@pytest.mark.broken(
["trino"], reason="inserting maps into structs doesn't work", raises=TrinoUserError
)
@pytest.mark.notyet(
["flink"],
raises=Py4JJavaError,
reason="does not seem to support field selection on unnest",
)
def test_array_of_struct_unnest(con):
jobs = ibis.memtable(
{
Expand Down Expand Up @@ -1079,10 +1117,15 @@ def test_range_start_stop_step_zero(con, start, stop):
reason="ibis hasn't implemented this behavior yet",
)
@pytest.mark.notyet(
["datafusion", "flink"],
["datafusion"],
raises=com.OperationNotDefinedError,
reason="backend doesn't support unnest",
)
@pytest.mark.notyet(
["flink"],
raises=Py4JJavaError,
reason="SQL validation failed; Flink does not support ARRAY[]", # https://issues.apache.org/jira/browse/FLINK-20578
)
def test_unnest_empty_array(con):
t = ibis.memtable({"arr": [[], ["a"], ["a", "b"]]})
expr = t.arr.unnest()
Expand Down
7 changes: 6 additions & 1 deletion ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,7 @@ def query(t, group_cols):


@pytest.mark.notimpl(
["dask", "pandas", "oracle", "flink", "exasol"], raises=com.OperationNotDefinedError
["dask", "pandas", "oracle", "exasol"], raises=com.OperationNotDefinedError
)
@pytest.mark.notimpl(["druid"], raises=AssertionError)
@pytest.mark.notyet(
Expand All @@ -1046,6 +1046,11 @@ def query(t, group_cols):
reason="invalid code generated for unnesting a struct",
raises=TrinoUserError,
)
@pytest.mark.broken(
["flink"],
reason="invalid code generated for unnesting a struct",
raises=Py4JJavaError,
)
def test_pivot_longer(backend):
diamonds = backend.diamonds
df = diamonds.execute()
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,6 @@ def test_map_length(con):
assert con.execute(expr) == 2


@pytest.mark.notimpl(["flink"], raises=exc.OperationNotDefinedError)
def test_map_keys_unnest(backend):
expr = backend.map.kv.keys().unnest()
result = expr.to_pandas()
Expand Down
7 changes: 6 additions & 1 deletion ibis/formats/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis import util
from ibis.common.numeric import normalize_decimal
from ibis.common.temporal import normalize_timezone
from ibis.formats import DataMapper, SchemaMapper, TableProxy
Expand Down Expand Up @@ -300,7 +301,11 @@ def convert(values, names=dtype.names, converters=converters):
if values is None:
return values

items = values.items() if isinstance(values, dict) else zip(names, values)
items = (
values.items()
if isinstance(values, dict)
else zip(names, util.promote_list(values))
)
return {
k: converter(v) if v is not None else v
for converter, (k, v) in zip(converters, items)
Expand Down

0 comments on commit a6e6564

Please sign in to comment.