Skip to content

Commit

Permalink
Add logic key for tileable graph (#2961)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongchun authored Apr 27, 2022
1 parent 2966c8e commit a057995
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 0 deletions.
19 changes: 19 additions & 0 deletions mars/core/graph/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ...serialization.core import buffered
from ...serialization.serializables import Serializable, DictField, ListField, BoolField
from ...serialization.serializables.core import SerializableSerializer
from ...utils import tokenize
from .core import DAG


Expand Down Expand Up @@ -57,6 +58,11 @@ def copy(self) -> "EntityGraph":

class TileableGraph(EntityGraph, Iterable[Tileable]):
_result_tileables: List[Tileable]
# logic key is a unique and deterministic key for `TileableGraph`. For
# multiple runs the logic key will remain same if the computational logic
# doesn't change. And it can be used to some optimization when running a
# same `execute`, like HBO.
_logic_key: str

def __init__(self, result_tileables: List[Tileable] = None):
super().__init__()
Expand All @@ -74,6 +80,19 @@ def results(self):
def results(self, new_results):
self._result_tileables = new_results

@property
def logic_key(self):
if not hasattr(self, "_logic_key") or self._logic_key is None:
token_keys = []
for node in self.bfs():
token_keys.append(
tokenize(node.op.get_logic_key(), **node.extra_params)
if node.extra_params
else node.op.get_logic_key()
)
self._logic_key = tokenize(*token_keys)
return self._logic_key


class ChunkGraph(EntityGraph, Iterable[Chunk]):
_result_chunks: List[Chunk]
Expand Down
90 changes: 90 additions & 0 deletions mars/core/graph/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

import pytest

import numpy as np

from .... import dataframe as md
from .... import tensor as mt
from ....tests.core import flaky
from ....utils import to_str
Expand Down Expand Up @@ -105,3 +108,90 @@ def test_to_dot():

dot = to_str(graph.to_dot(trunc_key=5))
assert all(to_str(n.key)[:5] in dot for n in graph) is True


def test_tileable_graph_logic_key():
# Tensor
t1 = mt.random.randint(10, size=(10, 8), chunk_size=4)
t2 = mt.random.randint(10, size=(10, 8), chunk_size=5)
graph1 = (t1 + t2).build_graph(tile=False)
tt1 = mt.random.randint(10, size=(10, 8), chunk_size=4)
tt2 = mt.random.randint(10, size=(10, 8), chunk_size=5)
graph2 = (tt1 + tt2).build_graph(tile=False)
assert graph1.logic_key == graph2.logic_key
t3 = mt.random.randint(10, size=(10, 8), chunk_size=6)
tt3 = mt.random.randint(10, size=(10, 8), chunk_size=6)
graph3 = (t1 + t3).build_graph(tile=False)
graph4 = (t1 + tt3).build_graph(tile=False)
assert graph1.logic_key != graph3.logic_key
assert graph3.logic_key == graph4.logic_key
t4 = mt.random.randint(10, size=(10, 8))
graph5 = (t1 + t4).build_graph(tile=False)
assert graph1.logic_key != graph5.logic_key

# Series
s1 = md.Series([1, 3, 5, mt.nan, 6, 8])
s2 = md.Series(np.random.randn(1000), chunk_size=100)
graph1 = (s1 + s2).build_graph(tile=False)
ss1 = md.Series([1, 3, 5, mt.nan, 6, 8])
ss2 = md.Series(np.random.randn(1000), chunk_size=100)
graph2 = (ss1 + ss2).build_graph(tile=False)
assert graph1.logic_key == graph2.logic_key
s3 = md.Series(np.random.randn(1000), chunk_size=200)
ss3 = md.Series(np.random.randn(1000), chunk_size=200)
graph3 = (s1 + s3).build_graph(tile=False)
graph4 = (s1 + ss3).build_graph(tile=False)
assert graph1.logic_key != graph3.logic_key
assert graph3.logic_key == graph4.logic_key
s4 = md.Series(np.random.randn(1000))
graph5 = (s1 + s4).build_graph(tile=False)
assert graph1.logic_key != graph5.logic_key

# DataFrame
df1 = md.DataFrame(
np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=5
)
df2 = md.DataFrame(
np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=4
)
graph1 = (df1 + df2).build_graph(tile=False)
ddf1 = md.DataFrame(
np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=5
)
ddf2 = md.DataFrame(
np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=4
)
graph2 = (ddf1 + ddf2).build_graph(tile=False)
assert graph1.logic_key == graph2.logic_key
df3 = md.DataFrame(
np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=3
)
ddf3 = md.DataFrame(
np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=3
)
graph3 = (df1 + df3).build_graph(tile=False)
graph4 = (df1 + ddf3).build_graph(tile=False)
assert graph1.logic_key != graph3.logic_key
assert graph3.logic_key == graph4.logic_key
df5 = md.DataFrame(
np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD")
)
graph5 = (df1 + df5).build_graph(tile=False)
assert graph1.logic_key != graph5.logic_key
graph6 = df1.describe().build_graph(tile=False)
graph7 = df2.describe().build_graph(tile=False)
assert graph6.logic_key != graph7.logic_key
graph8 = df1.apply(lambda x: x.max() - x.min()).build_graph(tile=False)
graph9 = df2.apply(lambda x: x.max() - x.min()).build_graph(tile=False)
assert graph8.logic_key != graph9.logic_key
pieces1 = [df1[:3], df1[3:7], df1[7:]]
graph10 = md.concat(pieces1).build_graph(tile=False)
pieces2 = [df2[:3], df2[3:7], df2[7:]]
graph11 = md.concat(pieces2).build_graph(tile=False)
assert graph10.logic_key != graph11.logic_key
graph12 = md.merge(df1, df2, on="A", how="left").build_graph(tile=False)
graph13 = md.merge(df1, df3, on="A", how="left").build_graph(tile=False)
assert graph12.logic_key != graph13.logic_key
graph14 = df2.groupby("A").sum().build_graph(tile=False)
graph15 = df3.groupby("A").sum().build_graph(tile=False)
assert graph14.logic_key != graph15.logic_key

0 comments on commit a057995

Please sign in to comment.