diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index ada2bd327227..c77f895b2557 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -7,6 +7,7 @@ import time import numpy as np +from joblib import Parallel, delayed import torch @@ -1384,6 +1385,7 @@ def dgl_partition_to_graphbolt( store_inner_node=False, store_inner_edge=False, graph_formats=None, + n_jobs=1, ): """Convert partitions of dgl to FusedCSCSamplingGraph of GraphBolt. @@ -1411,6 +1413,9 @@ def dgl_partition_to_graphbolt( specifying `coo` format to save edge ID mapping and destination node IDs. If not specified, whether to save `coo` format is determined by the availability of the format in DGL partitions. Default: None. + n_jobs: int + Number of parallel jobs to run during partition conversion. Max parallelism + is determined by the partition count. """ debug_mode = "DGL_DIST_DEBUG" in os.environ if debug_mode: @@ -1439,8 +1444,7 @@ def init_type_per_edge(graph, gpb): # But this is not a problem since such information is not used in sampling. # We can simply pass None to it. - # Iterate over partitions. - for part_id in range(num_parts): + def convert_partition(part_id, graph_formats): graph, _, _, gpb, _, _, _ = load_partition( part_config, part_id, load_feats=False ) @@ -1564,10 +1568,20 @@ def init_type_per_edge(graph, gpb): ) torch.save(csc_graph, csc_graph_path) + + return os.path.relpath(csc_graph_path, os.path.dirname(part_config)) # Update graph path. + + # Iterate over partitions. + partition_paths = Parallel(n_jobs=min(num_parts, n_jobs))( + delayed(convert_partition)(part_id, graph_formats) for part_id in range(num_parts) + ) + + for part_id, part_path in enumerate(partition_paths): new_part_meta[f"part-{part_id}"][ "part_graph_graphbolt" - ] = os.path.relpath(csc_graph_path, os.path.dirname(part_config)) + ] = part_path + # Save dtype info into partition config. # [TODO][Rui] Always use int64_t for node/edge IDs in GraphBolt. See more