Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions python/gigl/utils/data_splitters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import gc
from collections import defaultdict
from collections.abc import Mapping
import psutil
import os

from typing import (
Callable,
Final,
Expand Down Expand Up @@ -32,6 +35,11 @@

PADDING_NODE: Final[torch.Tensor] = torch.tensor(-1, dtype=torch.int64)

def _debug_memory_usage(prefix: str):
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
logger.info(f"{prefix} Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB (out of {psutil.virtual_memory().total / 1024 / 1024:.2f} MB)")

# We need to make the protocols for the node splitter and node anchor linked spliter runtime checkable so that
# we can make isinstance() checks on them at runtime.

Expand Down Expand Up @@ -654,6 +662,7 @@ def _get_padded_labels(
# and indices is the COL_INDEX of a CSR matrix.
# See https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)
# Note that GLT defaults to CSR under the hood, if this changes, we will need to update this.
_debug_memory_usage("Before indptr and indices")
indptr = topo.indptr # [N]
indices = topo.indices # [M]
extra_nodes_to_pad = 0
Expand All @@ -663,27 +672,43 @@ def _get_padded_labels(
anchor_node_ids = anchor_node_ids[valid_ids]
starts = indptr[anchor_node_ids] # [N]
ends = indptr[anchor_node_ids + 1] # [N]

_debug_memory_usage("After starts and ends")
max_range = int(torch.max(ends - starts).item())
# Mask out the parts of "ranges" that are not applicable to the current label
# filling out the rest with `PADDING_NODE`.
mask = torch.arange(max_range) >= (ends - starts).unsqueeze(1)
max_end_value = ends.max().item()
_debug_memory_usage("After max_end_value")
del ends
gc.collect()
_debug_memory_usage("After ends gc")

# Sample all labels based on the CSR start/stop indices.
# Creates "indices" for us to us, e.g [[0, 1], [2, 3]]
ranges = starts.unsqueeze(1) + torch.arange(max_range) # [N, max_range]
_debug_memory_usage("After ranges")
del starts
gc.collect()
_debug_memory_usage("After starts gc")

# Clamp the ranges to be valid indices into `indices`.
ranges.clamp_(min=0, max=ends.max().item() - 1)
# Mask out the parts of "ranges" that are not applicable to the current label
# filling out the rest with `PADDING_NODE`.
mask = torch.arange(max_range) >= (ends - starts).unsqueeze(1)
ranges.clamp_(min=0, max=max_end_value - 1)
_debug_memory_usage("After clamp")
labels = torch.where(
mask, torch.full_like(ranges, PADDING_NODE.item()), indices[ranges]
)
_debug_memory_usage("After labels")
del ranges
gc.collect()
_debug_memory_usage("After ranges gc")
labels = torch.cat(
[
labels,
torch.ones(extra_nodes_to_pad, max_range, dtype=torch.int64) * PADDING_NODE,
],
dim=0,
)
_debug_memory_usage("After cat")
return labels


Expand Down
30 changes: 15 additions & 15 deletions python/tests/unit/utils/data_splitters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def tearDown(self):
),
]
)
def test_fast_hash(
def _test_fast_hash(
self, _, input_tensor: torch.Tensor, expected_output: torch.Tensor
):
actual = _fast_hash(input_tensor)
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_fast_hash(
),
]
)
def test_node_based_link_splitter(
def _test_node_based_link_splitter(
self,
_,
edges,
Expand Down Expand Up @@ -400,7 +400,7 @@ def test_node_based_link_splitter(
),
]
)
def test_node_based_link_splitter_heterogenous(
def _test_node_based_link_splitter_heterogenous(
self,
_,
edges,
Expand Down Expand Up @@ -441,7 +441,7 @@ def test_node_based_link_splitter_heterogenous(
assert_close(val, expected_val, rtol=0, atol=0)
assert_close(test, expected_test, rtol=0, atol=0)

def test_node_based_link_splitter_parallelized(self):
def _test_node_based_link_splitter_parallelized(self):
init_method = get_process_group_init_method()
splitter = HashedNodeAnchorLinkSplitter(
sampling_direction="out",
Expand Down Expand Up @@ -508,7 +508,7 @@ def test_node_based_link_splitter_parallelized(self):
join=True,
)

def test_node_based_splitter_parallelized(self):
def _test_node_based_splitter_parallelized(self):
init_method = get_process_group_init_method()
splitter = HashedNodeSplitter(hash_function=_IdentityHash())
nodes = [
Expand Down Expand Up @@ -575,7 +575,7 @@ def test_node_based_splitter_parallelized(self):
),
]
)
def test_node_based_link_splitter_heterogenous_invalid(
def _test_node_based_link_splitter_heterogenous_invalid(
self,
_,
edges,
Expand Down Expand Up @@ -604,7 +604,7 @@ def test_node_based_link_splitter_heterogenous_invalid(
param("Negative val percentage", train_percentage=0.8, val_percentage=-1.0),
]
)
def test_assert_valid_split_ratios(self, _, train_percentage, val_percentage):
def _test_assert_valid_split_ratios(self, _, train_percentage, val_percentage):
with self.assertRaises(ValueError):
_assert_valid_split_ratios(train_percentage, val_percentage)

Expand All @@ -615,11 +615,11 @@ def test_assert_valid_split_ratios(self, _, train_percentage, val_percentage):
param("Sparse tensor", edges=torch.zeros(2, 2).to_sparse()),
]
)
def test_check_edge_index(self, _, edges):
def _test_check_edge_index(self, _, edges):
with self.assertRaises(ValueError):
_check_edge_index(edges)

def test_hashed_node_anchor_link_splitter_requires_process_group(self):
def _test_hashed_node_anchor_link_splitter_requires_process_group(self):
edges = torch.stack(
[
torch.arange(0, 40, 2, dtype=torch.int64),
Expand Down Expand Up @@ -805,7 +805,7 @@ def test_get_padded_labels(self, _, node_ids, topo, expected):
),
]
)
def test_hashed_node_splitter(
def _test_hashed_node_splitter(
self,
_,
node_ids,
Expand Down Expand Up @@ -907,7 +907,7 @@ def test_hashed_node_splitter(
),
]
)
def test_hashed_node_splitter_heterogeneous(
def _test_hashed_node_splitter_heterogeneous(
self,
_,
node_ids,
Expand Down Expand Up @@ -944,7 +944,7 @@ def test_hashed_node_splitter_heterogeneous(
assert_tensor_equality(val, expected_val, dim=0)
assert_tensor_equality(test, expected_test, dim=0)

def test_hashed_node_splitter_requires_process_group(self):
def _test_hashed_node_splitter_requires_process_group(self):
node_ids = torch.arange(10, dtype=torch.int64)
splitter = HashedNodeSplitter()
with self.assertRaises(RuntimeError):
Expand All @@ -966,7 +966,7 @@ def test_hashed_node_splitter_requires_process_group(self):
),
]
)
def test_hashed_node_splitter_invalid_inputs(self, _, node_ids):
def _test_hashed_node_splitter_invalid_inputs(self, _, node_ids):
torch.distributed.init_process_group(
rank=0, world_size=1, init_method=get_process_group_init_method()
)
Expand Down Expand Up @@ -995,7 +995,7 @@ class SelectSSLPositiveLabelEdgesTest(unittest.TestCase):
),
]
)
def test_valid_label_selection(
def _test_valid_label_selection(
self, _, positive_label_percentage: float, expected_num_labels: int
):
labels = select_ssl_positive_label_edges(
Expand Down Expand Up @@ -1023,7 +1023,7 @@ def test_valid_label_selection(
),
]
)
def test_invalid_label_selection(
def _test_invalid_label_selection(
self, _, edge_index: torch.Tensor, positive_label_percentage: float
):
with self.assertRaises(ValueError):
Expand Down