diff --git a/CHANGELOG.md b/CHANGELOG.md index 59065b21..ef8743ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [Unreleased] ### Added +- Added `hetero_subgraph` kernel ([#43](https://github.com/pyg-team/pyg-lib/pull/43) - Added `pyg::sampler::Mapper` utility for mapping global to local node indices ([#45](https://github.com/pyg-team/pyg-lib/pull/45) - Added benchmark script ([#45](https://github.com/pyg-team/pyg-lib/pull/45) - Added download script for benchmark data ([#44](https://github.com/pyg-team/pyg-lib/pull/44) diff --git a/pyg_lib/csrc/sampler/cpu/mapper.h b/pyg_lib/csrc/sampler/cpu/mapper.h index 47144ea6..a54fe0f8 100644 --- a/pyg_lib/csrc/sampler/cpu/mapper.h +++ b/pyg_lib/csrc/sampler/cpu/mapper.h @@ -23,11 +23,13 @@ class Mapper { void fill(const scalar_t* nodes_data, const scalar_t size) { if (use_vec) { - for (scalar_t i = 0; i < size; ++i) + for (scalar_t i = 0; i < size; ++i) { to_local_vec[nodes_data[i]] = i; + } } else { - for (scalar_t i = 0; i < size; ++i) + for (scalar_t i = 0; i < size; ++i) { to_local_map.insert({nodes_data[i], i}); + } } } @@ -35,14 +37,14 @@ class Mapper { fill(nodes.data_ptr(), nodes.numel()); } - bool exists(const scalar_t& node) { + bool exists(const scalar_t& node) const { if (use_vec) return to_local_vec[node] >= 0; else return to_local_map.count(node) > 0; } - scalar_t map(const scalar_t& node) { + scalar_t map(const scalar_t& node) const { if (use_vec) return to_local_vec[node]; else { diff --git a/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp b/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp index 9e67cdde..98330cca 100644 --- a/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp @@ -3,6 +3,7 @@ #include #include "pyg_lib/csrc/sampler/cpu/mapper.h" +#include "pyg_lib/csrc/sampler/subgraph.h" #include "pyg_lib/csrc/utils/cpu/convert.h" namespace pyg { @@ -10,89 +11,131 @@ namespace sampler { namespace { -std::tuple> subgraph_kernel( - const at::Tensor& rowptr, - const at::Tensor& col, - const at::Tensor& nodes, - const bool return_edge_id) { - TORCH_CHECK(rowptr.is_cpu(), "'rowptr' must be a CPU tensor"); - TORCH_CHECK(col.is_cpu(), "'col' must be a CPU tensor"); - TORCH_CHECK(nodes.is_cpu(), "'nodes' must be a CPU tensor"); - +template +std::tuple> +subgraph_with_mapper(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& nodes, + const Mapper& mapper, + const bool return_edge_id) { const auto num_nodes = rowptr.size(0) - 1; const auto out_rowptr = rowptr.new_empty({nodes.size(0) + 1}); at::Tensor out_col; c10::optional out_edge_id = c10::nullopt; - AT_DISPATCH_INTEGRAL_TYPES(nodes.scalar_type(), "subgraph_kernel", [&] { - auto mapper = pyg::sampler::Mapper(num_nodes, nodes.size(0)); - mapper.fill(nodes); - - const auto rowptr_data = rowptr.data_ptr(); - const auto col_data = col.data_ptr(); - const auto nodes_data = nodes.data_ptr(); - - // We first iterate over all nodes and collect information about the number - // of edges in the induced subgraph. - const auto deg = rowptr.new_empty({nodes.size(0)}); - auto deg_data = deg.data_ptr(); - auto grain_size = at::internal::GRAIN_SIZE; - at::parallel_for(0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { - for (size_t i = _s; i < _e; ++i) { - const auto v = nodes_data[i]; - // Iterate over all neighbors and check if they are part of `nodes`: - scalar_t d = 0; - for (size_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { - if (mapper.exists(col_data[j])) - d++; - } - deg_data[i] = d; - } - }); - - auto out_rowptr_data = out_rowptr.data_ptr(); - out_rowptr_data[0] = 0; - auto tmp = out_rowptr.narrow(0, 1, nodes.size(0)); - at::cumsum_out(tmp, deg, /*dim=*/0); - - out_col = col.new_empty({out_rowptr_data[nodes.size(0)]}); - auto out_col_data = out_col.data_ptr(); - scalar_t* out_edge_id_data; - if (return_edge_id) { - out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]}); - out_edge_id_data = out_edge_id.value().data_ptr(); - } - - // Customize `grain_size` based on the work each thread does (it will need - // to find `col.size(0) / nodes.size(0)` neighbors on average). - // TODO Benchmark this customization - grain_size = std::max(out_col.size(0) / nodes.size(0), 1); - grain_size = at::internal::GRAIN_SIZE / grain_size; - at::parallel_for(0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { - for (scalar_t i = _s; i < _e; ++i) { - const auto v = nodes_data[i]; - // Iterate over all neighbors and check if they are part of `nodes`: - scalar_t offset = out_rowptr_data[i]; - for (scalar_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { - const auto w = mapper.map(col_data[j]); - if (w >= 0) { - out_col_data[offset] = w; - if (return_edge_id) - out_edge_id_data[offset] = j; - offset++; - } + AT_DISPATCH_INTEGRAL_TYPES( + nodes.scalar_type(), "subgraph_kernel_with_mapper", [&] { + const auto rowptr_data = rowptr.data_ptr(); + const auto col_data = col.data_ptr(); + const auto nodes_data = nodes.data_ptr(); + + // We first iterate over all nodes and collect information about the + // number of edges in the induced subgraph. + const auto deg = rowptr.new_empty({nodes.size(0)}); + auto deg_data = deg.data_ptr(); + auto grain_size = at::internal::GRAIN_SIZE; + at::parallel_for( + 0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { + for (size_t i = _s; i < _e; ++i) { + const auto v = nodes_data[i]; + // Iterate over all neighbors and check if they are part of + // `nodes`: + scalar_t d = 0; + for (size_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { + if (mapper.exists(col_data[j])) + d++; + } + deg_data[i] = d; + } + }); + + auto out_rowptr_data = out_rowptr.data_ptr(); + out_rowptr_data[0] = 0; + auto tmp = out_rowptr.narrow(0, 1, nodes.size(0)); + at::cumsum_out(tmp, deg, /*dim=*/0); + + out_col = col.new_empty({out_rowptr_data[nodes.size(0)]}); + auto out_col_data = out_col.data_ptr(); + scalar_t* out_edge_id_data; + if (return_edge_id) { + out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]}); + out_edge_id_data = out_edge_id.value().data_ptr(); } - } - }); - }); + + // Customize `grain_size` based on the work each thread does (it will + // need to find `col.size(0) / nodes.size(0)` neighbors on average). + // TODO Benchmark this customization + grain_size = std::max(out_col.size(0) / nodes.size(0), 1); + grain_size = at::internal::GRAIN_SIZE / grain_size; + at::parallel_for( + 0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { + for (scalar_t i = _s; i < _e; ++i) { + const auto v = nodes_data[i]; + // Iterate over all neighbors and check if they + // are part of `nodes`: + scalar_t offset = out_rowptr_data[i]; + for (scalar_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { + const auto w = mapper.map(col_data[j]); + if (w >= 0) { + out_col_data[offset] = w; + if (return_edge_id) + out_edge_id_data[offset] = j; + offset++; + } + } + } + }); + }); return std::make_tuple(out_rowptr, out_col, out_edge_id); } +std::tuple> +subgraph_bipartite_kernel(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& src_nodes, + const at::Tensor& dst_nodes, + const bool return_edge_id) { + TORCH_CHECK(rowptr.is_cpu(), "'rowptr' must be a CPU tensor"); + TORCH_CHECK(col.is_cpu(), "'col' must be a CPU tensor"); + TORCH_CHECK(src_nodes.is_cpu(), "'src_nodes' must be a CPU tensor"); + TORCH_CHECK(dst_nodes.is_cpu(), "'dst_nodes' must be a CPU tensor"); + + const auto num_nodes = rowptr.size(0) - 1; + at::Tensor out_rowptr, out_col; + c10::optional out_edge_id; + + AT_DISPATCH_INTEGRAL_TYPES( + src_nodes.scalar_type(), "subgraph_bipartite_kernel", [&] { + // TODO: at::max parallel but still a little expensive + Mapper mapper(at::max(col).item() + 1, + dst_nodes.size(0)); + mapper.fill(dst_nodes); + + auto res = subgraph_with_mapper(rowptr, col, src_nodes, + mapper, return_edge_id); + out_rowptr = std::get<0>(res); + out_col = std::get<1>(res); + out_edge_id = std::get<2>(res); + }); + + return {out_rowptr, out_col, out_edge_id}; +} + +std::tuple> subgraph_kernel( + const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& nodes, + const bool return_edge_id) { + return subgraph_bipartite_kernel(rowptr, col, nodes, nodes, return_edge_id); +} + } // namespace TORCH_LIBRARY_IMPL(pyg, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("pyg::subgraph"), TORCH_FN(subgraph_kernel)); + m.impl(TORCH_SELECTIVE_NAME("pyg::subgraph_bipartite"), + TORCH_FN(subgraph_bipartite_kernel)); } } // namespace sampler diff --git a/pyg_lib/csrc/sampler/subgraph.cpp b/pyg_lib/csrc/sampler/subgraph.cpp index dbedb5b6..24316d9f 100644 --- a/pyg_lib/csrc/sampler/subgraph.cpp +++ b/pyg_lib/csrc/sampler/subgraph.cpp @@ -1,8 +1,11 @@ #include "subgraph.h" +#include #include #include +#include + namespace pyg { namespace sampler { @@ -11,7 +14,7 @@ std::tuple> subgraph( const at::Tensor& col, const at::Tensor& nodes, const bool return_edge_id) { - at::TensorArg rowptr_t{rowptr, "rowtpr", 1}; + at::TensorArg rowptr_t{rowptr, "rowptr", 1}; at::TensorArg col_t{col, "col", 1}; at::TensorArg nodes_t{nodes, "nodes", 1}; @@ -25,10 +28,76 @@ std::tuple> subgraph( return op.call(rowptr, col, nodes, return_edge_id); } +std::tuple> +subgraph_bipartite(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& src_nodes, + const at::Tensor& dst_nodes, + const bool return_edge_id) { + at::TensorArg rowptr_t{rowptr, "rowptr", 1}; + at::TensorArg col_t{col, "col", 1}; + at::TensorArg src_nodes_t{src_nodes, "src_nodes", 1}; + at::TensorArg dst_nodes_t{dst_nodes, "dst_nodes", 1}; + + at::CheckedFrom c = "subgraph_bipartite"; + at::checkAllDefined(c, {rowptr_t, col_t, src_nodes_t, dst_nodes_t}); + at::checkAllSameType(c, {rowptr_t, col_t, src_nodes_t, dst_nodes_t}); + + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("pyg::subgraph_bipartite", "") + .typed(); + return op.call(rowptr, col, src_nodes, dst_nodes, return_edge_id); +} + +c10::Dict>> +hetero_subgraph(const utils::EdgeTensorDict& rowptr, + const utils::EdgeTensorDict& col, + const utils::NodeTensorDict& src_nodes, + const utils::NodeTensorDict& dst_nodes, + const c10::Dict& return_edge_id) { + c10::Dict>> + res; + + // Construct dispatchable arguments + utils::HeteroDispatchArg + src_nodes_arg(src_nodes); + utils::HeteroDispatchArg + dst_nodes_arg(dst_nodes); + utils::HeteroDispatchArg, bool, + utils::EdgeMode> + edge_id_arg(return_edge_id); + + for (const auto& kv : rowptr) { + const auto& edge_type = kv.key(); + bool pass = src_nodes_arg.filter_by_edge(edge_type) && + dst_nodes_arg.filter_by_edge(edge_type) && + edge_id_arg.filter_by_edge(edge_type); + if (pass) { + const auto& r = rowptr.at(edge_type); + const auto& c = col.at(edge_type); + res.insert(edge_type, subgraph_bipartite( + r, c, src_nodes_arg.value_by_edge(edge_type), + dst_nodes_arg.value_by_edge(edge_type), + edge_id_arg.value_by_edge(edge_type))); + } + } + + return res; +} + TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( "pyg::subgraph(Tensor rowptr, Tensor col, Tensor " "nodes, bool return_edge_id) -> (Tensor, Tensor, Tensor?)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "pyg::subgraph_bipartite(Tensor rowptr, Tensor col, Tensor " + "src_nodes, Tensor dst_nodes, bool return_edge_id) -> (Tensor, Tensor, " + "Tensor?)")); + m.def("hetero_subgraph", hetero_subgraph); } } // namespace sampler diff --git a/pyg_lib/csrc/sampler/subgraph.h b/pyg_lib/csrc/sampler/subgraph.h index a2f8de54..ebb32e9d 100644 --- a/pyg_lib/csrc/sampler/subgraph.h +++ b/pyg_lib/csrc/sampler/subgraph.h @@ -1,7 +1,11 @@ #pragma once #include +#include + #include "pyg_lib/csrc/macros.h" +#include "pyg_lib/csrc/sampler/cpu/mapper.h" +#include "pyg_lib/csrc/utils/types.h" namespace pyg { namespace sampler { @@ -15,5 +19,23 @@ PYG_API std::tuple> subgraph( const at::Tensor& nodes, const bool return_edge_id = true); +// A bipartite version of the above function. +PYG_API std::tuple> +subgraph_bipartite(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& src_nodes, + const at::Tensor& dst_nodes, + const bool return_edge_id); + +// A heterogeneous version of the above function. +// Returns a dict from each relation type to its result +PYG_API c10::Dict>> +hetero_subgraph(const utils::EdgeTensorDict& rowptr, + const utils::EdgeTensorDict& col, + const utils::NodeTensorDict& src_nodes, + const utils::NodeTensorDict& dst_nodes, + const c10::Dict& return_edge_id); + } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/utils/hetero_dispatch.h b/pyg_lib/csrc/utils/hetero_dispatch.h new file mode 100644 index 00000000..2569bae7 --- /dev/null +++ b/pyg_lib/csrc/utils/hetero_dispatch.h @@ -0,0 +1,164 @@ +#pragma once + +#include "types.h" + +#include + +namespace pyg { + +namespace utils { + +// Base class for easier type check +struct HeteroDispatchMode {}; + +// List hetero dispatch mode as different types to avoid non-type template +// specialization. +struct SkipMode : public HeteroDispatchMode {}; + +struct NodeSrcMode : public HeteroDispatchMode {}; + +struct NodeDstMode : public HeteroDispatchMode {}; + +struct EdgeMode : public HeteroDispatchMode {}; + +// Check if the argument is a c10::dict so that is could be filtered by an edge +// type. +template +struct is_c10_dict : std::false_type {}; + +template +struct is_c10_dict> : std::true_type {}; + +// TODO: Should specialize as if-constexpr when in C++17 +template +class HeteroDispatchArg {}; + +// In SkipMode we do not filter this arg +template +class HeteroDispatchArg { + public: + using ValueType = V; + HeteroDispatchArg(const T& val) : val_(val) {} + + // If we pass the filter, we will obtain the value of the argument. + V value_by_edge(const EdgeType& edge) { return val_; } + + bool filter_by_edge(const EdgeType& edge) { return true; } + + private: + T val_; +}; + +// In NodeSrcMode we check if source node is in the dict +template +class HeteroDispatchArg { + public: + using ValueType = V; + HeteroDispatchArg(const T& val) : val_(val) { + static_assert(is_c10_dict::value, "Should be a c10::dict"); + } + + // Dict value lookup + V value_by_edge(const EdgeType& edge) { return val_.at(get_src(edge)); } + + // Dict if key exists + bool filter_by_edge(const EdgeType& edge) { + return val_.contains(get_src(edge)); + } + + private: + T val_; +}; + +// In NodeDstMode we check if destination node is in the dict +template +class HeteroDispatchArg { + public: + using ValueType = V; + HeteroDispatchArg(const T& val) : val_(val) { + static_assert(is_c10_dict::value, "Should be a c10::dict"); + } + + V value_by_edge(const EdgeType& edge) { return val_.at(get_dst(edge)); } + + bool filter_by_edge(const EdgeType& edge) { + return val_.contains(get_dst(edge)); + } + + private: + T val_; +}; + +// In EdgeMode we check if edge is in the dict +template +class HeteroDispatchArg { + public: + using ValueType = V; + HeteroDispatchArg(const T& val) : val_(val) { + static_assert(is_c10_dict::value, "Should be a c10::dict"); + } + + V value_by_edge(const EdgeType& edge) { return val_.at(edge); } + + bool filter_by_edge(const EdgeType& edge) { return val_.contains(edge); } + + private: + T val_; +}; + +// The following will help static type checks: +template +struct is_hetero_arg : std::false_type {}; + +// Just check inheritance, a workaround without introducing concepts +template +struct is_hetero_arg> : std::true_type { + static_assert(std::is_base_of::value, + "Must pass a mode for dispatching"); +}; + +// Specialize +template +bool filter_args_by_edge(const EdgeType& edge, Args&&... args) {} + +// Stop condition of argument filtering +template <> +bool filter_args_by_edge(const EdgeType& edge) { + return true; +} + +// We filter each argument individually by the given edge using a variadic +// template +template +bool filter_args_by_edge(const EdgeType& edge, T&& t, Args&&... args) { + static_assert( + is_hetero_arg>>::value, + "args should be HeteroDispatchArg"); + return t.filter_by_edge(edge) && filter_args_by_edge(edge, args...); +} + +// Specialize +template +auto value_args_by_edge(const EdgeType& edge, Args&&... args) {} + +// Stop condition of argument filtering +template <> +auto value_args_by_edge(const EdgeType& edge) { + return std::tuple<>(); +} + +// We filter each argument individually by the given edge using a variadic +// template +template +auto value_args_by_edge(const EdgeType& edge, T&& t, Args&&... args) { + using ArgType = std::remove_const_t>; + static_assert(is_hetero_arg::value, + "args should be HeteroDispatchArg"); + return std::tuple_cat( + std::tuple(t.value_by_edge(edge)), + value_args_by_edge(edge, args...)); +} + +} // namespace utils + +} // namespace pyg diff --git a/pyg_lib/csrc/utils/types.h b/pyg_lib/csrc/utils/types.h new file mode 100644 index 00000000..7e407f9d --- /dev/null +++ b/pyg_lib/csrc/utils/types.h @@ -0,0 +1,34 @@ +#pragma once + +#include + +#include + +namespace pyg { +namespace utils { + +const std::string SPLIT_TOKEN = "__"; + +using EdgeType = std::string; +using NodeType = std::string; +using RelationType = std::string; + +using EdgeTensorDict = c10::Dict; +using NodeTensorDict = c10::Dict; + +inline NodeType get_src(const EdgeType& e) { + return e.substr(0, e.find_first_of(SPLIT_TOKEN)); +} + +inline RelationType get_rel(const EdgeType& e) { + auto beg = e.find_first_of(SPLIT_TOKEN) + SPLIT_TOKEN.size(); + return e.substr(beg, + e.find_last_of(SPLIT_TOKEN) - SPLIT_TOKEN.size() + 1 - beg); +} + +inline NodeType get_dst(const EdgeType& e) { + return e.substr(e.find_last_of(SPLIT_TOKEN) + 1); +} +} // namespace utils + +} // namespace pyg diff --git a/test/csrc/sampler/test_subgraph.cpp b/test/csrc/sampler/test_subgraph.cpp index ea923514..25c143c2 100644 --- a/test/csrc/sampler/test_subgraph.cpp +++ b/test/csrc/sampler/test_subgraph.cpp @@ -20,3 +20,59 @@ TEST(SubgraphTest, BasicAssertions) { auto expected_edge_id = at::tensor({3, 4, 5, 6, 7, 8}, options); EXPECT_TRUE(at::equal(std::get<2>(out).value(), expected_edge_id)); } + +TEST(HeteroSubgraphPassFilterTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + auto nodes = at::arange(1, 5, options); + auto graph = cycle_graph(/*num_nodes=*/6, options); + + pyg::utils::NodeType node_name = "node"; + pyg::utils::EdgeType edge_name = "node__to__node"; + + pyg::utils::EdgeTensorDict rowptr_dict; + rowptr_dict.insert(edge_name, std::get<0>(graph)); + pyg::utils::EdgeTensorDict col_dict; + col_dict.insert(edge_name, std::get<1>(graph)); + pyg::utils::EdgeTensorDict nodes_dict; + nodes_dict.insert(node_name, nodes); + c10::Dict edge_id_dict; + edge_id_dict.insert(edge_name, true); + + auto res = pyg::sampler::hetero_subgraph(rowptr_dict, col_dict, nodes_dict, + nodes_dict, edge_id_dict); + + EXPECT_EQ(res.size(), 1); + auto out = res.at(edge_name); + + auto expected_rowptr = at::tensor({0, 1, 3, 5, 6}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_rowptr)); + auto expected_col = at::tensor({1, 0, 2, 1, 3, 2}, options); + EXPECT_TRUE(at::equal(std::get<1>(out), expected_col)); + auto expected_edge_id = at::tensor({3, 4, 5, 6, 7, 8}, options); + EXPECT_TRUE(at::equal(std::get<2>(out).value(), expected_edge_id)); +} + +TEST(HeteroSubgraphFailFilterTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + auto nodes = at::arange(1, 5, options); + auto graph = cycle_graph(/*num_nodes=*/6, options); + + pyg::utils::NodeType node_name = "node"; + pyg::utils::EdgeType edge_name = "node123__to456__node321"; + + pyg::utils::EdgeTensorDict rowptr_dict; + rowptr_dict.insert(edge_name, std::get<0>(graph)); + pyg::utils::EdgeTensorDict col_dict; + col_dict.insert(edge_name, std::get<1>(graph)); + pyg::utils::EdgeTensorDict nodes_dict; + nodes_dict.insert(node_name, nodes); + c10::Dict edge_id_dict; + edge_id_dict.insert(edge_name, true); + + auto res = pyg::sampler::hetero_subgraph(rowptr_dict, col_dict, nodes_dict, + nodes_dict, edge_id_dict); + + EXPECT_EQ(res.size(), 0); +} diff --git a/test/csrc/utils/test_utils.cpp b/test/csrc/utils/test_utils.cpp new file mode 100644 index 00000000..cf5bf1fc --- /dev/null +++ b/test/csrc/utils/test_utils.cpp @@ -0,0 +1,15 @@ +#include + +#include "pyg_lib/csrc/utils/types.h" + +TEST(UtilsTypeTest, BasicAssertions) { + pyg::utils::EdgeType edge = "node1__to__node2"; + + auto src = pyg::utils::get_src(edge); + auto dst = pyg::utils::get_dst(edge); + auto rel = pyg::utils::get_rel(edge); + + EXPECT_EQ(src, std::string("node1")); + EXPECT_EQ(dst, std::string("node2")); + EXPECT_EQ(rel, std::string("to")); +}