From fa5041d1333f4bdcd693550d431538d87e8805b9 Mon Sep 17 00:00:00 2001 From: Dhyey Mavani Date: Sat, 4 Oct 2025 08:33:51 +0000 Subject: [PATCH] Add support for custom callbacks in graphgym.train Allow users to pass custom PyTorch Lightning callbacks via the trainer_config parameter. This enables extending the training process with custom monitoring, logging, or other callback functionality without modifying the core train function. Fixes #10386 Co-authored-by: Ona --- torch_geometric/graphgym/train.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torch_geometric/graphgym/train.py b/torch_geometric/graphgym/train.py index d391e9de21a8..6060bbd5d6c1 100644 --- a/torch_geometric/graphgym/train.py +++ b/torch_geometric/graphgym/train.py @@ -62,6 +62,15 @@ def train( callbacks.append(ckpt_cbk) trainer_config = trainer_config or {} + + # Allow custom callbacks to be passed via trainer_config + if 'callbacks' in trainer_config: + custom_callbacks = trainer_config.pop('callbacks') + if isinstance(custom_callbacks, list): + callbacks.extend(custom_callbacks) + else: + callbacks.append(custom_callbacks) + trainer = pl.Trainer( **trainer_config, enable_checkpointing=cfg.train.enable_ckpt,