Skip to content

Conversation

nshah-sc
Copy link
Collaborator

Add some utils which are used to write embeddings to BQ and unenumerate the resulting tables for KGE.

Where is the documentation for this feature?: N/A

Did you add automated tests or write a test plan?

Updated Changelog.md? NO

Ready for code review?: YES

exporter = export.EmbeddingExporter(
export_dir=embedding_dir,
file_prefix=f"{rank}_of_{world_size}_embeddings_",
min_shard_size_threshold_bytes=1_000_000_000, # 1GB threshold for sharding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like a lot, any reason we are using 1GB here?
IIRC, we found best results to be 200-500mb fordownstream distributed reads.

embeddings_table_node_id_field: str,
unenumerated_embeddings_table: str,
enumerator_mapping_table: str,
):
Copy link
Collaborator

@svij-sc svij-sc Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the following function be re-used here instead? gigl.src.post_process.utils.unenumeration._unenumerate_single_inferred_asset()
It could also be moved to some shared place for better visibility than in pos processor.

logger.info(f"{rank_prefix_str} Initialized TrainPipelineSparseDist for inference.")

# Run inference in no_grad context to save memory and improve performance
with torch.no_grad():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion / non-blocking: there is also with torch.inference_mode(): which is faster than no_grad. It wasnt working w/ GNN inference due to some reason (@mkolodner-sc) - but maybe it works here.


def infer_and_export_embeddings(
applied_task_identifier: AppliedTaskIdentifier,
rank_prefix_str: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels like a weird argument.
Any reason it cannot be inferred automatically?

logger.info(
f"""{rank_prefix_str} Running inference for edge type {edge_type} on
src node type {edge_type.src_node_type} and dst node type {edge_type.dst_node_type}."""
)
Copy link
Collaborator

@svij-sc svij-sc Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use textwrap.dedent ?

)
unenumerate_embeddings_table(
enumerated_embeddings_table=enum_src_node_embedding_table,
embeddings_table_node_id_field=export._NODE_ID_KEY,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use of cross modules private constants.
_NODE_ID_KEY

QUALIFY RANK() OVER (PARTITION BY mapping.{original_node_id_field} ORDER BY RAND()) = 1
"""

bq_utils = BqUtils(project=get_resource_config().project)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have been trying to prevent coupling of get_resource_config() in low level / leaf apis/function calls.
It is better to dereference resource config as much as we can higher in the stack and pass those values through. Not doing this leads to usability and extensibility challanges.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reco: inject project as a param.

)

# Process destination nodes for this edge type
infer_and_export_node_embeddings(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

opportunity to reduce complexity here by infer_and_export_node_embeddings automatically exporting both src and dst node embeddings.
We are essentially calling infer_and_export_node_embeddings w/ same arguments but different values of is_src.

world_size: int,
device: torch.device,
kge_config: HeterogeneousGraphSparseEmbeddingConfig,
model_and_loss: Union[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need loss if we are just doing inference?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants