Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new: preserve_adb_keys in PyG to ArangoDB #11

Merged
merged 70 commits into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
0bf2609
initial commit
aMahanna Aug 1, 2022
42bb841
temp: create py310 database for 3.10 testing
aMahanna Aug 1, 2022
d8d151b
temp: black hack
aMahanna Aug 1, 2022
d6dab74
Update conftest.py
aMahanna Aug 1, 2022
e91cff1
Update build.yml
aMahanna Aug 1, 2022
29eeb63
cleanup build.yml & release.yml
aMahanna Aug 1, 2022
fde393d
remove: temp conftest hack
aMahanna Aug 1, 2022
acbd39c
cleanup: abc
aMahanna Aug 1, 2022
97084a7
new: optional & partial edge collectiond data transfer (ArangoDB to PyG)
aMahanna Aug 1, 2022
eef7018
fix: black
aMahanna Aug 1, 2022
f5f3f0a
temp: disable HeterogeneousPartialEdgeCollectionImport
aMahanna Aug 1, 2022
2e2d6f4
new: test_adb_partial_to_pyg
aMahanna Aug 1, 2022
1ac09cf
new: query/dataframe optimization
aMahanna Aug 1, 2022
32a234b
cleanup
aMahanna Aug 2, 2022
78b15b2
Update README.md
aMahanna Aug 2, 2022
d4a4608
Update README.md
aMahanna Aug 2, 2022
6333dd4
cleanup: `validate_adb_metagraph`
aMahanna Aug 2, 2022
b6455a5
fix: black
aMahanna Aug 2, 2022
ca09c73
Merge branch 'master' into feature/adbpyg-map
aMahanna Aug 2, 2022
2b1d1ea
initial (experimental) commit
aMahanna Aug 2, 2022
4bb66f1
checkpoint
aMahanna Aug 3, 2022
aca385a
new: lazy attempt at #4
aMahanna Aug 3, 2022
6c5ed73
new: `preserve_adb_keys` docstring
aMahanna Aug 3, 2022
0249467
new: `pytest_exception_interact`
aMahanna Aug 3, 2022
84d62b1
temp: disable (partial) feature validation in `assert_arangodb_data`
aMahanna Aug 3, 2022
f91c8f0
cleanup: adapter.py
aMahanna Aug 3, 2022
a8b8b73
move: `preserve_adb_keys`
aMahanna Aug 3, 2022
e17be93
cleanup: `pyg_keys`
aMahanna Aug 3, 2022
63ea00d
new: test cases to cover `preserve_adb_keys`
aMahanna Aug 3, 2022
f787e64
temp: `# flake8: noqa`
aMahanna Aug 3, 2022
bf14832
debug: `pytest_exception_interact`
aMahanna Aug 3, 2022
03e6ea3
temp: fix cudf to_dict error
aMahanna Aug 3, 2022
17d008d
fix: black
aMahanna Aug 3, 2022
ab41595
fix: typo
aMahanna Aug 3, 2022
847ebae
remove: `cudf` imports
aMahanna Aug 3, 2022
a9734a9
fix: `test_adb_partial_to_pyg` RNG
aMahanna Aug 3, 2022
ef135a6
cleanup: `__finish_adb_dataframe` and `__build_dataframe_from_tensor`
aMahanna Aug 3, 2022
bd384b9
fix: map typings
aMahanna Aug 4, 2022
4d01224
new: test_adapter.py refactor
aMahanna Aug 4, 2022
a49d9d3
new: `preserve_adb_keys` refactor
aMahanna Aug 4, 2022
82e02f8
debug: test `HeterogeneousTurnedHomogeneous`
aMahanna Aug 4, 2022
79ae252
cleanup: test_adapter
aMahanna Aug 4, 2022
83799a7
update docstring
aMahanna Aug 4, 2022
d9073fe
fix: flake8
aMahanna Aug 4, 2022
e83bcd8
update: docstrings
aMahanna Aug 4, 2022
011911b
fix: default param value
aMahanna Aug 4, 2022
ee40271
fix: docstring
aMahanna Aug 4, 2022
f114b9a
cleanup
aMahanna Aug 4, 2022
c09987a
fix: black
aMahanna Aug 4, 2022
c4182c0
new: Full Cycle README section
aMahanna Aug 4, 2022
d7e2e19
update release.yml
aMahanna Aug 4, 2022
fa2e4cb
update `explicit_metagraph` docstring
aMahanna Aug 4, 2022
c56196b
Update README.md
aMahanna Aug 4, 2022
751fd28
move: __build_tensor_from_dataframe
aMahanna Aug 4, 2022
abf477b
bump
aMahanna Aug 4, 2022
d627902
Revert "bump"
aMahanna Aug 4, 2022
f47a12e
Update test_adapter.py
aMahanna Aug 4, 2022
c0eb5b4
new: `set[str]` metagraph value type
aMahanna Aug 5, 2022
4637506
fix: flake8
aMahanna Aug 5, 2022
123df0e
Update README.md
aMahanna Aug 5, 2022
e72a396
Update README.md
aMahanna Aug 5, 2022
4dcd032
cleanup: test_adapter
aMahanna Aug 5, 2022
ae47341
cleanup: progress bars
aMahanna Aug 5, 2022
df07def
update: documentation
aMahanna Aug 5, 2022
a292744
new: address comments
aMahanna Aug 5, 2022
7e31443
new: `test_full_cycle_homogeneous_with_preserve_adb_keys`
aMahanna Aug 5, 2022
c9f66a8
fix: black & mypy
aMahanna Aug 5, 2022
867729f
Update README.md
aMahanna Aug 5, 2022
92bf916
Update README.md
aMahanna Aug 5, 2022
2d5c135
new: adbpyg 1.1.0 notebook
aMahanna Aug 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 154 additions & 40 deletions adbpyg_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set, Union

try:
# https://github.com/arangoml/pyg-adapter/issues/4
from cudf import DataFrame
except ModuleNotFoundError:
from pandas import DataFrame
aMahanna marked this conversation as resolved.
Show resolved Hide resolved

from arango.database import Database
from arango.graph import Graph as ADBGraph
from pandas import DataFrame
from torch import Tensor, cat, tensor
from torch_geometric.data import Data, HeteroData
from torch_geometric.data.storage import EdgeStorage, NodeStorage
Expand All @@ -16,9 +21,11 @@
from .controller import ADBPyG_Controller
from .exceptions import ADBMetagraphError, PyGMetagraphError
from .typings import (
ADBMap,
ADBMetagraph,
ADBMetagraphValues,
Json,
PyGMap,
PyGMetagraph,
PyGMetagraphValues,
)
Expand Down Expand Up @@ -59,6 +66,12 @@ def __init__(
self.__db = db
self.__cntrl = controller

# Maps ArangoDB vertex keys to PyG node IDs
self.adb_map: ADBMap = defaultdict(lambda: defaultdict(dict))

# Maps PyG node IDs to ArangoDB vertex keys
self.pyg_map: PyGMap = defaultdict(lambda: defaultdict(dict))

logger.info(f"Instantiated ADBPyG_Adapter with database '{db.name}'")

@property
Expand Down Expand Up @@ -99,13 +112,14 @@ def arangodb_to_pyg(
The current supported **metagraph** values are:
1) str: The name of the ArangoDB attribute that stores your PyG-ready data

2) Dict[str, Callable[[pandas.DataFrame], torch.Tensor] | None]:
2) Dict[str, Callable[[(pandas | cudf).DataFrame], torch.Tensor] | None]:
A dictionary mapping ArangoDB attributes to a callable Python Class
(i.e has a `__call__` function defined), or to None
(if the ArangoDB attribute is already a list of numerics).

3) Callable[[pandas.DataFrame], torch.Tensor]: A user-defined function for
custom behaviour. NOTE: The function return type MUST be a tensor.
3) Callable[[(pandas | cudf).DataFrame], torch.Tensor]:
A user-defined function for custom behaviour.
NOTE: The function return type MUST be a tensor.

1)
.. code-block:: python
Expand Down Expand Up @@ -188,7 +202,7 @@ def udf_v1_x(v1_df):
}

The metagraph above provides an interface for a user-defined function to
build a PyG-ready Tensor from a Pandas DataFrame equivalent to the
build a PyG-ready Tensor from a DataFrame equivalent to the
associated ArangoDB collection.
"""
logger.debug(f"--arangodb_to_pyg('{name}')--")
Expand All @@ -200,21 +214,21 @@ def udf_v1_x(v1_df):
and len(metagraph["edgeCollections"]) == 1
)

# Maps ArangoDB vertex IDs to PyG node IDs
adb_map: Dict[str, Json] = dict()
adb_map = self.adb_map[name]

data = Data() if is_homogeneous else HeteroData()

for v_col, meta in metagraph["vertexCollections"].items():
logger.debug(f"Preparing '{v_col}' vertices")

df = self.__fetch_adb_docs(v_col, meta == {}, query_options)
adb_map.update({adb_id: pyg_id for pyg_id, adb_id in enumerate(df["_id"])})
adb_map[v_col] = {adb_id: pyg_id for pyg_id, adb_id in enumerate(df["_id"])}

node_data: NodeStorage = data if is_homogeneous else data[v_col]
for k, v in meta.items():
node_data[k] = self.__build_tensor_from_dataframe(df, k, v)

et_df: DataFrame
v_cols: List[str] = list(metagraph["vertexCollections"].keys())
for e_col, meta in metagraph.get("edgeCollections", {}).items():
logger.debug(f"Preparing '{e_col}' edges")
Expand All @@ -235,8 +249,12 @@ def udf_v1_x(v1_df):

# Get the edge data corresponding to the current edge type
et_df = df[(df["from_col"] == from_col) & (df["to_col"] == to_col)]
from_nodes = [adb_map[_id] for _id in et_df["_from"]]
to_nodes = [adb_map[_id] for _id in et_df["_to"]]
adb_map[edge_type] = {
adb_id: pyg_id for pyg_id, adb_id in enumerate(et_df["_id"])
}

from_nodes = et_df["_from"].map(adb_map[from_col]).tolist()
to_nodes = et_df["_to"].map(adb_map[to_col]).tolist()

edge_data: EdgeStorage = data if is_homogeneous else data[edge_type]
edge_data.edge_index = tensor([from_nodes, to_nodes])
Expand Down Expand Up @@ -306,6 +324,7 @@ def pyg_to_arangodb(
metagraph: PyGMetagraph = {},
explicit_metagraph: bool = True,
overwrite_graph: bool = False,
preserve_adb_keys: bool = False, # TODO: explain
**import_options: Any,
) -> ADBGraph:
"""Create an ArangoDB graph from a PyG graph.
Expand All @@ -326,6 +345,18 @@ def pyg_to_arangodb(
:param overwrite_graph: Overwrites the graph if it already exists.
Does not drop associated collections. Defaults to False.
:type overwrite_graph: bool
:param preserve_adb_keys: NOTE: EXPERIMENTAL FEATURE. USE AT OWN RISK.
If True, relies on **adbpyg_adapter.adb_map[**name**]** to map the
PyG Node & Edge IDs back into ArangoDB Vertex & Edge IDs. Assumes that
the user has a valid ADB Map for **name**.
An ADB Map can be built by running ArangoDB to PyG operation for
the same **name**, or by creating the map manually. Defaults to False.

.. code-block:: python
adbpyg_adapter.pyg_to_arangodb(
name, pyg_g, preserve_adb_keys=True, on_duplicate="update"
)
:type preserve_adb_keys: bool
:param import_options: Keyword arguments to specify additional
parameters for ArangoDB document insertion. Full parameter list:
https://docs.python-arango.com/en/main/specs.html#arango.collection.Collection.import_bulk
Expand All @@ -340,8 +371,9 @@ def pyg_to_arangodb(
2) List[str]: A list of ArangoDB attribute names that will break down
your tensor data to have one ArangoDB attribute per tensor value.

3) Callable[[torch.Tensor], pandas.DataFrame]: A user-defined function for
custom behaviour. NOTE: The function return type MUST be a DataFrame.
3) Callable[[torch.Tensor], (pandas | cudf).DataFrame]:
A user-defined function for custom behaviour.
NOTE: The function return type MUST be a DataFrame (pandas or cudf).

1) Here is an example entry for parameter **metagraph**:
.. code-block:: python
Expand All @@ -350,7 +382,7 @@ def v2_x_to_pandas_dataframe(t: Tensor):
# The parameter **t** is the tensor representing
# the feature matrix 'x' of the 'v2' node type.

df = pandas.DataFrame(columns=["v2_features"])
df = (pandas | cudf).DataFrame(columns=["v2_features"])
df["v2_features"] = t.tolist()
# do more things with df["v2_features"] here ...
return df
Expand Down Expand Up @@ -380,9 +412,31 @@ def v2_x_to_pandas_dataframe(t: Tensor):

is_homogeneous = type(pyg_g) is Data

pyg_map = self.pyg_map[name]
if preserve_adb_keys:
if is_homogeneous:
msg = """
**preserve_adb_keys** does not yet support
homogeneous graphs (i.e type(pyg_g) is Data).
"""
raise ValueError(msg)

if self.adb_map[name] == {}:
msg = f"""
Parameter **preserve_adb_keys** was enabled,
but no ArangoDB Map was found for graph {name} in
**self.adb_map**.
"""
raise ValueError(msg)

# Build the reverse map
for k, map in self.adb_map[name].items():
pyg_map[k].update({pyg_id: adb_id for adb_id, pyg_id in map.items()})

node_types: List[str]
edge_types: List[EdgeType]
if metagraph and explicit_metagraph:
explicit_metagraph = metagraph != {} and explicit_metagraph
if explicit_metagraph:
node_types = metagraph.get("nodeTypes", {}).keys() # type: ignore
edge_types = metagraph.get("edgeTypes", {}).keys() # type: ignore

Expand Down Expand Up @@ -415,13 +469,31 @@ def v2_x_to_pandas_dataframe(t: Tensor):
n_meta = metagraph.get("nodeTypes", {})
for n_type in node_types:
node_data = pyg_g if is_homogeneous else pyg_g[n_type]
df = DataFrame([{"_key": str(i)} for i in range(node_data.num_nodes)])
num_nodes = node_data.num_nodes

df = DataFrame(index=range(num_nodes))
if preserve_adb_keys:
num_node_keys = len(pyg_map[n_type])

if num_nodes != num_node_keys:
msg = f"""
{num_nodes} does not match
number of node keys in pyg_map
({num_node_keys}) for {n_type}
"""
raise ValueError(msg)

df["_id"] = df.index.map(pyg_map[n_type])
else:
df["_key"] = df.index.astype(str)

meta = n_meta.get(n_type, {})
for k, t in node_data.items():
if type(t) is Tensor and len(t) == node_data.num_nodes:
v = meta.get(k, k)
df = df.join(self.__build_dataframe_from_tensor(t, k, v))
pyg_keys = (
set(meta.keys())
if explicit_metagraph
else {k for k, _ in node_data.items()} # can't do node_data.keys()
)
df = self.__finish_adb_dataframe(df, meta, pyg_keys, node_data)

if type(self.__cntrl) is not ADBPyG_Controller:
f = lambda n: self.__cntrl._prepare_pyg_node(n, n_type)
Expand All @@ -436,17 +508,21 @@ def v2_x_to_pandas_dataframe(t: Tensor):

columns = ["_from", "_to"]
df = DataFrame(zip(*(edge_data.edge_index.tolist())), columns=columns)
df["_from"] = from_col + "/" + df["_from"].astype(str)
df["_to"] = to_col + "/" + df["_to"].astype(str)
if preserve_adb_keys:
df["_id"] = df.index.map(pyg_map[e_type])
df["_from"] = df["_from"].map(pyg_map[from_col])
df["_to"] = df["_to"].map(pyg_map[to_col])
else:
df["_from"] = from_col + "/" + df["_from"].astype(str)
df["_to"] = to_col + "/" + df["_to"].astype(str)

meta = e_meta.get(e_type, {})
for k, t in edge_data.items():
if k == "edge_index":
continue

if type(t) is Tensor and len(t) == edge_data.num_edges:
v = meta.get(k, k)
df = df.join(self.__build_dataframe_from_tensor(t, k, v))
pyg_keys = (
set(meta.keys())
if explicit_metagraph
else {k for k, _ in edge_data.items()} # can't do edge_data.keys()
)
df = self.__finish_adb_dataframe(df, meta, pyg_keys, edge_data)

if type(self.__cntrl) is not ADBPyG_Controller:
f = lambda e: self.__cntrl._prepare_pyg_edge(e, e_type)
Expand Down Expand Up @@ -527,7 +603,7 @@ def __fetch_adb_docs(
self, col: str, empty_meta: bool, query_options: Any
) -> DataFrame:
"""Fetches ArangoDB documents within a collection. Returns the
documents in a Pandas DataFrame.
documents in a DataFrame.

:param col: The ArangoDB collection.
:type col: str
Expand All @@ -537,8 +613,8 @@ def __fetch_adb_docs(
:param query_options: Keyword arguments to specify AQL query options
when fetching documents from the ArangoDB instance.
:type query_options: Any
:return: A Pandas DataFrame representing the ArangoDB documents.
:rtype: pandas.DataFrame
:return: A DataFrame representing the ArangoDB documents.
:rtype: (pandas | cudf).DataFrame
"""
# Only return the entire document if **empty_meta** is False
data = "{_id: doc._id, _from: doc._from, _to: doc._to}" if empty_meta else "doc"
Expand Down Expand Up @@ -591,11 +667,11 @@ def __build_tensor_from_dataframe(
meta_key: str,
meta_val: ADBMetagraphValues,
) -> Tensor:
"""Constructs a PyG-ready Tensor from a Pandas Dataframe, based on
"""Constructs a PyG-ready Tensor from a DataFrame, based on
the nature of the user-defined metagraph.

:param adb_df: The Pandas Dataframe representing ArangoDB data.
:type adb_df: pandas.DataFrame
:param adb_df: The DataFrame representing ArangoDB data.
:type adb_df: (pandas | cudf).DataFrame
:param meta_key: The current ArangoDB-PyG metagraph key
:type meta_key: str
:param meta_val: The value mapped to **meta_key** to
Expand Down Expand Up @@ -636,25 +712,63 @@ def __build_tensor_from_dataframe(

raise ADBMetagraphError(f"Invalid {meta_val} type") # pragma: no cover

def __finish_adb_dataframe(
self,
df: DataFrame,
meta: Dict[Any, PyGMetagraphValues],
pyg_keys: Set[Any],
pyg_data: Union[NodeStorage, EdgeStorage],
) -> DataFrame:
"""A helper method to complete the ArangoDB Dataframe for the given
collection. Is responsible for creating DataFrames from PyG tensors,
and appending them to the main dataframe **df**.

:param df: The main ArangoDB DataFrame containing (at minimum)
the vertex/edge _id or _key attribute.
:type df: (pandas | cudf).DataFrame
:param meta: The metagraph associated to the
current PyG node or edge type.
:type meta: Dict[Any, adbpyg_adapter.typings.PyGMetagraphValues]
:param pyg_keys: The set of PyG NodeStorage or EdgeStorage keys, retrieved
either from the **meta** parameter (if **explicit_metagraph** is True),
or from the **pyg_data** parameter (if **explicit_metagraph** is False).
:type pyg_keys: Set[Any]
:param pyg_data: The NodeStorage or EdgeStorage of the current
PyG node or edge type.
:type pyg_data: torch_geometric.data.storage.(NodeStorage | EdgeStorage)
:return: The completed DataFrame for the (soon-to-be) ArangoDB collection.
:rtype: (pandas | cudf).DataFrame
"""
for k in pyg_keys:
if k == "edge_index":
continue

t = pyg_data[k]
if type(t) is Tensor and len(t) == len(df):
v = meta.get(k, str(k))
df = df.join(self.__build_dataframe_from_tensor(t, k, v))

return df

def __build_dataframe_from_tensor(
self,
pyg_tensor: Tensor,
meta_key: str,
meta_key: Any,
meta_val: PyGMetagraphValues,
) -> DataFrame:
"""Builds a Pandas DataFrame from PyG Tensor, based on
"""Builds a DataFrame from PyG Tensor, based on
the nature of the user-defined metagraph.

:param pyg_tensor: The Tensor representing PyG data.
:type pyg_tensor: torch.Tensor
:param meta_key: The current PyG-ArangoDB metagraph key
:type meta_key
:type meta_key: Any
:param meta_val: The value mapped to the PyG-ArangoDB metagraph key to
help convert **tensor** into a Pandas Dataframe.
help convert **tensor** into a DataFrame.
e.g the value of `metagraph['nodeTypes']['users']['x']`.
:type meta_val: adbpyg_adapter.typings.PyGMetagraphValues
:return: A Pandas DataFrame equivalent to the Tensor
:rtype: pandas.DataFrame
:return: A DataFrame equivalent to the Tensor
:rtype: (pandas | cudf).DataFrame
:raise adbpyg_adapter.exceptions.PyGMetagraphError: If invalid **meta_val**.
"""
logger.debug(f"__build_dataframe_from_tensor(df, '{meta_key}', {meta_val})")
Expand Down
7 changes: 6 additions & 1 deletion adbpyg_adapter/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

from typing import Any, Dict, Optional

from pandas import DataFrame
try:
# https://github.com/arangoml/pyg-adapter/issues/4
from cudf import DataFrame
except ModuleNotFoundError:
from pandas import DataFrame

from torch import Tensor, from_numpy, zeros


Expand Down
Loading