Skip to content

Commit

Permalink
Refine ArithmeticToEval related rules
Browse files Browse the repository at this point in the history
  • Loading branch information
ericpai committed Jul 3, 2023
1 parent dd02ba9 commit 8505806
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 78 deletions.
32 changes: 0 additions & 32 deletions mars/optimization/logical/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import functools
import itertools
import weakref
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
Expand Down Expand Up @@ -92,8 +91,6 @@ def get_original_entity(


class OptimizationRule(ABC):
_preds_to_remove = weakref.WeakKeyDictionary()

def __init__(
self,
graph: EntityGraph,
Expand Down Expand Up @@ -217,35 +214,6 @@ def _replace_subgraph(
for result in new_results:
self._graph.results[result_indices[result.key]] = result

def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType):
pred_original = self._records.get_original_entity(predecessor, predecessor)
if predecessor not in self._preds_to_remove:
self._preds_to_remove[pred_original] = {node}
else:
self._preds_to_remove[pred_original].add(node)

def _remove_collapsable_predecessors(self, node: EntityType):
node = self._records.get_optimization_result(node) or node
preds_opt_to_remove = []
for pred in self._graph.predecessors(node):
pred_original = self._records.get_original_entity(pred, pred)
pred_opt = self._records.get_optimization_result(pred, pred)

if pred_opt in self._graph.results or pred_original in self._graph.results:
continue
affect_succ = self._preds_to_remove.get(pred_original) or []
affect_succ_opt = [
self._records.get_optimization_result(s, s) for s in affect_succ
]
if all(s in affect_succ_opt for s in self._graph.successors(pred)):
preds_opt_to_remove.append((pred_original, pred_opt))

for pred_original, pred_opt in preds_opt_to_remove:
self._graph.remove_node(pred_opt)
self._records.append_record(
OptimizationRecord(pred_original, None, OptimizationRecordType.delete)
)


class OperandBasedOptimizationRule(OptimizationRule):
"""
Expand Down
5 changes: 1 addition & 4 deletions mars/optimization/logical/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,7 @@ def test_replace_null_subgraph():

c1.inputs.clear()
c2.inputs.clear()
r.replace_subgraph(
None,
{key_to_node[op.key] for op in [s1, s2]}
)
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
assert g1.results == expected_results
assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}}
expected_edges = {
Expand Down
125 changes: 83 additions & 42 deletions mars/optimization/logical/tileable/arithmetic_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,27 @@
# limitations under the License.

import weakref
from typing import NamedTuple, Optional
from abc import ABC
from typing import NamedTuple, Optional, Type, Set

import numpy as np
from pandas.api.types import is_scalar

from .... import dataframe as md
from ....core import Tileable, get_output_types, ENTITY_TYPE
from ....core import Tileable, get_output_types, ENTITY_TYPE, TileableGraph
from ....core.graph import EntityGraph
from ....dataframe.arithmetic.core import DataFrameUnaryUfunc, DataFrameBinopUfunc
from ....dataframe.base.eval import DataFrameEval
from ....dataframe.indexing.getitem import DataFrameIndex
from ....dataframe.indexing.setitem import DataFrameSetitem
from ....typing import OperandType
from ....typing import OperandType, EntityType
from ....utils import implements
from ..core import OptimizationRecord, OptimizationRecordType
from ..core import (
OptimizationRecord,
OptimizationRecordType,
OptimizationRecords,
Optimizer,
)
from ..tileable.core import register_operand_based_optimization_rule
from .core import OperandBasedOptimizationRule

Expand Down Expand Up @@ -66,8 +73,70 @@ def builder(lhs: str, rhs: str):
_extract_result_cache = weakref.WeakKeyDictionary()


class _EvalRewriteOptimizationRule(OperandBasedOptimizationRule, ABC):
def __init__(
self,
graph: EntityGraph,
records: OptimizationRecords,
optimizer_cls: Type[Optimizer],
):
super().__init__(graph, records, optimizer_cls)
self._marked_predecessors = dict()

def _mark_predecessor(self, node: EntityType, predecessor: EntityType):
pred_original = self._records.get_original_entity(predecessor, predecessor)
if predecessor not in self._marked_predecessors:
self._marked_predecessors[pred_original] = {node}
else:
self._marked_predecessors[pred_original].add(node)

def _find_nodes_to_remove(self, node: EntityType) -> Set[EntityType]:
node = self._records.get_optimization_result(node) or node
removed_nodes = {node}
results_set = set(self._graph.results)
removed_pairs = []
for pred in self._graph.iter_predecessors(node):
pred_original = self._records.get_original_entity(pred, pred)
pred_opt = self._records.get_optimization_result(pred, pred)

if pred_opt in results_set or pred_original in results_set:
continue

affect_succ = self._marked_predecessors.get(pred_original) or []
affect_succ_opt = [
self._records.get_optimization_result(s, s) for s in affect_succ
]
if all(s in affect_succ_opt for s in self._graph.iter_successors(pred)):
removed_pairs.append((pred_original, pred_opt))

for pred_original, pred_opt in removed_pairs:
removed_nodes.add(pred_opt)
self._records.append_record(
OptimizationRecord(pred_original, None, OptimizationRecordType.delete)
)
return removed_nodes

def _replace_with_new_node(self, original_node: EntityType, new_node: EntityType):
# Find all the nodes to remove
nodes_to_remove = self._find_nodes_to_remove(original_node)

# Build the replaced subgraph
subgraph = TileableGraph()
subgraph.add_node(new_node)

new_results = [new_node] if new_node in self._graph.results else None
self._replace_subgraph(subgraph, nodes_to_remove, new_results)
self._records.append_record(
OptimizationRecord(
self._records.get_original_entity(original_node, original_node),
new_node,
OptimizationRecordType.replace,
)
)


@register_operand_based_optimization_rule([DataFrameUnaryUfunc, DataFrameBinopUfunc])
class SeriesArithmeticToEval(OperandBasedOptimizationRule):
class SeriesArithmeticToEval(_EvalRewriteOptimizationRule):
_var_counter = 0

@classmethod
Expand Down Expand Up @@ -151,7 +220,7 @@ def _extract_unary(self, tileable) -> EvalExtractRecord:
if in_tileable is None:
return EvalExtractRecord()

self._add_collapsable_predecessor(tileable, op.inputs[0])
self._mark_predecessor(tileable, op.inputs[0])
return EvalExtractRecord(
in_tileable, _func_name_to_builder[func_name](expr), variables
)
Expand All @@ -164,10 +233,10 @@ def _extract_binary(self, tileable) -> EvalExtractRecord:

lhs_tileable, lhs_expr, lhs_vars = self._extract_eval_expression(op.lhs)
if lhs_tileable is not None:
self._add_collapsable_predecessor(tileable, op.lhs)
self._mark_predecessor(tileable, op.lhs)
rhs_tileable, rhs_expr, rhs_vars = self._extract_eval_expression(op.rhs)
if rhs_tileable is not None:
self._add_collapsable_predecessor(tileable, op.rhs)
self._mark_predecessor(tileable, op.rhs)

if lhs_expr is None or rhs_expr is None:
return EvalExtractRecord()
Expand Down Expand Up @@ -204,24 +273,10 @@ def apply_to_operand(self, op: OperandType):
new_node = new_op.new_tileable(
[opt_in_tileable], _key=node.key, _id=node.id, **node.params
).data
self._replace_with_new_node(node, new_node)

self._remove_collapsable_predecessors(node)
self._replace_node(node, new_node)
self._graph.add_edge(opt_in_tileable, new_node)

self._records.append_record(
OptimizationRecord(node, new_node, OptimizationRecordType.replace)
)

# check node if it's in result
try:
i = self._graph.results.index(node)
self._graph.results[i] = new_node
except ValueError:
pass


class _DataFrameEvalRewriteRule(OperandBasedOptimizationRule):
class _DataFrameEvalRewriteRule(_EvalRewriteOptimizationRule):
@implements(OperandBasedOptimizationRule.match_operand)
def match_operand(self, op: OperandType) -> bool:
optimized_eval_op = self._get_optimized_eval_op(op)
Expand All @@ -245,16 +300,6 @@ def _get_optimized_eval_op(self, op: OperandType) -> OperandType:
def _get_input_columnar_node(self, op: OperandType) -> ENTITY_TYPE:
raise NotImplementedError

def _update_op_node(self, old_node: ENTITY_TYPE, new_node: ENTITY_TYPE):
self._replace_node(old_node, new_node)
for in_tileable in new_node.inputs:
self._graph.add_edge(in_tileable, new_node)

original_node = self._records.get_original_entity(old_node, old_node)
self._records.append_record(
OptimizationRecord(original_node, new_node, OptimizationRecordType.replace)
)

@implements(OperandBasedOptimizationRule.apply_to_operand)
def apply_to_operand(self, op: DataFrameIndex):
node = op.outputs[0]
Expand All @@ -268,10 +313,8 @@ def apply_to_operand(self, op: DataFrameIndex):
new_node = new_op.new_tileable(
[opt_in_tileable], _key=node.key, _id=node.id, **node.params
).data

self._add_collapsable_predecessor(node, in_columnar_node)
self._remove_collapsable_predecessors(node)
self._update_op_node(node, new_node)
self._mark_predecessor(node, in_columnar_node)
self._replace_with_new_node(node, new_node)


@register_operand_based_optimization_rule([DataFrameIndex])
Expand Down Expand Up @@ -360,7 +403,5 @@ def apply_to_operand(self, op: DataFrameIndex):
new_node = new_op.new_tileable(
pred_opt_node.inputs, _key=node.key, _id=node.id, **node.params
).data

self._add_collapsable_predecessor(opt_node, pred_opt_node)
self._remove_collapsable_predecessors(opt_node)
self._update_op_node(opt_node, new_node)
self._mark_predecessor(opt_node, pred_opt_node)
self._replace_with_new_node(opt_node, new_node)

0 comments on commit 8505806

Please sign in to comment.