-
Notifications
You must be signed in to change notification settings - Fork 7
[KGE] Add embedding export / unenumeration utils #341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
exporter = export.EmbeddingExporter( | ||
export_dir=embedding_dir, | ||
file_prefix=f"{rank}_of_{world_size}_embeddings_", | ||
min_shard_size_threshold_bytes=1_000_000_000, # 1GB threshold for sharding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems like a lot, any reason we are using 1GB here?
IIRC, we found best results to be 200-500mb fordownstream distributed reads.
embeddings_table_node_id_field: str, | ||
unenumerated_embeddings_table: str, | ||
enumerator_mapping_table: str, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the following function be re-used here instead? gigl.src.post_process.utils.unenumeration._unenumerate_single_inferred_asset()
It could also be moved to some shared place for better visibility than in pos processor.
logger.info(f"{rank_prefix_str} Initialized TrainPipelineSparseDist for inference.") | ||
|
||
# Run inference in no_grad context to save memory and improve performance | ||
with torch.no_grad(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion / non-blocking: there is also with torch.inference_mode():
which is faster than no_grad. It wasnt working w/ GNN inference due to some reason (@mkolodner-sc) - but maybe it works here.
|
||
def infer_and_export_embeddings( | ||
applied_task_identifier: AppliedTaskIdentifier, | ||
rank_prefix_str: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this feels like a weird argument.
Any reason it cannot be inferred automatically?
logger.info( | ||
f"""{rank_prefix_str} Running inference for edge type {edge_type} on | ||
src node type {edge_type.src_node_type} and dst node type {edge_type.dst_node_type}.""" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use textwrap.dedent
?
) | ||
unenumerate_embeddings_table( | ||
enumerated_embeddings_table=enum_src_node_embedding_table, | ||
embeddings_table_node_id_field=export._NODE_ID_KEY, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use of cross modules private constants.
_NODE_ID_KEY
QUALIFY RANK() OVER (PARTITION BY mapping.{original_node_id_field} ORDER BY RAND()) = 1 | ||
""" | ||
|
||
bq_utils = BqUtils(project=get_resource_config().project) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have been trying to prevent coupling of get_resource_config()
in low level / leaf apis/function calls.
It is better to dereference resource config as much as we can higher in the stack and pass those values through. Not doing this leads to usability and extensibility challanges.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reco: inject project as a param.
) | ||
|
||
# Process destination nodes for this edge type | ||
infer_and_export_node_embeddings( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
opportunity to reduce complexity here by infer_and_export_node_embeddings
automatically exporting both src and dst node embeddings.
We are essentially calling infer_and_export_node_embeddings
w/ same arguments but different values of is_src
.
world_size: int, | ||
device: torch.device, | ||
kge_config: HeterogeneousGraphSparseEmbeddingConfig, | ||
model_and_loss: Union[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need loss
if we are just doing inference?
Add some utils which are used to write embeddings to BQ and unenumerate the resulting tables for KGE.
Where is the documentation for this feature?: N/A
Did you add automated tests or write a test plan?
Updated Changelog.md? NO
Ready for code review?: YES