From b44b83609734255d2985e8e268e88a3cd49595ad Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 1 Oct 2025 18:40:30 +0000 Subject: [PATCH] logging --- python/gigl/utils/data_splitters.py | 35 ++++++++++++++++--- .../tests/unit/utils/data_splitters_test.py | 30 ++++++++-------- 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/python/gigl/utils/data_splitters.py b/python/gigl/utils/data_splitters.py index a7272cd4d..052479c9c 100644 --- a/python/gigl/utils/data_splitters.py +++ b/python/gigl/utils/data_splitters.py @@ -1,6 +1,9 @@ import gc from collections import defaultdict from collections.abc import Mapping +import psutil +import os + from typing import ( Callable, Final, @@ -32,6 +35,11 @@ PADDING_NODE: Final[torch.Tensor] = torch.tensor(-1, dtype=torch.int64) +def _debug_memory_usage(prefix: str): + process = psutil.Process(os.getpid()) + memory_info = process.memory_info() + logger.info(f"{prefix} Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB (out of {psutil.virtual_memory().total / 1024 / 1024:.2f} MB)") + # We need to make the protocols for the node splitter and node anchor linked spliter runtime checkable so that # we can make isinstance() checks on them at runtime. @@ -654,6 +662,7 @@ def _get_padded_labels( # and indices is the COL_INDEX of a CSR matrix. # See https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format) # Note that GLT defaults to CSR under the hood, if this changes, we will need to update this. + _debug_memory_usage("Before indptr and indices") indptr = topo.indptr # [N] indices = topo.indices # [M] extra_nodes_to_pad = 0 @@ -663,20 +672,35 @@ def _get_padded_labels( anchor_node_ids = anchor_node_ids[valid_ids] starts = indptr[anchor_node_ids] # [N] ends = indptr[anchor_node_ids + 1] # [N] - + _debug_memory_usage("After starts and ends") max_range = int(torch.max(ends - starts).item()) + # Mask out the parts of "ranges" that are not applicable to the current label + # filling out the rest with `PADDING_NODE`. + mask = torch.arange(max_range) >= (ends - starts).unsqueeze(1) + max_end_value = ends.max().item() + _debug_memory_usage("After max_end_value") + del ends + gc.collect() + _debug_memory_usage("After ends gc") # Sample all labels based on the CSR start/stop indices. # Creates "indices" for us to us, e.g [[0, 1], [2, 3]] ranges = starts.unsqueeze(1) + torch.arange(max_range) # [N, max_range] + _debug_memory_usage("After ranges") + del starts + gc.collect() + _debug_memory_usage("After starts gc") + # Clamp the ranges to be valid indices into `indices`. - ranges.clamp_(min=0, max=ends.max().item() - 1) - # Mask out the parts of "ranges" that are not applicable to the current label - # filling out the rest with `PADDING_NODE`. - mask = torch.arange(max_range) >= (ends - starts).unsqueeze(1) + ranges.clamp_(min=0, max=max_end_value - 1) + _debug_memory_usage("After clamp") labels = torch.where( mask, torch.full_like(ranges, PADDING_NODE.item()), indices[ranges] ) + _debug_memory_usage("After labels") + del ranges + gc.collect() + _debug_memory_usage("After ranges gc") labels = torch.cat( [ labels, @@ -684,6 +708,7 @@ def _get_padded_labels( ], dim=0, ) + _debug_memory_usage("After cat") return labels diff --git a/python/tests/unit/utils/data_splitters_test.py b/python/tests/unit/utils/data_splitters_test.py index efe67c044..6aff8366e 100644 --- a/python/tests/unit/utils/data_splitters_test.py +++ b/python/tests/unit/utils/data_splitters_test.py @@ -99,7 +99,7 @@ def tearDown(self): ), ] ) - def test_fast_hash( + def _test_fast_hash( self, _, input_tensor: torch.Tensor, expected_output: torch.Tensor ): actual = _fast_hash(input_tensor) @@ -181,7 +181,7 @@ def test_fast_hash( ), ] ) - def test_node_based_link_splitter( + def _test_node_based_link_splitter( self, _, edges, @@ -400,7 +400,7 @@ def test_node_based_link_splitter( ), ] ) - def test_node_based_link_splitter_heterogenous( + def _test_node_based_link_splitter_heterogenous( self, _, edges, @@ -441,7 +441,7 @@ def test_node_based_link_splitter_heterogenous( assert_close(val, expected_val, rtol=0, atol=0) assert_close(test, expected_test, rtol=0, atol=0) - def test_node_based_link_splitter_parallelized(self): + def _test_node_based_link_splitter_parallelized(self): init_method = get_process_group_init_method() splitter = HashedNodeAnchorLinkSplitter( sampling_direction="out", @@ -508,7 +508,7 @@ def test_node_based_link_splitter_parallelized(self): join=True, ) - def test_node_based_splitter_parallelized(self): + def _test_node_based_splitter_parallelized(self): init_method = get_process_group_init_method() splitter = HashedNodeSplitter(hash_function=_IdentityHash()) nodes = [ @@ -575,7 +575,7 @@ def test_node_based_splitter_parallelized(self): ), ] ) - def test_node_based_link_splitter_heterogenous_invalid( + def _test_node_based_link_splitter_heterogenous_invalid( self, _, edges, @@ -604,7 +604,7 @@ def test_node_based_link_splitter_heterogenous_invalid( param("Negative val percentage", train_percentage=0.8, val_percentage=-1.0), ] ) - def test_assert_valid_split_ratios(self, _, train_percentage, val_percentage): + def _test_assert_valid_split_ratios(self, _, train_percentage, val_percentage): with self.assertRaises(ValueError): _assert_valid_split_ratios(train_percentage, val_percentage) @@ -615,11 +615,11 @@ def test_assert_valid_split_ratios(self, _, train_percentage, val_percentage): param("Sparse tensor", edges=torch.zeros(2, 2).to_sparse()), ] ) - def test_check_edge_index(self, _, edges): + def _test_check_edge_index(self, _, edges): with self.assertRaises(ValueError): _check_edge_index(edges) - def test_hashed_node_anchor_link_splitter_requires_process_group(self): + def _test_hashed_node_anchor_link_splitter_requires_process_group(self): edges = torch.stack( [ torch.arange(0, 40, 2, dtype=torch.int64), @@ -805,7 +805,7 @@ def test_get_padded_labels(self, _, node_ids, topo, expected): ), ] ) - def test_hashed_node_splitter( + def _test_hashed_node_splitter( self, _, node_ids, @@ -907,7 +907,7 @@ def test_hashed_node_splitter( ), ] ) - def test_hashed_node_splitter_heterogeneous( + def _test_hashed_node_splitter_heterogeneous( self, _, node_ids, @@ -944,7 +944,7 @@ def test_hashed_node_splitter_heterogeneous( assert_tensor_equality(val, expected_val, dim=0) assert_tensor_equality(test, expected_test, dim=0) - def test_hashed_node_splitter_requires_process_group(self): + def _test_hashed_node_splitter_requires_process_group(self): node_ids = torch.arange(10, dtype=torch.int64) splitter = HashedNodeSplitter() with self.assertRaises(RuntimeError): @@ -966,7 +966,7 @@ def test_hashed_node_splitter_requires_process_group(self): ), ] ) - def test_hashed_node_splitter_invalid_inputs(self, _, node_ids): + def _test_hashed_node_splitter_invalid_inputs(self, _, node_ids): torch.distributed.init_process_group( rank=0, world_size=1, init_method=get_process_group_init_method() ) @@ -995,7 +995,7 @@ class SelectSSLPositiveLabelEdgesTest(unittest.TestCase): ), ] ) - def test_valid_label_selection( + def _test_valid_label_selection( self, _, positive_label_percentage: float, expected_num_labels: int ): labels = select_ssl_positive_label_edges( @@ -1023,7 +1023,7 @@ def test_valid_label_selection( ), ] ) - def test_invalid_label_selection( + def _test_invalid_label_selection( self, _, edge_index: torch.Tensor, positive_label_percentage: float ): with self.assertRaises(ValueError):