Skip to content

Commit

Permalink
Fix error on dependent DataFrame setitems (#2701)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi authored Feb 11, 2022
1 parent c7d2bfb commit e5c851c
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 56 deletions.
33 changes: 17 additions & 16 deletions mars/optimization/logical/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Tuple, Type
from typing import Dict, List, Optional, Tuple, Type

from ...core import OperandType, ChunkType, EntityType, enter_mode
from ...core.graph import EntityGraph
Expand Down Expand Up @@ -59,10 +59,12 @@ def append_record(self, record: OptimizationRecord):
):
self._optimized_chunk_to_records[record.new_chunk] = record

def get_optimization_result(self, original_chunk: ChunkType) -> ChunkType:
def get_optimization_result(
self, original_chunk: ChunkType, default: Optional[ChunkType] = None
) -> ChunkType:
chunk = original_chunk
if chunk not in self._original_chunk_to_records:
return
return default
while chunk in self._original_chunk_to_records:
record = self._original_chunk_to_records[chunk]
if record.record_type == OptimizationRecordType.replace:
Expand All @@ -72,10 +74,12 @@ def get_optimization_result(self, original_chunk: ChunkType) -> ChunkType:
return None
return chunk

def get_original_chunk(self, optimized_chunk: ChunkType) -> ChunkType:
def get_original_chunk(
self, optimized_chunk: ChunkType, default: Optional[ChunkType] = None
) -> ChunkType:
chunk = optimized_chunk
if chunk not in self._optimized_chunk_to_records:
return
return default
while chunk in self._optimized_chunk_to_records:
record = self._optimized_chunk_to_records[chunk]
if record.record_type == OptimizationRecordType.replace:
Expand Down Expand Up @@ -151,28 +155,25 @@ def _replace_node(self, original_node: EntityType, new_node: EntityType):
for succ in successors:
self._graph.add_edge(new_node, succ)

@classmethod
def _add_collapsable_predecessor(cls, node: EntityType, predecessor: EntityType):
if predecessor not in cls._preds_to_remove:
cls._preds_to_remove[predecessor] = {node}
def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType):
pred_original = self._records.get_original_chunk(predecessor, predecessor)
if predecessor not in self._preds_to_remove:
self._preds_to_remove[pred_original] = {node}
else:
cls._preds_to_remove[predecessor].add(node)
self._preds_to_remove[pred_original].add(node)

def _remove_collapsable_predecessors(self, node: EntityType):
node = self._records.get_optimization_result(node) or node
preds_opt_to_remove = []
for pred in self._graph.predecessors(node):
pred_original = self._records.get_original_chunk(pred)
pred_original = pred_original if pred_original is not None else pred

pred_opt = self._records.get_optimization_result(pred)
pred_opt = pred_opt if pred_opt is not None else pred
pred_original = self._records.get_original_chunk(pred, pred)
pred_opt = self._records.get_optimization_result(pred, pred)

if pred_opt in self._graph.results or pred_original in self._graph.results:
continue
affect_succ = self._preds_to_remove.get(pred_original) or []
affect_succ_opt = [
self._records.get_optimization_result(s) or s for s in affect_succ
self._records.get_optimization_result(s, s) for s in affect_succ
]
if all(s in affect_succ_opt for s in self._graph.successors(pred)):
preds_opt_to_remove.append((pred_original, pred_opt))
Expand Down
126 changes: 87 additions & 39 deletions mars/optimization/logical/tileable/arithmetic_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,12 @@ def _is_select_dataframe_column(tileable) -> bool:
and index_op.mask is None
)

@classmethod
def _extract_eval_expression(cls, tileable) -> EvalExtractRecord:
def _extract_eval_expression(self, tileable) -> EvalExtractRecord:
if is_scalar(tileable):
if isinstance(tileable, (int, bool, str, bytes, np.integer, np.bool_)):
return EvalExtractRecord(expr=repr(tileable))
else:
var_name = f"__eval_scalar_var{cls._next_var_id()}"
var_name = f"__eval_scalar_var{self._next_var_id()}"
var_dict = {var_name: tileable}
return EvalExtractRecord(expr=f"@{var_name}", variables=var_dict)

Expand All @@ -121,15 +120,15 @@ def _extract_eval_expression(cls, tileable) -> EvalExtractRecord:
if tileable in _extract_result_cache:
return _extract_result_cache[tileable]

if cls._is_select_dataframe_column(tileable):
result = cls._extract_column_select(tileable)
if self._is_select_dataframe_column(tileable):
result = self._extract_column_select(tileable)
elif isinstance(tileable.op, DataFrameUnaryUfunc):
result = cls._extract_unary(tileable)
result = self._extract_unary(tileable)
elif isinstance(tileable.op, DataFrameBinopUfunc):
if tileable.op.fill_value is not None or tileable.op.level is not None:
result = EvalExtractRecord()
else:
result = cls._extract_binary(tileable)
result = self._extract_binary(tileable)
else:
result = EvalExtractRecord()

Expand All @@ -140,35 +139,33 @@ def _extract_eval_expression(cls, tileable) -> EvalExtractRecord:
def _extract_column_select(cls, tileable) -> EvalExtractRecord:
return EvalExtractRecord(tileable.inputs[0], f"`{tileable.op.col_names}`")

@classmethod
def _extract_unary(cls, tileable) -> EvalExtractRecord:
def _extract_unary(self, tileable) -> EvalExtractRecord:
op = tileable.op
func_name = getattr(op, "_func_name") or getattr(op, "_bin_func_name")
if func_name not in _func_name_to_builder: # pragma: no cover
return EvalExtractRecord()

in_tileable, expr, variables = cls._extract_eval_expression(op.inputs[0])
in_tileable, expr, variables = self._extract_eval_expression(op.inputs[0])
if in_tileable is None:
return EvalExtractRecord()

cls._add_collapsable_predecessor(tileable, op.inputs[0])
self._add_collapsable_predecessor(tileable, op.inputs[0])
return EvalExtractRecord(
in_tileable, _func_name_to_builder[func_name](expr), variables
)

@classmethod
def _extract_binary(cls, tileable) -> EvalExtractRecord:
def _extract_binary(self, tileable) -> EvalExtractRecord:
op = tileable.op
func_name = getattr(op, "_func_name", None) or getattr(op, "_bit_func_name")
if func_name not in _func_name_to_builder: # pragma: no cover
return EvalExtractRecord()

lhs_tileable, lhs_expr, lhs_vars = cls._extract_eval_expression(op.lhs)
lhs_tileable, lhs_expr, lhs_vars = self._extract_eval_expression(op.lhs)
if lhs_tileable is not None:
cls._add_collapsable_predecessor(tileable, op.lhs)
rhs_tileable, rhs_expr, rhs_vars = cls._extract_eval_expression(op.rhs)
self._add_collapsable_predecessor(tileable, op.lhs)
rhs_tileable, rhs_expr, rhs_vars = self._extract_eval_expression(op.rhs)
if rhs_tileable is not None:
cls._add_collapsable_predecessor(tileable, op.rhs)
self._add_collapsable_predecessor(tileable, op.rhs)

if lhs_expr is None or rhs_expr is None:
return EvalExtractRecord()
Expand All @@ -190,6 +187,9 @@ def _extract_binary(cls, tileable) -> EvalExtractRecord:
def apply(self, op: OperandType):
node = op.outputs[0]
in_tileable, expr, variables = self._extract_eval_expression(node)
opt_in_tileable = self._records.get_optimization_result(
in_tileable, in_tileable
)

new_op = DataFrameEval(
_key=node.op.key,
Expand All @@ -199,13 +199,13 @@ def apply(self, op: OperandType):
parser="pandas",
is_query=False,
)
new_node = new_op.new_tileable([in_tileable], **node.params).data
new_node._key = node.key
new_node._id = node.id
new_node = new_op.new_tileable(
[opt_in_tileable], _key=node.key, _id=node.id, **node.params
).data

self._remove_collapsable_predecessors(node)
self._replace_node(node, new_node)
self._graph.add_edge(in_tileable, new_node)
self._graph.add_edge(opt_in_tileable, new_node)

self._records.append_record(
OptimizationRecord(node, new_node, OptimizationRecordType.replace)
Expand Down Expand Up @@ -241,34 +241,40 @@ def _get_optimized_eval_op(self, op: OperandType) -> OperandType:
def _get_input_columnar_node(self, op: OperandType) -> ENTITY_TYPE:
raise NotImplementedError

def apply(self, op: DataFrameIndex):
node = op.outputs[0]
in_tileable = op.inputs[0]
in_columnar_node = self._get_input_columnar_node(op)

new_op = self._build_new_eval_op(op)
new_op._key = node.op.key

new_node = new_op.new_tileable([in_tileable], **node.params).data
new_node._key = node.key
new_node._id = node.id

self._add_collapsable_predecessor(node, in_columnar_node)
self._remove_collapsable_predecessors(node)
def _update_op_node(self, old_node: ENTITY_TYPE, new_node: ENTITY_TYPE):
self._replace_node(old_node, new_node)
for in_tileable in new_node.inputs:
self._graph.add_edge(in_tileable, new_node)

self._replace_node(node, new_node)
self._graph.add_edge(in_tileable, new_node)
original_node = self._records.get_original_chunk(old_node, old_node)
self._records.append_record(
OptimizationRecord(node, new_node, OptimizationRecordType.replace)
OptimizationRecord(original_node, new_node, OptimizationRecordType.replace)
)

# check node if it's in result
try:
i = self._graph.results.index(node)
i = self._graph.results.index(old_node)
self._graph.results[i] = new_node
except ValueError:
pass

def apply(self, op: DataFrameIndex):
node = op.outputs[0]
in_tileable = op.inputs[0]
in_columnar_node = self._get_input_columnar_node(op)
opt_in_tileable = self._records.get_optimization_result(
in_tileable, in_tileable
)

new_op = self._build_new_eval_op(op)
new_node = new_op.new_tileable(
[opt_in_tileable], _key=node.key, _id=node.id, **node.params
).data

self._add_collapsable_predecessor(node, in_columnar_node)
self._remove_collapsable_predecessors(node)
self._update_op_node(node, new_node)


@register_tileable_optimization_rule([DataFrameIndex])
class DataFrameBoolEvalToQuery(_DataFrameEvalRewriteRule):
Expand All @@ -287,6 +293,7 @@ def _get_input_columnar_node(self, op: OperandType) -> ENTITY_TYPE:
def _build_new_eval_op(self, op: OperandType):
in_eval_op = self._get_optimized_eval_op(op)
return DataFrameEval(
_key=op.key,
_output_types=get_output_types(op.outputs[0]),
expr=in_eval_op.expr,
variables=in_eval_op.variables,
Expand All @@ -308,10 +315,51 @@ def _get_input_columnar_node(self, op: DataFrameSetitem) -> ENTITY_TYPE:
def _build_new_eval_op(self, op: DataFrameSetitem):
in_eval_op = self._get_optimized_eval_op(op)
return DataFrameEval(
_key=op.key,
_output_types=get_output_types(op.outputs[0]),
expr=f"`{op.indexes}` = {in_eval_op.expr}",
variables=in_eval_op.variables,
parser="pandas",
is_query=False,
self_target=True,
)

def apply(self, op: DataFrameIndex):
super().apply(op)

node = op.outputs[0]
opt_node = self._records.get_optimization_result(node, node)
if not isinstance(opt_node.op, DataFrameEval): # pragma: no cover
return

# when encountering consecutive SetItems, expressions can be
# merged as a multiline expression
pred_opt_node = opt_node.inputs[0]
if (
isinstance(pred_opt_node.op, DataFrameEval)
and opt_node.op.parser == pred_opt_node.op.parser == "pandas"
and not opt_node.op.is_query
and not pred_opt_node.op.is_query
and opt_node.op.self_target
and pred_opt_node.op.self_target
):
new_expr = pred_opt_node.op.expr + "\n" + opt_node.op.expr
new_variables = (pred_opt_node.op.variables or dict()).copy()
new_variables.update(opt_node.op.variables or dict())

new_op = DataFrameEval(
_key=op.key,
_output_types=get_output_types(op.outputs[0]),
expr=new_expr,
variables=new_variables,
parser="pandas",
is_query=False,
self_target=True,
)
new_node = new_op.new_tileable(
pred_opt_node.inputs, _key=node.key, _id=node.id, **node.params
).data

self._add_collapsable_predecessor(opt_node, pred_opt_node)
self._remove_collapsable_predecessors(opt_node)
self._update_op_node(opt_node, new_node)
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,25 @@ def test_eval_setitem_to_eval(setup):
df2 = md.DataFrame(raw2, chunk_size=10)
df3 = df1.merge(df2, on="A", suffixes=("", "_"))
df3["K"] = df3["A"] * (1 - df3["B"])
df3["L"] = df3["K"] - df3["A"]
df3["M"] = df3["K"] + df3["L"]

graph = TileableGraph([df3.data])
next(TileableGraphBuilder(graph).build())
records = optimize(graph)
opt_df3 = records.get_optimization_result(df3.data)
assert opt_df3.op.expr == "`K` = (`A`) * ((1) - (`B`))"
assert opt_df3.op.expr == "\n".join(
[
"`K` = (`A`) * ((1) - (`B`))",
"`L` = (`K`) - (`A`)",
"`M` = (`K`) + (`L`)",
]
)
assert len(graph) == 4
assert len([n for n in graph if isinstance(n.op, DataFrameEval)]) == 1

r_df3 = raw.merge(raw2, on="A", suffixes=("", "_"))
r_df3["K"] = r_df3["A"] * (1 - r_df3["B"])
r_df3["L"] = r_df3["K"] - r_df3["A"]
r_df3["M"] = r_df3["K"] + r_df3["L"]
pd.testing.assert_frame_equal(df3.execute().fetch(), r_df3)

0 comments on commit e5c851c

Please sign in to comment.