From b9e6430f32185e44639fe27519b9f2c954994e3b Mon Sep 17 00:00:00 2001 From: EricPai Date: Tue, 27 Jun 2023 17:12:48 +0800 Subject: [PATCH] Add replace_subgraph with tests --- mars/core/entity/core.py | 4 + mars/optimization/logical/core.py | 78 ++++++- mars/optimization/logical/tests/__init__.py | 13 ++ mars/optimization/logical/tests/test_core.py | 221 +++++++++++++++++++ 4 files changed, 315 insertions(+), 1 deletion(-) create mode 100644 mars/optimization/logical/tests/__init__.py create mode 100644 mars/optimization/logical/tests/test_core.py diff --git a/mars/core/entity/core.py b/mars/core/entity/core.py index 6a27ac65d2..d3234a59d5 100644 --- a/mars/core/entity/core.py +++ b/mars/core/entity/core.py @@ -42,6 +42,10 @@ def __init__(self, *args, **kwargs): def op(self): return self._op + @property + def outputs(self): + return self._op.outputs + @property def inputs(self): return self.op.inputs diff --git a/mars/optimization/logical/core.py b/mars/optimization/logical/core.py index ba49f825f0..61bfcd2c00 100644 --- a/mars/optimization/logical/core.py +++ b/mars/optimization/logical/core.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import itertools import weakref from abc import ABC, abstractmethod from collections import defaultdict @@ -19,7 +20,7 @@ from enum import Enum from typing import Dict, List, Optional, Type, Set -from ...core import OperandType, EntityType, enter_mode +from ...core import OperandType, EntityType, enter_mode, Entity from ...core.graph import EntityGraph from ...utils import implements @@ -130,6 +131,77 @@ def _replace_node(self, original_node: EntityType, new_node: EntityType): for succ in successors: self._graph.add_edge(new_node, succ) + def _replace_subgraph( + self, + graph: Optional[EntityGraph], + removed_nodes: Optional[Set[EntityType]], + new_results: Optional[List[Entity]] = None, + ): + """ + Replace the subgraph from the self._graph represented by a list of nodes with input graph. + It will delete the nodes in removed_nodes with all linked edges first, and then add (or update if it's still + existed in self._graph) the nodes and edges of the input graph. + + Parameters + ---------- + graph: EntityGraph, optional. + The input graph. If it's none, no new node and edge will be added. + removed_nodes: Set[EntityType], optional. + The nodes to be removed. All the edges connected with them are removed as well. + new_results: List[EntityType], optional, default is None + The updated results of the graph. If it's None, then the results will not be updated. + + Raises + ------ + ReplaceSubgraphError: + If the input key of the removed node's successor can't be found in the subgraph. + Or some of the nodes of the subgraph are in removed ones. + """ + infected_successors = set() + + output_to_node = dict() + removed_nodes = removed_nodes or set() + if graph is not None: + # Add the output key -> node of the subgraph + for node in graph.iter_nodes(): + if node in removed_nodes: + raise ReplaceSubgraphError(f"The node {node} is in the removed set") + for output in node.outputs: + output_to_node[output.key] = node + + for node in removed_nodes: + for infected_successor in self._graph.iter_successors(node): + if infected_successor not in removed_nodes: + infected_successors.add(infected_successor) + # Check whether infected successors' inputs are in subgraph + for infected_successor in infected_successors: + for inp in infected_successor.inputs: + if inp.key not in output_to_node: + raise ReplaceSubgraphError( + f"The output {inp} of node {infected_successor} is missing in the subgraph" + ) + for node in removed_nodes: + self._graph.remove_node(node) + + if graph is None: + return + + # Add the output key -> node of the original graph + for node in self._graph.iter_nodes(): + for output in node.outputs: + output_to_node[output.key] = node + + for node in graph.iter_nodes(): + self._graph.add_node(node) + + for node in itertools.chain(graph.iter_nodes(), infected_successors): + for inp in node.inputs: + pred_node = output_to_node[inp.key] + self._graph.add_edge(pred_node, node) + + if new_results is not None: + self._graph.results = new_results.copy() + def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType): pred_original = self._records.get_original_entity(predecessor, predecessor) if predecessor not in self._preds_to_remove: @@ -283,3 +355,7 @@ def optimize(cls, graph: EntityGraph) -> OptimizationRecords: graph.results = new_results return records + + +class ReplaceSubgraphError(Exception): + pass diff --git a/mars/optimization/logical/tests/__init__.py b/mars/optimization/logical/tests/__init__.py new file mode 100644 index 0000000000..c71e83c08e --- /dev/null +++ b/mars/optimization/logical/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 1999-2021 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mars/optimization/logical/tests/test_core.py b/mars/optimization/logical/tests/test_core.py new file mode 100644 index 0000000000..0900bce3d4 --- /dev/null +++ b/mars/optimization/logical/tests/test_core.py @@ -0,0 +1,221 @@ +# Copyright 1999-2021 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import pytest + +import mars.tensor as mt +from ..core import OptimizationRule, ReplaceSubgraphError +from .... import dataframe as md + + +class _MockRule(OptimizationRule): + def apply(self) -> bool: + pass + + def replace_subgraph(self, graph, removed_nodes, new_results=None): + self._replace_subgraph(graph, removed_nodes, new_results) + + +def test_replace_tileable_subgraph(): + """ + Original Graph: + s1 ---> c1 ---> v1 ---> v4 ----> v6(output) <--- v5 <--- c5 <--- s5 + | ^ + | | + V | + v3 ------| + ^ + | + s2 ---> c2 ---> v2 + + Target Graph: + s1 ---> c1 ---> v1 ---> v7 ----> v8(output) <--- v5 <--- c5 <--- s5 + ^ + | + s2 ---> c2 ---> v2 + + The nodes [v3, v4, v6] will be removed. + Subgraph only contains [v7, v8] + """ + s1 = mt.random.randint(0, 100, size=(5, 4)) + v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5) + s2 = mt.random.randint(0, 100, size=(5, 4)) + v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5) + v3 = v1.add(v2) + v4 = v3.add(v1) + s5 = mt.random.randint(0, 100, size=(5, 4)) + v5 = md.DataFrame(s5, columns=list("ABCD"), chunk_size=4) + v6 = v5.sub(v4) + g1 = v6.build_graph() + v7 = v1.sub(v2) + v8 = v7.add(v5) + g2 = v8.build_graph() + + # Here we use a trick way to construct the subgraph for test only + key_to_node = dict() + for node in g2.iter_nodes(): + key_to_node[node.key] = node + for key, node in key_to_node.items(): + if key != v7.key and key != v8.key: + g2.remove_node(node) + r = _MockRule(g1, None, None) + for node in g1.iter_nodes(): + key_to_node[node.key] = node + + c1 = g1.successors(key_to_node[s1.key])[0] + c2 = g1.successors(key_to_node[s2.key])[0] + c5 = g1.successors(key_to_node[s5.key])[0] + + expected_results = [v8.outputs[0]] + r.replace_subgraph( + g2, {key_to_node[op.key] for op in [v3, v4, v6]}, expected_results + ) + assert g1.results == expected_results + + expected_nodes = {s1, c1, v1, s2, c2, v2, s5, c5, v5, v7, v8} + assert set(g1) == {key_to_node[n.key] for n in expected_nodes} + + expected_edges = { + s1: [c1], + c1: [v1], + v1: [v7], + s2: [c2], + c2: [v2], + v2: [v7], + s5: [c5], + c5: [v5], + v5: [v8], + v7: [v8], + v8: [], + } + for pred, successors in expected_edges.items(): + pred_node = key_to_node[pred.key] + assert g1.count_successors(pred_node) == len(successors) + for successor in successors: + assert g1.has_successor(pred_node, key_to_node[successor.key]) + + +def test_replace_null_subgraph(): + """ + Original Graph: + s1 ---> c1 ---> v1 ---> v3 <--- v2 <--- c2 <--- s2 + + Target Graph: + c1 ---> v1 ---> v3 <--- v2 <--- c2 + + The nodes [s1, s2] will be removed. + Subgraph is None + """ + s1 = mt.random.randint(0, 100, size=(10, 4)) + v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5) + s2 = mt.random.randint(0, 100, size=(10, 4)) + v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5) + v3 = v1.add(v2) + g1 = v3.build_graph() + key_to_node = {node.key: node for node in g1.iter_nodes()} + c1 = g1.successors(key_to_node[s1.key])[0] + c2 = g1.successors(key_to_node[s2.key])[0] + r = _MockRule(g1, None, None) + expected_results = [v3.outputs[0]] + # delete c5 s5 will fail + with pytest.raises(ReplaceSubgraphError) as e: + r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]}) + assert g1.results == expected_results + assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}} + expected_edges = { + s1: [c1], + c1: [v1], + v1: [v3], + s2: [c2], + c2: [v2], + v2: [v3], + v3: [], + } + for pred, successors in expected_edges.items(): + pred_node = key_to_node[pred.key] + assert g1.count_successors(pred_node) == len(successors) + for successor in successors: + assert g1.has_successor(pred_node, key_to_node[successor.key]) + + c1.inputs.clear() + c2.inputs.clear() + r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]}) + assert g1.results == expected_results + assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}} + expected_edges = { + c1: [v1], + v1: [v3], + c2: [v2], + v2: [v3], + v3: [], + } + for pred, successors in expected_edges.items(): + pred_node = key_to_node[pred.key] + assert g1.count_successors(pred_node) == len(successors) + for successor in successors: + assert g1.has_successor(pred_node, key_to_node[successor.key]) + + +def test_replace_subgraph_without_removing_nodes(): + """ + Original Graph: + s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2 + + Target Graph: + s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2 + s3 ---> c3 ---> v3 + + Nothing will be removed. + Subgraph only contains [s3, c3, v3] + """ + s1 = mt.random.randint(0, 100, size=(10, 4)) + v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5) + s2 = mt.random.randint(0, 100, size=(10, 4)) + v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5) + v4 = v1.add(v2) + g1 = v4.build_graph() + + s3 = mt.random.randint(0, 100, size=(10, 4)) + v3 = md.DataFrame(s3, columns=list("ABCD"), chunk_size=5) + g2 = v3.build_graph() + key_to_node = { + node.key: node for node in itertools.chain(g1.iter_nodes(), g2.iter_nodes()) + } + expected_results = [v3.outputs[0], v4.outputs[0]] + c1 = g1.successors(key_to_node[s1.key])[0] + c2 = g1.successors(key_to_node[s2.key])[0] + c3 = g2.successors(key_to_node[s3.key])[0] + r = _MockRule(g1, None, None) + r.replace_subgraph(g2, None, expected_results) + assert g1.results == expected_results + assert set(g1) == { + key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, s3, c3, v3, v4} + } + expected_edges = { + s1: [c1], + c1: [v1], + v1: [v4], + s2: [c2], + c2: [v2], + v2: [v4], + s3: [c3], + c3: [v3], + v3: [], + v4: [], + } + for pred, successors in expected_edges.items(): + pred_node = key_to_node[pred.key] + assert g1.count_successors(pred_node) == len(successors) + for successor in successors: + assert g1.has_successor(pred_node, key_to_node[successor.key])