Skip to content

Commit

Permalink
Update docstrings and types.
Browse files Browse the repository at this point in the history
  • Loading branch information
mihaeladuta committed Nov 7, 2024
1 parent 9743ffb commit a504634
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 138 deletions.
26 changes: 13 additions & 13 deletions l2gv2/anomaly_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ def raw_anomaly_score_node_patch(aligned_patch_emb, emb, node) -> float:
"""TODO: docstring for `raw_anomaly_score_node_patch`
Args:
aligned_patch_emb ():
aligned_patch_emb ([type]): [description]
emb ():
emb ([type]): [description]
node ():
node ([type]): [description]
Returns:
float: Raw anomaly score of the node in the patch.
Expand All @@ -25,10 +25,10 @@ def nodes_in_patches(patch_data: list[Patch]) -> list:
"""TODO: docstring for `nodes_in_patches`
Args:
patch_data (list[Patch]):
patch_data (list[Patch]): [description]
Returns:
list:
list: [description]
"""

return [set(p.nodes.numpy()) for p in patch_data]
Expand All @@ -40,14 +40,14 @@ def normalized_anomaly(
"""TODO: docstring for `normalized_anomaly`
Args:
patch_emb (list[Patch]):
patch_emb (list[Patch]): [description]
patch_data (list[Patch]):
patch_data (list[Patch]): [description]
emb (np.array):
emb (np.array): [description]
Returns:
np.array:
np.array: [description]
"""

nodes = nodes_in_patches(patch_data)
Expand Down Expand Up @@ -95,16 +95,16 @@ def get_outliers(
"""TODO: docstring for `get_outliers`
Args:
patch_emb (list):
patch_emb (list): [description]
patch_data (list):
patch_data (list): [description]
emb (np.array):
emb (np.array): [description]
k (float): Threshold for outliers as multiplier of the standard deviation.
Returns:
list[int]:
list[int]: [description]
"""

out = []
Expand Down
8 changes: 4 additions & 4 deletions l2gv2/embedding/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def fit(
data (torch_geometric.data.Data):
logger ():
logger ([type]): [description]
Returns:
torch.nn.Module:
torch.nn.Module: [description]
"""
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
Expand Down Expand Up @@ -98,11 +98,11 @@ def loss_fun(self, data) -> torch.Tensor:
Args:
data (torch_geometric.data.Data)
data (torch_geometric.data.Data): [description]
Returns:
torch.Tensor:
torch.Tensor: [description]
"""
return (
self.model.recon_loss(self.model.encode(data), data.edge_index)
Expand Down
8 changes: 4 additions & 4 deletions l2gv2/induced_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ def induced_subgraph(data: tg.data.Data, nodes, extend_hops: int = 0) -> tg.data
"""TODO: docstring for `induced_subgraph`
Args:
data (torch_geometric.data.Data):
data (torch_geometric.data.Data): [description]
nodes (int):
nodes (int): [description]
extend_hops (int, optional): default is 0.
extend_hops (int, optional): [description], default is 0.
Returns:
torch_geometric.data.Data:
torch_geometric.data.Data: [description]
"""

nodes = torch.as_tensor(nodes, dtype=torch.long)
Expand Down
30 changes: 15 additions & 15 deletions l2gv2/manopt_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def double_intersections_nodes(
"""TODO: docstring for `double_intersections_nodes`.
Args:
patches (list[Patch]):
patches (list[Patch]): [description]
Returns:
dict[tuple[int, int], list[int]]:
dict[tuple[int, int], list[int]]: [description]
"""

double_intersections = {}
Expand All @@ -46,22 +46,22 @@ def anp_loss_nodes_consecutive_patches(
"""TODO: docstring for `anp_loss_nodes_consecutive_patches`.
Args:
rotations ():
rotations ([type]): [description]
scales ():
scales ([type]): [description]
translations (int):
translations (int): [description]
patches ():
patches ([type]): [description]
nodes ():
nodes ([type]): [description]
dim (int):
dim (int): [description]
random_choice (bool, optional): default is True.
random_choice (bool, optional): [description] default is True.
Returns:
float: loss function.
float: loss function value.
"""

loss_function = 0
Expand Down Expand Up @@ -98,16 +98,16 @@ def optimization(
"""TODO: docstring for `optimization`.
Args:
patches ():
patches ([type]): [description]
nodes ():
nodes ([type]): [description]
dim (int):
dim (int): [description]
Returns:
result:
result: [description]
embedding:
embedding: [description]
"""
n_patches = len(patches)

Expand Down
4 changes: 2 additions & 2 deletions l2gv2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def node2vec_(
data(torch_geometric.data.Data)
emb_dim (int):
emb_dim (int): [description]
w_length (int, optional): The walk length, default is 20.
Expand Down Expand Up @@ -296,7 +296,7 @@ def node2vec_patch_embeddings(
patch_data (list) torch_geometric.data.Data objects.
emb_dim (int):
emb_dim (int): [description]
w_length (int, optional): The walk length, default is 20.
Expand Down
15 changes: 13 additions & 2 deletions l2gv2/patch/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,11 @@ def __array__(self, dtype=None):
return self.as_array(out)

def as_array(self, out=None):
""" TODO: docstring for `as_array` """
""" TODO: docstring for `as_array`
Args:
out (Optional[type]): [description], defaults to None
"""
if out is None:
out = np.zeros(self.shape)
index = np.empty((self.nodes.max() + 1,), dtype=np.int64)
Expand All @@ -260,7 +264,14 @@ def as_array(self, out=None):
return out

def get_coordinates(self, nodes, out=None):
""" TODO: docstring for `get_coordinates` """
""" TODO: docstring for `get_coordinates`
Args:
nodes ([type]): [description]
out (Optional[type]): [description], default is None
"""
nodes = np.asanyarray(nodes)
if out is None:
out = np.zeros((len(nodes), self.dim))
Expand Down
49 changes: 25 additions & 24 deletions l2gv2/patch/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from random import choice
from math import ceil
from typing import List, Tuple, Optional, Literal
from collections.abc import Iterable

import torch
Expand Down Expand Up @@ -30,16 +31,16 @@ def geodesic_expand_overlap(
""" Expand patch
Args:
subgraph (): graph containing patch nodes and all target nodes for potential expansion
subgraph ([type]): graph containing patch nodes and all target nodes for potential expansion
seed_mask ():
seed_mask ([type]): [description]
min_overlap (): minimum overlap before stopping expansion
min_overlap ([type]): minimum overlap before stopping expansion
target_overlap (): maximum overlap
target_overlap ([type]): maximum overlap
(if expansion step results in more overlap, the nodes added are sampled at random)
reseed_samples (): default is 10
reseed_samples ([type]): [description] default is 10
Returns:
index tensor of new nodes to add to patch
Expand Down Expand Up @@ -148,13 +149,13 @@ def create_overlapping_patches(
partition_tensor (torch.LongTensor): partition of input graph
patch_graph (): graph where nodes are clusters of partition
patch_graph ([type]): graph where nodes are clusters of partition
and an edge indicates that the corresponding patches
in the output should have at least ``min_overlap`` nodes in common
min_overlap (): minimum overlap for connected patches
min_overlap ([type]): minimum overlap for connected patches
target_overlap (): maximum overlap during expansion
target_overlap ([type]): maximum overlap during expansion
for an edge (additional overlap may result from expansion of other edges)
Returns:
Expand Down Expand Up @@ -232,40 +233,40 @@ def _patch_overlaps(
)
return patches


def create_patch_data(
graph: TGraph,
partition_tensor: torch.LongTensor,
min_overlap,
target_overlap,
min_patch_size=None,
sparsify_method: str="resistance",
target_patch_degree: int=4,
gamma: int=0,
verbose: bool=False,
):
min_overlap: int,
target_overlap: int,
min_patch_size: Optional[int] = None,
sparsify_method: Literal["resistance", "rmst", "none"] = "resistance",
target_patch_degree: int = 4,
gamma: int = 0,
verbose: bool = False,
) -> Tuple[List, object]:
""" Divide data into overlapping patches
Args:
graph (TGraph): input data
partition_tensor (torch.LongTensor): starting partition for creating patches
min_overlap: minimum patch overlap for connected patches
min_overlap ([type]): minimum patch overlap for connected patches
target_overlap: maximum patch overlap during expansion of an edge of the patch graph
target_overlap ([type]): maximum patch overlap during expansion
of an edge of the patch graph
min_patch_size (optional): minimum size of patches, defauls is None
min_patch_size (Optional[type]): minimum size of patches, defauls is None
sparsify_method (str): method for sparsifying patch graph
sparsify_method (Optional[[str]): method for sparsifying patch graph
(one of ``'resistance'``, ``'rmst'``, ``'none'``), default is ``'resistance'``
target_patch_degree (optional): target patch degree for
target_patch_degree (Optional[str]): target patch degree for
``sparsify_method='resistance'``, default is 4
gamma (int): ``gamma`` value for use with ``sparsify_method='rmst'``, default is 0
gamma (Optional[int]): ``gamma`` value for use with ``sparsify_method='rmst'``, default is 0
verbose (bool): if true, print some info about created patches, default is False
verbose (Optional[bool]): if true, print some info about created patches, default is False
Returns:
list of patch data, patch graph
Expand Down
Loading

0 comments on commit a504634

Please sign in to comment.