Skip to content

Commit d0036fa

Browse files
set all seeds during training
1 parent 7f37d20 commit d0036fa

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

lightning_pose/train.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Example model training function."""
22

33
import os
4+
import random
45

56
import lightning.pytorch as pl
7+
import numpy as np
8+
import torch
69
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
710
from typeguard import typechecked
811

@@ -32,6 +35,15 @@
3235
@typechecked
3336
def train(cfg: DictConfig) -> None:
3437

38+
# reset all seeds
39+
seed = 0
40+
os.environ["PYTHONHASHSEED"] = str(seed)
41+
torch.manual_seed(seed)
42+
np.random.seed(seed)
43+
random.seed(seed)
44+
torch.backends.cudnn.benchmark = False
45+
torch.backends.cudnn.deterministic = True
46+
3547
# record lightning-pose version
3648
from lightning_pose import __version__ as lightning_pose_version
3749
with open_dict(cfg):

0 commit comments

Comments
 (0)