diff --git a/python/hidet/drivers/build_graph.py b/python/hidet/drivers/build_graph.py index 3cc4b25e3..a3181554f 100644 --- a/python/hidet/drivers/build_graph.py +++ b/python/hidet/drivers/build_graph.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Set, Dict +from typing import List, Dict import os import json import shutil @@ -34,15 +34,15 @@ def get_graph_weights(graph): Get the weights of the graph. All constant tensors used by the operators in the graph, or returned directly by the graph, are considered as weights. """ - weights: Set[Tensor] = set() + weights: List[Tensor] = [] for node in graph.nodes: for x in node.inputs: if x.storage is not None: - weights.add(x) + weights.append(x) for y in graph.outputs: if y.storage is not None: - weights.add(y) - return list(weights) + weights.append(y) + return weights def get_graph_intermediates(graph): @@ -145,6 +145,7 @@ def get_graph_meta_data(graph: FlowGraph, num_kernels, space: int) -> GraphMetaD lines.append(str(node.task)) lines.append(str(graph)) lines.append(str(space)) + graph_hash = sha256('\n'.join(lines).encode('utf-8')).hexdigest()[:16] return GraphMetaData( diff --git a/python/hidet/graph/impl/graph_impl.py b/python/hidet/graph/impl/graph_impl.py index dad81cdae..8df427ef5 100644 --- a/python/hidet/graph/impl/graph_impl.py +++ b/python/hidet/graph/impl/graph_impl.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Dict, Set, Optional, Union +from typing import List, Tuple, Dict, Optional, Union from collections import defaultdict import hidet.option from hidet.graph.tensor import Tensor @@ -39,10 +39,12 @@ def graph_analyze( stop_tensors: List[Tensor] = stop_tensors or [] # find out all nodes - all_nodes: Set[Operator] = set() + # use dict for ordered set behaviour + # ordering needed for deterministic node ordering + all_nodes: Dict[Operator, bool] = {} def find_all_nodes(u: Operator): - all_nodes.add(u) + all_nodes[u] = True for x in u.inputs: if x.op is None or x in stop_tensors: continue