Skip to content

Commit 3eaa07c

Browse files
committed
[Bugfix] Fix UVA sampling with partially specified node types (#3897)
* fix uva with partial node types * lint * skip tensorflow unit test
1 parent 34bcb11 commit 3eaa07c

File tree

4 files changed

+59
-34
lines changed

4 files changed

+59
-34
lines changed

python/dgl/subgraph.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from . import backend as F
1111
from . import graph_index
1212
from . import heterograph_index
13-
from . import ndarray as nd
1413
from .heterograph import DGLHeteroGraph
1514
from . import utils
1615
from .utils import recursive_apply, context_of
@@ -142,12 +141,12 @@ def _process_nodes(ntype, v):
142141
return F.astype(F.nonzero_1d(F.copy_to(v, graph.device)), graph.idtype)
143142
else:
144143
return utils.prepare_tensor(graph, v, 'nodes["{}"]'.format(ntype))
144+
nodes = {ntype: _process_nodes(ntype, v) for ntype, v in nodes.items()}
145+
device = context_of(nodes)
145146

146-
induced_nodes = []
147-
for ntype in graph.ntypes:
148-
nids = nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), graph.device))
149-
induced_nodes.append(_process_nodes(ntype, nids))
150-
device = context_of(induced_nodes)
147+
induced_nodes = [
148+
nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), device))
149+
for ntype in graph.ntypes]
151150
sgi = graph._graph.node_subgraph(induced_nodes, relabel_nodes)
152151
induced_edges = sgi.induced_edges
153152
# (BarclayII) should not write induced_nodes = sgi.induced_nodes due to the same
@@ -301,13 +300,13 @@ def _process_edges(etype, e):
301300
return F.astype(F.nonzero_1d(F.copy_to(e, graph.device)), graph.idtype)
302301
else:
303302
return utils.prepare_tensor(graph, e, 'edges["{}"]'.format(etype))
304-
305303
edges = {graph.to_canonical_etype(etype): e for etype, e in edges.items()}
306-
induced_edges = []
307-
for cetype in graph.canonical_etypes:
308-
eids = edges.get(cetype, F.copy_to(F.tensor([], graph.idtype), graph.device))
309-
induced_edges.append(_process_edges(cetype, eids))
310-
device = context_of(induced_edges)
304+
edges = {etype: _process_edges(etype, e) for etype, e in edges.items()}
305+
device = context_of(edges)
306+
induced_edges = [
307+
edges.get(cetype, F.copy_to(F.tensor([], graph.idtype), device))
308+
for cetype in graph.canonical_etypes]
309+
311310
sgi = graph._graph.edge_subgraph(induced_edges, not relabel_nodes)
312311
induced_nodes_or_device = sgi.induced_nodes if relabel_nodes else device
313312
subg = _create_hetero_subgraph(
@@ -430,12 +429,9 @@ def in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_dev
430429
nodes = {graph.ntypes[0] : nodes}
431430
nodes = utils.prepare_tensor_dict(graph, nodes, 'nodes')
432431
device = context_of(nodes)
433-
nodes_all_types = []
434-
for ntype in graph.ntypes:
435-
if ntype in nodes:
436-
nodes_all_types.append(F.to_dgl_nd(nodes[ntype]))
437-
else:
438-
nodes_all_types.append(nd.NULL[graph._idtype_str])
432+
nodes_all_types = [
433+
F.to_dgl_nd(nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), device)))
434+
for ntype in graph.ntypes]
439435

440436
sgi = _CAPI_DGLInSubgraph(graph._graph, nodes_all_types, relabel_nodes)
441437
induced_nodes_or_device = sgi.induced_nodes if relabel_nodes else device
@@ -560,12 +556,9 @@ def out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_de
560556
nodes = {graph.ntypes[0] : nodes}
561557
nodes = utils.prepare_tensor_dict(graph, nodes, 'nodes')
562558
device = context_of(nodes)
563-
nodes_all_types = []
564-
for ntype in graph.ntypes:
565-
if ntype in nodes:
566-
nodes_all_types.append(F.to_dgl_nd(nodes[ntype]))
567-
else:
568-
nodes_all_types.append(nd.NULL[graph._idtype_str])
559+
nodes_all_types = [
560+
F.to_dgl_nd(nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), device)))
561+
for ntype in graph.ntypes]
569562

570563
sgi = _CAPI_DGLOutSubgraph(graph._graph, nodes_all_types, relabel_nodes)
571564
induced_nodes_or_device = sgi.induced_nodes if relabel_nodes else device
@@ -693,7 +686,8 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, out
693686

694687
last_hop_nodes = nodes
695688
k_hop_nodes_ = [last_hop_nodes]
696-
place_holder = F.copy_to(F.tensor([], dtype=graph.idtype), graph.device)
689+
device = context_of(nodes)
690+
place_holder = F.copy_to(F.tensor([], dtype=graph.idtype), device)
697691
for _ in range(k):
698692
current_hop_nodes = {nty: [] for nty in graph.ntypes}
699693
for cetype in graph.canonical_etypes:
@@ -853,7 +847,8 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, ou
853847

854848
last_hop_nodes = nodes
855849
k_hop_nodes_ = [last_hop_nodes]
856-
place_holder = F.copy_to(F.tensor([], dtype=graph.idtype), graph.device)
850+
device = context_of(nodes)
851+
place_holder = F.copy_to(F.tensor([], dtype=graph.idtype), device)
857852
for _ in range(k):
858853
current_hop_nodes = {nty: [] for nty in graph.ntypes}
859854
for cetype in graph.canonical_etypes:

src/graph/subgraph.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ HeteroSubgraph InEdgeGraphRelabelNodes(
1313
CHECK_EQ(vids.size(), graph->NumVertexTypes())
1414
<< "Invalid input: the input list size must be the same as the number of vertex types.";
1515
std::vector<IdArray> eids(graph->NumEdgeTypes());
16+
DLContext ctx = aten::GetContextOf(vids);
1617
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
1718
auto pair = graph->meta_graph()->FindEdge(etype);
1819
const dgl_type_t dst_vtype = pair.second;
1920
if (aten::IsNullArray(vids[dst_vtype])) {
20-
eids[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
21+
eids[etype] = IdArray::Empty({0}, graph->DataType(), ctx);
2122
} else {
2223
const auto& earr = graph->InEdges(etype, {vids[dst_vtype]});
2324
eids[etype] = earr.id;
@@ -33,6 +34,7 @@ HeteroSubgraph InEdgeGraphNoRelabelNodes(
3334
<< "Invalid input: the input list size must be the same as the number of vertex types.";
3435
std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
3536
std::vector<IdArray> induced_edges(graph->NumEdgeTypes());
37+
DLContext ctx = aten::GetContextOf(vids);
3638
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
3739
auto pair = graph->meta_graph()->FindEdge(etype);
3840
const dgl_type_t src_vtype = pair.first;
@@ -44,7 +46,7 @@ HeteroSubgraph InEdgeGraphNoRelabelNodes(
4446
relgraph->NumVertexTypes(),
4547
graph->NumVertices(src_vtype),
4648
graph->NumVertices(dst_vtype),
47-
graph->DataType(), graph->Context());
49+
graph->DataType(), ctx);
4850
induced_edges[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
4951
} else {
5052
const auto& earr = graph->InEdges(etype, {vids[dst_vtype]});
@@ -77,11 +79,12 @@ HeteroSubgraph OutEdgeGraphRelabelNodes(
7779
CHECK_EQ(vids.size(), graph->NumVertexTypes())
7880
<< "Invalid input: the input list size must be the same as the number of vertex types.";
7981
std::vector<IdArray> eids(graph->NumEdgeTypes());
82+
DLContext ctx = aten::GetContextOf(vids);
8083
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
8184
auto pair = graph->meta_graph()->FindEdge(etype);
8285
const dgl_type_t src_vtype = pair.first;
8386
if (aten::IsNullArray(vids[src_vtype])) {
84-
eids[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
87+
eids[etype] = IdArray::Empty({0}, graph->DataType(), ctx);
8588
} else {
8689
const auto& earr = graph->OutEdges(etype, {vids[src_vtype]});
8790
eids[etype] = earr.id;
@@ -97,6 +100,7 @@ HeteroSubgraph OutEdgeGraphNoRelabelNodes(
97100
<< "Invalid input: the input list size must be the same as the number of vertex types.";
98101
std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
99102
std::vector<IdArray> induced_edges(graph->NumEdgeTypes());
103+
DLContext ctx = aten::GetContextOf(vids);
100104
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
101105
auto pair = graph->meta_graph()->FindEdge(etype);
102106
const dgl_type_t src_vtype = pair.first;
@@ -108,7 +112,7 @@ HeteroSubgraph OutEdgeGraphNoRelabelNodes(
108112
relgraph->NumVertexTypes(),
109113
graph->NumVertices(src_vtype),
110114
graph->NumVertices(dst_vtype),
111-
graph->DataType(), graph->Context());
115+
graph->DataType(), ctx);
112116
induced_edges[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
113117
} else {
114118
const auto& earr = graph->OutEdges(etype, {vids[src_vtype]});

src/graph/unit_graph.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,8 @@ class UnitGraph::COO : public BaseHeteroGraph {
381381
CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
382382
HeteroSubgraph subg;
383383
const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);
384-
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
384+
DLContext ctx = aten::GetContextOf(vids);
385+
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
385386
subg.graph = std::make_shared<COO>(meta_graph(), submat.num_rows, submat.num_cols,
386387
submat.row, submat.col);
387388
subg.induced_vertices = vids;
@@ -801,7 +802,8 @@ class UnitGraph::CSR : public BaseHeteroGraph {
801802
CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
802803
HeteroSubgraph subg;
803804
const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);
804-
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
805+
DLContext ctx = aten::GetContextOf(vids);
806+
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
805807
subg.graph = std::make_shared<CSR>(meta_graph(), submat.num_rows, submat.num_cols,
806808
submat.indptr, submat.indices, sub_eids);
807809
subg.induced_vertices = vids;

tests/compute/test_subgraph.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def create_test_heterograph(idtype):
100100
('user', 'wishes', 'game'): ([0, 2], [1, 0]),
101101
('developer', 'develops', 'game'): ([0, 1], [0, 1])
102102
}, idtype=idtype, device=F.ctx())
103+
for etype in g.etypes:
104+
g.edges[etype].data['weight'] = F.randn((g.num_edges(etype),))
103105
assert g.idtype == idtype
104106
assert g.device == F.ctx()
105107
return g
@@ -629,6 +631,28 @@ def test_subframes(parent_idx_device, child_device):
629631
if parent_device == 'uva':
630632
g.unpin_memory_()
631633

634+
@unittest.skipIf(F._default_context_str != "gpu", reason="UVA only available on GPU")
635+
@pytest.mark.parametrize('device', [F.cpu(), F.cuda()])
636+
@parametrize_dtype
637+
def test_uva_subgraph(idtype, device):
638+
g = create_test_heterograph(idtype)
639+
g = g.to(F.cpu())
640+
g.create_formats_()
641+
g.pin_memory_()
642+
indices = {'user': F.copy_to(F.tensor([0], idtype), device)}
643+
edge_indices = {'follows': F.copy_to(F.tensor([0], idtype), device)}
644+
assert g.subgraph(indices).device == device
645+
assert g.edge_subgraph(edge_indices).device == device
646+
assert g.in_subgraph(indices).device == device
647+
assert g.out_subgraph(indices).device == device
648+
if dgl.backend.backend_name != 'tensorflow':
649+
# (BarclayII) Most of Tensorflow functions somehow do not preserve device: a CPU tensor
650+
# becomes a GPU tensor after operations such as concat(), unique() or even sin().
651+
# Not sure what should be the best fix.
652+
assert g.khop_in_subgraph(indices, 1)[0].device == device
653+
assert g.khop_out_subgraph(indices, 1)[0].device == device
654+
assert g.sample_neighbors(indices, 1).device == device
655+
g.unpin_memory_()
656+
632657
if __name__ == '__main__':
633-
test_khop_out_subgraph(F.int64)
634-
test_subframes(('cpu', F.cpu()), F.cuda())
658+
test_uva_subgraph(F.int64, F.cpu())

0 commit comments

Comments
 (0)