diff --git a/mars/core/graph/entity.py b/mars/core/graph/entity.py index 2301996785..e0b8b0e602 100644 --- a/mars/core/graph/entity.py +++ b/mars/core/graph/entity.py @@ -19,6 +19,7 @@ from ...serialization.core import buffered from ...serialization.serializables import Serializable, DictField, ListField, BoolField from ...serialization.serializables.core import SerializableSerializer +from ...utils import tokenize from .core import DAG @@ -57,6 +58,11 @@ def copy(self) -> "EntityGraph": class TileableGraph(EntityGraph, Iterable[Tileable]): _result_tileables: List[Tileable] + # logic key is a unique and deterministic key for `TileableGraph`. For + # multiple runs the logic key will remain same if the computational logic + # doesn't change. And it can be used to some optimization when running a + # same `execute`, like HBO. + _logic_key: str def __init__(self, result_tileables: List[Tileable] = None): super().__init__() @@ -74,6 +80,19 @@ def results(self): def results(self, new_results): self._result_tileables = new_results + @property + def logic_key(self): + if not hasattr(self, "_logic_key") or self._logic_key is None: + token_keys = [] + for node in self.bfs(): + token_keys.append( + tokenize(node.op.get_logic_key(), **node.extra_params) + if node.extra_params + else node.op.get_logic_key() + ) + self._logic_key = tokenize(*token_keys) + return self._logic_key + class ChunkGraph(EntityGraph, Iterable[Chunk]): _result_chunks: List[Chunk] diff --git a/mars/core/graph/tests/test_graph.py b/mars/core/graph/tests/test_graph.py index ca5c1a6501..a9362b3ab9 100644 --- a/mars/core/graph/tests/test_graph.py +++ b/mars/core/graph/tests/test_graph.py @@ -14,6 +14,9 @@ import pytest +import numpy as np + +from .... import dataframe as md from .... import tensor as mt from ....tests.core import flaky from ....utils import to_str @@ -105,3 +108,90 @@ def test_to_dot(): dot = to_str(graph.to_dot(trunc_key=5)) assert all(to_str(n.key)[:5] in dot for n in graph) is True + + +def test_tileable_graph_logic_key(): + # Tensor + t1 = mt.random.randint(10, size=(10, 8), chunk_size=4) + t2 = mt.random.randint(10, size=(10, 8), chunk_size=5) + graph1 = (t1 + t2).build_graph(tile=False) + tt1 = mt.random.randint(10, size=(10, 8), chunk_size=4) + tt2 = mt.random.randint(10, size=(10, 8), chunk_size=5) + graph2 = (tt1 + tt2).build_graph(tile=False) + assert graph1.logic_key == graph2.logic_key + t3 = mt.random.randint(10, size=(10, 8), chunk_size=6) + tt3 = mt.random.randint(10, size=(10, 8), chunk_size=6) + graph3 = (t1 + t3).build_graph(tile=False) + graph4 = (t1 + tt3).build_graph(tile=False) + assert graph1.logic_key != graph3.logic_key + assert graph3.logic_key == graph4.logic_key + t4 = mt.random.randint(10, size=(10, 8)) + graph5 = (t1 + t4).build_graph(tile=False) + assert graph1.logic_key != graph5.logic_key + + # Series + s1 = md.Series([1, 3, 5, mt.nan, 6, 8]) + s2 = md.Series(np.random.randn(1000), chunk_size=100) + graph1 = (s1 + s2).build_graph(tile=False) + ss1 = md.Series([1, 3, 5, mt.nan, 6, 8]) + ss2 = md.Series(np.random.randn(1000), chunk_size=100) + graph2 = (ss1 + ss2).build_graph(tile=False) + assert graph1.logic_key == graph2.logic_key + s3 = md.Series(np.random.randn(1000), chunk_size=200) + ss3 = md.Series(np.random.randn(1000), chunk_size=200) + graph3 = (s1 + s3).build_graph(tile=False) + graph4 = (s1 + ss3).build_graph(tile=False) + assert graph1.logic_key != graph3.logic_key + assert graph3.logic_key == graph4.logic_key + s4 = md.Series(np.random.randn(1000)) + graph5 = (s1 + s4).build_graph(tile=False) + assert graph1.logic_key != graph5.logic_key + + # DataFrame + df1 = md.DataFrame( + np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=5 + ) + df2 = md.DataFrame( + np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=4 + ) + graph1 = (df1 + df2).build_graph(tile=False) + ddf1 = md.DataFrame( + np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=5 + ) + ddf2 = md.DataFrame( + np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=4 + ) + graph2 = (ddf1 + ddf2).build_graph(tile=False) + assert graph1.logic_key == graph2.logic_key + df3 = md.DataFrame( + np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=3 + ) + ddf3 = md.DataFrame( + np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=3 + ) + graph3 = (df1 + df3).build_graph(tile=False) + graph4 = (df1 + ddf3).build_graph(tile=False) + assert graph1.logic_key != graph3.logic_key + assert graph3.logic_key == graph4.logic_key + df5 = md.DataFrame( + np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD") + ) + graph5 = (df1 + df5).build_graph(tile=False) + assert graph1.logic_key != graph5.logic_key + graph6 = df1.describe().build_graph(tile=False) + graph7 = df2.describe().build_graph(tile=False) + assert graph6.logic_key != graph7.logic_key + graph8 = df1.apply(lambda x: x.max() - x.min()).build_graph(tile=False) + graph9 = df2.apply(lambda x: x.max() - x.min()).build_graph(tile=False) + assert graph8.logic_key != graph9.logic_key + pieces1 = [df1[:3], df1[3:7], df1[7:]] + graph10 = md.concat(pieces1).build_graph(tile=False) + pieces2 = [df2[:3], df2[3:7], df2[7:]] + graph11 = md.concat(pieces2).build_graph(tile=False) + assert graph10.logic_key != graph11.logic_key + graph12 = md.merge(df1, df2, on="A", how="left").build_graph(tile=False) + graph13 = md.merge(df1, df3, on="A", how="left").build_graph(tile=False) + assert graph12.logic_key != graph13.logic_key + graph14 = df2.groupby("A").sum().build_graph(tile=False) + graph15 = df3.groupby("A").sum().build_graph(tile=False) + assert graph14.logic_key != graph15.logic_key