|
10 | 10 | from . import backend as F
|
11 | 11 | from . import graph_index
|
12 | 12 | from . import heterograph_index
|
13 |
| -from . import ndarray as nd |
14 | 13 | from .heterograph import DGLHeteroGraph
|
15 | 14 | from . import utils
|
16 | 15 | from .utils import recursive_apply, context_of
|
@@ -142,12 +141,12 @@ def _process_nodes(ntype, v):
|
142 | 141 | return F.astype(F.nonzero_1d(F.copy_to(v, graph.device)), graph.idtype)
|
143 | 142 | else:
|
144 | 143 | return utils.prepare_tensor(graph, v, 'nodes["{}"]'.format(ntype))
|
| 144 | + nodes = {ntype: _process_nodes(ntype, v) for ntype, v in nodes.items()} |
| 145 | + device = context_of(nodes) |
145 | 146 |
|
146 |
| - induced_nodes = [] |
147 |
| - for ntype in graph.ntypes: |
148 |
| - nids = nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), graph.device)) |
149 |
| - induced_nodes.append(_process_nodes(ntype, nids)) |
150 |
| - device = context_of(induced_nodes) |
| 147 | + induced_nodes = [ |
| 148 | + nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), device)) |
| 149 | + for ntype in graph.ntypes] |
151 | 150 | sgi = graph._graph.node_subgraph(induced_nodes, relabel_nodes)
|
152 | 151 | induced_edges = sgi.induced_edges
|
153 | 152 | # (BarclayII) should not write induced_nodes = sgi.induced_nodes due to the same
|
@@ -301,13 +300,13 @@ def _process_edges(etype, e):
|
301 | 300 | return F.astype(F.nonzero_1d(F.copy_to(e, graph.device)), graph.idtype)
|
302 | 301 | else:
|
303 | 302 | return utils.prepare_tensor(graph, e, 'edges["{}"]'.format(etype))
|
304 |
| - |
305 | 303 | edges = {graph.to_canonical_etype(etype): e for etype, e in edges.items()}
|
306 |
| - induced_edges = [] |
307 |
| - for cetype in graph.canonical_etypes: |
308 |
| - eids = edges.get(cetype, F.copy_to(F.tensor([], graph.idtype), graph.device)) |
309 |
| - induced_edges.append(_process_edges(cetype, eids)) |
310 |
| - device = context_of(induced_edges) |
| 304 | + edges = {etype: _process_edges(etype, e) for etype, e in edges.items()} |
| 305 | + device = context_of(edges) |
| 306 | + induced_edges = [ |
| 307 | + edges.get(cetype, F.copy_to(F.tensor([], graph.idtype), device)) |
| 308 | + for cetype in graph.canonical_etypes] |
| 309 | + |
311 | 310 | sgi = graph._graph.edge_subgraph(induced_edges, not relabel_nodes)
|
312 | 311 | induced_nodes_or_device = sgi.induced_nodes if relabel_nodes else device
|
313 | 312 | subg = _create_hetero_subgraph(
|
@@ -430,12 +429,9 @@ def in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_dev
|
430 | 429 | nodes = {graph.ntypes[0] : nodes}
|
431 | 430 | nodes = utils.prepare_tensor_dict(graph, nodes, 'nodes')
|
432 | 431 | device = context_of(nodes)
|
433 |
| - nodes_all_types = [] |
434 |
| - for ntype in graph.ntypes: |
435 |
| - if ntype in nodes: |
436 |
| - nodes_all_types.append(F.to_dgl_nd(nodes[ntype])) |
437 |
| - else: |
438 |
| - nodes_all_types.append(nd.NULL[graph._idtype_str]) |
| 432 | + nodes_all_types = [ |
| 433 | + F.to_dgl_nd(nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), device))) |
| 434 | + for ntype in graph.ntypes] |
439 | 435 |
|
440 | 436 | sgi = _CAPI_DGLInSubgraph(graph._graph, nodes_all_types, relabel_nodes)
|
441 | 437 | induced_nodes_or_device = sgi.induced_nodes if relabel_nodes else device
|
@@ -560,12 +556,9 @@ def out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_de
|
560 | 556 | nodes = {graph.ntypes[0] : nodes}
|
561 | 557 | nodes = utils.prepare_tensor_dict(graph, nodes, 'nodes')
|
562 | 558 | device = context_of(nodes)
|
563 |
| - nodes_all_types = [] |
564 |
| - for ntype in graph.ntypes: |
565 |
| - if ntype in nodes: |
566 |
| - nodes_all_types.append(F.to_dgl_nd(nodes[ntype])) |
567 |
| - else: |
568 |
| - nodes_all_types.append(nd.NULL[graph._idtype_str]) |
| 559 | + nodes_all_types = [ |
| 560 | + F.to_dgl_nd(nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), device))) |
| 561 | + for ntype in graph.ntypes] |
569 | 562 |
|
570 | 563 | sgi = _CAPI_DGLOutSubgraph(graph._graph, nodes_all_types, relabel_nodes)
|
571 | 564 | induced_nodes_or_device = sgi.induced_nodes if relabel_nodes else device
|
@@ -693,7 +686,8 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, out
|
693 | 686 |
|
694 | 687 | last_hop_nodes = nodes
|
695 | 688 | k_hop_nodes_ = [last_hop_nodes]
|
696 |
| - place_holder = F.copy_to(F.tensor([], dtype=graph.idtype), graph.device) |
| 689 | + device = context_of(nodes) |
| 690 | + place_holder = F.copy_to(F.tensor([], dtype=graph.idtype), device) |
697 | 691 | for _ in range(k):
|
698 | 692 | current_hop_nodes = {nty: [] for nty in graph.ntypes}
|
699 | 693 | for cetype in graph.canonical_etypes:
|
@@ -853,7 +847,8 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, ou
|
853 | 847 |
|
854 | 848 | last_hop_nodes = nodes
|
855 | 849 | k_hop_nodes_ = [last_hop_nodes]
|
856 |
| - place_holder = F.copy_to(F.tensor([], dtype=graph.idtype), graph.device) |
| 850 | + device = context_of(nodes) |
| 851 | + place_holder = F.copy_to(F.tensor([], dtype=graph.idtype), device) |
857 | 852 | for _ in range(k):
|
858 | 853 | current_hop_nodes = {nty: [] for nty in graph.ntypes}
|
859 | 854 | for cetype in graph.canonical_etypes:
|
|
0 commit comments