Skip to content

Commit

Permalink
feat: add outputNames hint to all Rels except joins (#1150)
Browse files Browse the repository at this point in the history
This is first PR to solve #1145 and adds outputNames hint to all Rels
except for JoinRel and CrossRel. Join relations need more work to handle
name collisions as ibis doesn't do it for every join step as part of
JoinChain.
  • Loading branch information
tokoko authored Sep 24, 2024
1 parent 5f668dc commit 6d473ad
Show file tree
Hide file tree
Showing 24 changed files with 3,264 additions and 94 deletions.
36 changes: 34 additions & 2 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,10 @@ def unbound_table(
read=stalg.ReadRel(
# TODO: filter,
# TODO: projection,
common=stalg.RelCommon(direct=stalg.RelCommon.Direct()),
common=stalg.RelCommon(
direct=stalg.RelCommon.Direct(),
hint=stalg.RelCommon.Hint(output_names=op.schema.fields),
),
base_schema=translate(op.schema),
named_table=stalg.ReadRel.NamedTable(names=[op.name]),
)
Expand All @@ -813,9 +816,14 @@ def filter(
compiler=compiler,
child_rel_field_offsets=child_rel_field_offsets,
)

predicates = [pred.to_expr() for pred in filter.predicates] # type: ignore
return stalg.Rel(
filter=stalg.FilterRel(
common=stalg.RelCommon(
direct=stalg.RelCommon.Direct(),
hint=stalg.RelCommon.Hint(output_names=filter.schema.fields),
),
input=relation,
condition=translate(
functools.reduce(operator.and_, predicates),
Expand All @@ -831,6 +839,7 @@ def apply_projection(
schema_len: int,
relation: stalg.Rel,
values: Mapping[str, ops.Value],
output_names: list[str],
compiler: SubstraitCompiler,
child_rel_field_offsets: Mapping[ops.TableNode, int] | None,
kwargs: Mapping,
Expand All @@ -843,7 +852,8 @@ def apply_projection(
common=stalg.RelCommon(
emit=stalg.RelCommon.Emit(
output_mapping=[next(mapping_counter) for _ in values]
)
),
hint=stalg.RelCommon.Hint(output_names=output_names),
),
expressions=[
translate(
Expand Down Expand Up @@ -874,6 +884,7 @@ def project(
schema_len=len(op.parent.schema),
relation=relation,
values=op.values,
output_names=op.schema.fields,
compiler=compiler,
child_rel_field_offsets=child_rel_field_offsets,
kwargs=kwargs,
Expand All @@ -894,6 +905,10 @@ def sort(

return stalg.Rel(
sort=stalg.SortRel(
common=stalg.RelCommon(
direct=stalg.RelCommon.Direct(),
hint=stalg.RelCommon.Hint(output_names=op.schema.fields),
),
input=relation,
sorts=[
translate(
Expand Down Expand Up @@ -979,6 +994,7 @@ def join(
schema_len=offset,
relation=relation,
values=op.values,
output_names=op.schema.fields,
compiler=compiler,
child_rel_field_offsets=child_rel_field_offsets,
kwargs=kwargs,
Expand All @@ -994,6 +1010,10 @@ def limit(
) -> stalg.Rel:
return stalg.Rel(
fetch=stalg.FetchRel(
common=stalg.RelCommon(
direct=stalg.RelCommon.Direct(),
hint=stalg.RelCommon.Hint(output_names=op.schema.fields),
),
input=translate(op.parent, compiler=compiler, **kwargs),
offset=op.offset,
count=op.n,
Expand Down Expand Up @@ -1034,6 +1054,10 @@ def set_op(
) -> stalg.Rel:
return stalg.Rel(
set=stalg.SetRel(
common=stalg.RelCommon(
direct=stalg.RelCommon.Direct(),
hint=stalg.RelCommon.Hint(output_names=op.schema.fields),
),
inputs=[
translate(op.left, compiler=compiler, **kwargs),
translate(op.right, compiler=compiler, **kwargs),
Expand All @@ -1051,6 +1075,10 @@ def aggregate(
**kwargs: Any,
) -> stalg.Rel:
aggregate = stalg.AggregateRel(
common=stalg.RelCommon(
direct=stalg.RelCommon.Direct(),
hint=stalg.RelCommon.Hint(output_names=op.schema.fields),
),
input=translate(op.parent, compiler=compiler, **kwargs),
groupings=[
stalg.AggregateRel.Grouping(
Expand Down Expand Up @@ -1310,6 +1338,10 @@ def _not_exists_subquery(
assert compiler is not None
tuples = stalg.Rel(
filter=stalg.FilterRel(
common=stalg.RelCommon(
direct=stalg.RelCommon.Direct(),
hint=stalg.RelCommon.Hint(output_names=op.schema.fields),
),
input=translate(op.foreign_table, compiler=compiler),
condition=translate(
functools.reduce(ops.And, op.predicates), # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,91 @@
"root": {
"input": {
"sort": {
"common": {
"direct": {},
"hint": {
"outputNames": [
"l_returnflag",
"l_linestatus",
"sum_qty",
"sum_base_price",
"sum_disc_price",
"sum_charge",
"avg_qty",
"avg_price",
"avg_disc",
"count_order"
]
}
},
"input": {
"aggregate": {
"common": {
"direct": {},
"hint": {
"outputNames": [
"l_returnflag",
"l_linestatus",
"sum_qty",
"sum_base_price",
"sum_disc_price",
"sum_charge",
"avg_qty",
"avg_price",
"avg_disc",
"count_order"
]
}
},
"input": {
"filter": {
"common": {
"direct": {},
"hint": {
"outputNames": [
"l_orderkey",
"l_partkey",
"l_suppkey",
"l_linenumber",
"l_quantity",
"l_extendedprice",
"l_discount",
"l_tax",
"l_returnflag",
"l_linestatus",
"l_shipdate",
"l_commitdate",
"l_receiptdate",
"l_shipinstruct",
"l_shipmode",
"l_comment"
]
}
},
"input": {
"read": {
"common": {
"direct": {}
"direct": {},
"hint": {
"outputNames": [
"l_orderkey",
"l_partkey",
"l_suppkey",
"l_linenumber",
"l_quantity",
"l_extendedprice",
"l_discount",
"l_tax",
"l_returnflag",
"l_linestatus",
"l_shipdate",
"l_commitdate",
"l_receiptdate",
"l_shipinstruct",
"l_shipmode",
"l_comment"
]
}
},
"baseSchema": {
"names": [
Expand Down
Loading

0 comments on commit 6d473ad

Please sign in to comment.