Skip to content

Commit 0213989

Browse files
committed
fix multi-gpu training
1 parent 9fe9990 commit 0213989

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

rslp/lightning_cli.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import os
44

55
import jsonargparse
6-
import wandb
76
from lightning.pytorch.callbacks import Callback
7+
from lightning.pytorch.utilities import rank_zero_only
88
from rslearn.main import RslearnLightningCLI
99
from upath import UPath
1010

11+
import wandb
1112
from rslp import launcher_lib
1213

1314
CHECKPOINT_DIR = "gs://{rslp_bucket}/projects/{project_id}/{experiment_id}/checkpoints/"
@@ -26,6 +27,7 @@ def __init__(self, project_id: str, experiment_id: str):
2627
self.project_id = project_id
2728
self.experiment_id = experiment_id
2829

30+
@rank_zero_only
2931
def on_fit_start(self, trainer, pl_module):
3032
"""Called just before fit starts I think.
3133
@@ -102,6 +104,16 @@ def before_instantiate_classes(self):
102104
c.trainer.logger.init_args.project = c.rslp_project
103105
c.trainer.logger.init_args.name = c.rslp_experiment
104106

107+
# Configure DDP strategy with find_unused_parameters=True
108+
c.trainer.strategy = jsonargparse.Namespace(
109+
{
110+
"class_path": "lightning.pytorch.strategies.DDPStrategy",
111+
"init_args": jsonargparse.Namespace(
112+
{"find_unused_parameters": True}
113+
),
114+
}
115+
)
116+
105117
# Set the checkpoint directory to canonical GCS location.
106118
checkpoint_callback = None
107119
upload_wandb_callback = None

0 commit comments

Comments
 (0)