Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

i29 api #39

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include l2gv2/datasets/data/*/*
76 changes: 56 additions & 20 deletions l2gv2/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
EDGE_COLUMNS = {"source", "dest"} # required columns

EdgeList = list[tuple[str, str]]
NodeIndex = dict[str | int, int]


def is_graph_dataset(p: Path) -> bool:
Expand All @@ -42,6 +43,8 @@ def __init__(self, dset: str | Path, timestamp_fmt: str = "%Y-%m-%d"):
if (nodes_path := self.path / (self.path.stem + "_nodes.parquet")).exists():
self.paths["nodes"] = nodes_path

self._node_index_map = None

self._load_files()

def timestamp_from_string(self, ts: str) -> datetime.datetime:
Expand Down Expand Up @@ -108,15 +111,16 @@ def _load_files(self):
x for x in self.nodes.columns if x not in ["timestamp", "label", "nodes"]
]

def get_dates(self) -> list[str]:
"Returns list of dates"
return self.datelist.to_list()

def get_edges(self) -> pl.DataFrame:
"Returns edges as a polars DataFrame"
return self.edges

def get_nodes(self, ts: str | None = None) -> pl.DataFrame:
@property
def timestamps(self) -> list:
"Returns sorted list of dates"
return sorted(self.datelist.to_list())

def get_nodes(self, ts: str | datetime.datetime | None = None) -> pl.DataFrame:
"""Returns node data as a polars DataFrame

Args:
Expand All @@ -127,11 +131,10 @@ def get_nodes(self, ts: str | None = None) -> pl.DataFrame:
"""
if ts is None:
return self.nodes
if isinstance(ts, str):
ts = self.timestamp_from_string(ts)
return self.nodes.filter(pl.col("timestamp") == ts)
ts_cast = self.timestamp_from_string(ts) if isinstance(ts, str) else ts
return self.nodes.filter(pl.col("timestamp") == ts_cast)

def get_node_list(self, ts: str | None = None) -> list[str]:
def get_node_list(self, ts: str | datetime.datetime | None = None) -> list[str]:
"""Returns node list

Args:
Expand All @@ -143,19 +146,10 @@ def get_node_list(self, ts: str | None = None) -> list[str]:
nodes = self.nodes

if ts is not None:
if isinstance(ts, str):
ts = self.timestamp_from_string(ts)
nodes = nodes.filter(pl.col("timestamp") == ts)
ts_cast = self.timestamp_from_string(ts) if isinstance(ts, str) else ts
nodes = nodes.filter(pl.col("timestamp") == ts_cast)
return nodes.select("nodes").unique(maintain_order=True).to_series().to_list()

def get_node_features(self) -> list[str]:
"Returns node features as a list of strings"
return self.node_features

def get_edge_features(self) -> list[str]:
"Returns edge features as a list of strings"
return self.edge_features

def get_graph(self) -> rp.Graph: # pylint: disable=no-member
"Returns a raphtory.Graph representation"
g = rp.Graph() # pylint: disable=no-member
Expand Down Expand Up @@ -253,6 +247,48 @@ def get_edge_index(
edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
return edge_index

def get_node_index_map(
self, ts: str | datetime.datetime | None = None
) -> NodeIndex:
"""Returns mapping from node value to an integer index.

Local2Global requires 0..|V|-1 indexing for a node set V.
Nodes have to be indexed from 1 both within a patch (timestamp),
and globally. This function returns a dictionary mapping
node values to the index.

Parameters
----------
ts
If specified, return index map for patch with timestamp `ts`
"""
if ts is None:
if self._node_index_map is None:
all_nodes = self.get_node_list()
self._node_index_map: dict[str | int, int] = {
x: i for i, x in enumerate(all_nodes)
}
else:
return self._node_index_map
ts_cast = self.timestamp_from_string(ts) if isinstance(ts, str) else ts
nodes = self.get_node_list(ts_cast)
return {x: i for i, x in enumerate(nodes)}

def get_renumbered_nodes(self) -> list[list[int]]:
"""Returns a list of renumbered nodes R from a timestamp based patch graph

In this list $R_i$ corresponds to the set of nodes at patch $i$, but
with the global node indexing applied (using
:meth:`DataLoader.get_node_index_map`). The ordering of the patch
graphs in this list is the timestamp ordering, with the earliest
timestamp as index 0.
"""
list_nodes_renumbered = []
node_idx = self.get_node_index_map()
for ts in self.timestamps:
list_nodes_renumbered.append([node_idx[i] for i in self.get_node_list(ts)])
return list_nodes_renumbered

def get_tgeometric(
self, temp: bool = True
) -> torch_geometric.data.Data | dict[datetime.datetime, torch_geometric.data.Data]:
Expand Down
Loading