Skip to content

Commit

Permalink
[GraphBolt] Allow using multiple processes for GraphBolt partition co…
Browse files Browse the repository at this point in the history
…nversion
  • Loading branch information
thvasilo committed Jul 1, 2024
1 parent cbad2f0 commit d38b519
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time

import numpy as np

Check warning on line 9 in python/dgl/distributed/partition.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
from joblib import Parallel, delayed

import torch

Expand Down Expand Up @@ -1384,6 +1385,7 @@ def dgl_partition_to_graphbolt(
store_inner_node=False,
store_inner_edge=False,
graph_formats=None,
n_jobs=1,
):
"""Convert partitions of dgl to FusedCSCSamplingGraph of GraphBolt.
Expand Down Expand Up @@ -1411,6 +1413,9 @@ def dgl_partition_to_graphbolt(
specifying `coo` format to save edge ID mapping and destination node
IDs. If not specified, whether to save `coo` format is determined by
the availability of the format in DGL partitions. Default: None.
n_jobs: int
Number of parallel jobs to run during partition conversion. Max parallelism
is determined by the partition count.
"""
debug_mode = "DGL_DIST_DEBUG" in os.environ
if debug_mode:
Expand Down Expand Up @@ -1439,8 +1444,7 @@ def init_type_per_edge(graph, gpb):
# But this is not a problem since such information is not used in sampling.
# We can simply pass None to it.

# Iterate over partitions.
for part_id in range(num_parts):
def convert_partition(part_id, graph_formats):
graph, _, _, gpb, _, _, _ = load_partition(
part_config, part_id, load_feats=False
)
Expand Down Expand Up @@ -1564,10 +1568,20 @@ def init_type_per_edge(graph, gpb):
)
torch.save(csc_graph, csc_graph_path)


return os.path.relpath(csc_graph_path, os.path.dirname(part_config))
# Update graph path.

# Iterate over partitions.
partition_paths = Parallel(n_jobs=min(num_parts, n_jobs))(
delayed(convert_partition)(part_id, graph_formats) for part_id in range(num_parts)
)

for part_id, part_path in enumerate(partition_paths):
new_part_meta[f"part-{part_id}"][
"part_graph_graphbolt"
] = os.path.relpath(csc_graph_path, os.path.dirname(part_config))
] = part_path


# Save dtype info into partition config.
# [TODO][Rui] Always use int64_t for node/edge IDs in GraphBolt. See more
Expand Down

0 comments on commit d38b519

Please sign in to comment.