From 6268e28d465b2b9f0f88605cf552d13492f7d7c0 Mon Sep 17 00:00:00 2001 From: Kevin Tong Date: Tue, 20 Feb 2024 14:45:26 -0500 Subject: [PATCH 1/3] fix graph metadata hash --- python/hidet/drivers/build_graph.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/hidet/drivers/build_graph.py b/python/hidet/drivers/build_graph.py index 3cc4b25e3..7fb8bc75b 100644 --- a/python/hidet/drivers/build_graph.py +++ b/python/hidet/drivers/build_graph.py @@ -145,6 +145,11 @@ def get_graph_meta_data(graph: FlowGraph, num_kernels, space: int) -> GraphMetaD lines.append(str(node.task)) lines.append(str(graph)) lines.append(str(space)) + + # graph nodes are not traversed in deterministic order + # sort to ensure same graph --> same hash + lines.sort() + graph_hash = sha256('\n'.join(lines).encode('utf-8')).hexdigest()[:16] return GraphMetaData( From 12d404cafc49e7eec7a9800629a361ac49ecdb13 Mon Sep 17 00:00:00 2001 From: Kevin Tong Date: Sat, 24 Feb 2024 03:10:48 -0500 Subject: [PATCH 2/3] determinism --- python/hidet/drivers/build_graph.py | 10 +++++----- python/hidet/graph/impl/graph_impl.py | 8 ++++++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/python/hidet/drivers/build_graph.py b/python/hidet/drivers/build_graph.py index 7fb8bc75b..c22722e79 100644 --- a/python/hidet/drivers/build_graph.py +++ b/python/hidet/drivers/build_graph.py @@ -34,15 +34,15 @@ def get_graph_weights(graph): Get the weights of the graph. All constant tensors used by the operators in the graph, or returned directly by the graph, are considered as weights. """ - weights: Set[Tensor] = set() + weights: List[Tensor] = list() for node in graph.nodes: for x in node.inputs: if x.storage is not None: - weights.add(x) + weights.append(x) for y in graph.outputs: if y.storage is not None: - weights.add(y) - return list(weights) + weights.append(y) + return weights def get_graph_intermediates(graph): @@ -151,7 +151,7 @@ def get_graph_meta_data(graph: FlowGraph, num_kernels, space: int) -> GraphMetaD lines.sort() graph_hash = sha256('\n'.join(lines).encode('utf-8')).hexdigest()[:16] - + return GraphMetaData( inputs=inputs, outputs=outputs, diff --git a/python/hidet/graph/impl/graph_impl.py b/python/hidet/graph/impl/graph_impl.py index dad81cdae..01065d21a 100644 --- a/python/hidet/graph/impl/graph_impl.py +++ b/python/hidet/graph/impl/graph_impl.py @@ -39,10 +39,12 @@ def graph_analyze( stop_tensors: List[Tensor] = stop_tensors or [] # find out all nodes - all_nodes: Set[Operator] = set() + # use dict for ordered set behaviour + # ordering needed for deterministic node ordering + all_nodes: Dict[Operator, bool] = {} def find_all_nodes(u: Operator): - all_nodes.add(u) + all_nodes[u] = True for x in u.inputs: if x.op is None or x in stop_tensors: continue @@ -56,6 +58,8 @@ def valid(t: Tensor) -> bool: for ot in outputs: if ot.trace and ot not in stop_tensors: find_all_nodes(ot.op) + print("all_nodes") + print(all_nodes) # topological sort out_degree: Dict[Operator, int] = {u: 0 for u in all_nodes} From 4cfe9813ff3ae22c68ecd24ec17a1f8f4064ddac Mon Sep 17 00:00:00 2001 From: Kevin Tong Date: Sat, 24 Feb 2024 03:18:08 -0500 Subject: [PATCH 3/3] lint --- python/hidet/drivers/build_graph.py | 10 +++------- python/hidet/graph/impl/graph_impl.py | 4 +--- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/python/hidet/drivers/build_graph.py b/python/hidet/drivers/build_graph.py index c22722e79..a3181554f 100644 --- a/python/hidet/drivers/build_graph.py +++ b/python/hidet/drivers/build_graph.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Set, Dict +from typing import List, Dict import os import json import shutil @@ -34,7 +34,7 @@ def get_graph_weights(graph): Get the weights of the graph. All constant tensors used by the operators in the graph, or returned directly by the graph, are considered as weights. """ - weights: List[Tensor] = list() + weights: List[Tensor] = [] for node in graph.nodes: for x in node.inputs: if x.storage is not None: @@ -146,12 +146,8 @@ def get_graph_meta_data(graph: FlowGraph, num_kernels, space: int) -> GraphMetaD lines.append(str(graph)) lines.append(str(space)) - # graph nodes are not traversed in deterministic order - # sort to ensure same graph --> same hash - lines.sort() - graph_hash = sha256('\n'.join(lines).encode('utf-8')).hexdigest()[:16] - + return GraphMetaData( inputs=inputs, outputs=outputs, diff --git a/python/hidet/graph/impl/graph_impl.py b/python/hidet/graph/impl/graph_impl.py index 01065d21a..8df427ef5 100644 --- a/python/hidet/graph/impl/graph_impl.py +++ b/python/hidet/graph/impl/graph_impl.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Dict, Set, Optional, Union +from typing import List, Tuple, Dict, Optional, Union from collections import defaultdict import hidet.option from hidet.graph.tensor import Tensor @@ -58,8 +58,6 @@ def valid(t: Tensor) -> bool: for ot in outputs: if ot.trace and ot not in stop_tensors: find_all_nodes(ot.op) - print("all_nodes") - print(all_nodes) # topological sort out_degree: Dict[Operator, int] = {u: 0 for u in all_nodes}