Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions python/gigl/common/data/load_torch_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
_FEATURE_FMT = "{entity}_features"
_NODE_KEY = "node"
_EDGE_KEY = "edge"
_POSITIVE_LABEL_KEY = "positive_label"
_NEGATIVE_LABEL_KEY = "negative_label"
_POSITIVE_SUPERVISION_EDGES_KEY = "positive_supervision_edges"
_NEGATIVE_SUPERVISION_EDGES_KEY = "negative_supervision_edges"


# TODO (mkolodner-sc): Change positive/negative label name to positive/negative supervision edges
@dataclass(frozen=True)
class SerializedGraphMetadata:
"""
Expand Down Expand Up @@ -256,7 +257,7 @@ def load_torch_tensors_from_tf_record(
"tf_record_dataloader": tf_record_dataloader,
"output_dict": edge_output_dict,
"error_dict": error_dict,
"entity_type": _POSITIVE_LABEL_KEY,
"entity_type": _POSITIVE_SUPERVISION_EDGES_KEY,
"serialized_tf_record_info": serialized_graph_metadata.positive_label_entity_info,
"rank": rank,
},
Expand All @@ -271,7 +272,7 @@ def load_torch_tensors_from_tf_record(
"tf_record_dataloader": tf_record_dataloader,
"output_dict": edge_output_dict,
"error_dict": error_dict,
"entity_type": _NEGATIVE_LABEL_KEY,
"entity_type": _NEGATIVE_SUPERVISION_EDGES_KEY,
"serialized_tf_record_info": serialized_graph_metadata.negative_label_entity_info,
"rank": rank,
},
Expand Down Expand Up @@ -328,12 +329,12 @@ def load_torch_tensors_from_tf_record(
edge_index = edge_output_dict[_ID_FMT.format(entity=_EDGE_KEY)]
edge_features = edge_output_dict.get(_FEATURE_FMT.format(entity=_EDGE_KEY), None)

positive_labels = edge_output_dict.get(
_ID_FMT.format(entity=_POSITIVE_LABEL_KEY), None
positive_supervision_edges = edge_output_dict.get(
_ID_FMT.format(entity=_POSITIVE_SUPERVISION_EDGES_KEY), None
)

negative_labels = edge_output_dict.get(
_ID_FMT.format(entity=_NEGATIVE_LABEL_KEY), None
negative_supervision_edges = edge_output_dict.get(
_ID_FMT.format(entity=_NEGATIVE_SUPERVISION_EDGES_KEY), None
)

if rpc_is_initialized():
Expand All @@ -351,6 +352,6 @@ def load_torch_tensors_from_tf_record(
node_features=node_features,
edge_index=edge_index,
edge_features=edge_features,
positive_label=positive_labels,
negative_label=negative_labels,
positive_supervision_edges=positive_supervision_edges,
negative_supervision_edges=negative_supervision_edges,
)
32 changes: 18 additions & 14 deletions python/gigl/distributed/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _load_and_build_partitioned_dataset(
set up beforehand, this function will throw an error.
Args:
serialized_graph_metadata (SerializedGraphMetadata): Serialized Graph Metadata contains serialized information for loading TFRecords across node and edge types
should_load_tensors_in_parallel (bool): Whether tensors should be loaded from serialized information in parallel or in sequence across the [node, edge, pos_label, neg_label] entity types.
should_load_tensors_in_parallel (bool): Whether tensors should be loaded from serialized information in parallel or in sequence across the [node, edge, positive_supervision_edges, negative_supervision_edges] entity types.
edge_dir (Literal["in", "out"]): Edge direction of the provided graph
partitioner_class (Optional[Type[DistPartitioner]]): Partitioner class to partition the graph inputs. If provided, this must be a
DistPartitioner or subclass of it. If not provided, will initialize use the DistPartitioner class.
Expand Down Expand Up @@ -238,8 +238,8 @@ def _load_and_build_partitioned_dataset(
# TODO (mkolodner-sc): Move this code block (from here up to start of partitioning) to transductive splitter once that is ready
if _ssl_positive_label_percentage is not None:
if (
loaded_graph_tensors.positive_label is not None
or loaded_graph_tensors.negative_label is not None
loaded_graph_tensors.positive_supervision_edges is not None
or loaded_graph_tensors.negative_supervision_edges is not None
):
raise ValueError(
"Cannot have loaded positive and negative labels when attempting to select self-supervised positive edges from edge index."
Expand Down Expand Up @@ -269,7 +269,7 @@ def _load_and_build_partitioned_dataset(
f"Found an unknown edge index type: {type(loaded_graph_tensors.edge_index)} when attempting to select positive labels"
)

loaded_graph_tensors.positive_label = positive_label_edges
loaded_graph_tensors.positive_supervision_edges = positive_label_edges

if (
isinstance(splitter, NodeAnchorLinkSplitter)
Expand Down Expand Up @@ -304,13 +304,15 @@ def _load_and_build_partitioned_dataset(
partitioner.register_edge_features(
edge_features=loaded_graph_tensors.edge_features
)
if loaded_graph_tensors.positive_label is not None:
partitioner.register_labels(
label_edge_index=loaded_graph_tensors.positive_label, is_positive=True
if loaded_graph_tensors.positive_supervision_edges is not None:
partitioner.register_supervision_edges(
supervision_edge_index=loaded_graph_tensors.positive_supervision_edges,
is_positive=True,
)
if loaded_graph_tensors.negative_label is not None:
partitioner.register_labels(
label_edge_index=loaded_graph_tensors.negative_label, is_positive=False
if loaded_graph_tensors.negative_supervision_edges is not None:
partitioner.register_supervision_edges(
supervision_edge_index=loaded_graph_tensors.negative_supervision_edges,
is_positive=False,
)

# We call del so that the reference count of these registered fields is 1,
Expand All @@ -321,8 +323,8 @@ def _load_and_build_partitioned_dataset(
loaded_graph_tensors.node_features,
loaded_graph_tensors.edge_index,
loaded_graph_tensors.edge_features,
loaded_graph_tensors.positive_label,
loaded_graph_tensors.negative_label,
loaded_graph_tensors.positive_supervision_edges,
loaded_graph_tensors.negative_supervision_edges,
)
del loaded_graph_tensors

Expand Down Expand Up @@ -399,7 +401,8 @@ def _build_dataset_process(
node_rank (int): Rank of the node (machine) on which this process is running
node_world_size (int): World size (total #) of the nodes participating in hosting the dataset
sample_edge_direction (Literal["in", "out"]): Whether edges in the graph are directed inward or outward
should_load_tensors_in_parallel (bool): Whether tensors should be loaded from serialized information in parallel or in sequence across the [node, edge, pos_label, neg_label] entity types.
should_load_tensors_in_parallel (bool): Whether tensors should be loaded from serialized information in parallel or
in sequence across the [node, edge, positive_supervision_edges, negative_supervision_edges] entity types.
partitioner_class (Optional[Type[DistPartitioner]]): Partitioner class to partition the graph inputs. If provided, this must be a
DistPartitioner or subclass of it. If not provided, will initialize use the DistPartitioner class.
node_tf_dataset_options (TFDatasetOptions): Options provided to a tf.data.Dataset to tune how serialized node data is read.
Expand Down Expand Up @@ -498,7 +501,8 @@ def build_dataset(
you need not initialized a process_group, one will be initialized.
sample_edge_direction (Union[Literal["in", "out"], str]): Whether edges in the graph are directed inward or outward. Note that this is
listed as a possible string to satisfy type check, but in practice must be a Literal["in", "out"].
should_load_tensors_in_parallel (bool): Whether tensors should be loaded from serialized information in parallel or in sequence across the [node, edge, pos_label, neg_label] entity types.
should_load_tensors_in_parallel (bool): Whether tensors should be loaded from serialized information in parallel or in sequence across the [node, edge,
positive_supervision_edges, negative_supervision_edges] entity types.
partitioner_class (Optional[Type[DistPartitioner]]): Partitioner class to partition the graph inputs. If provided, this must be a
DistPartitioner or subclass of it. If not provided, will initialize use the DistPartitioner class.
node_tf_dataset_options (TFDatasetOptions): Options provided to a tf.data.Dataset to tune how serialized node data is read.
Expand Down
8 changes: 6 additions & 2 deletions python/gigl/distributed/dist_link_prediction_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,8 +690,12 @@ def build(
gc.collect()

# Initializing Positive and Negative Edge Labels
self._positive_edge_label = partition_output.partitioned_positive_labels
self._negative_edge_label = partition_output.partitioned_negative_labels
self._positive_edge_label = (
partition_output.partitioned_positive_supervision_edges
)
self._negative_edge_label = (
partition_output.partitioned_negative_supervision_edges
)

# TODO (mkolodner-sc): Enable custom params for init_graph, init_node_features, and init_edge_features

Expand Down
Loading