Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Correctly set random seed for pytext training (#269)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #269

Random seed was configured in Trainer but unused; move this to Task (because model parameters are initialized in Task, before Trainer) and correctly set random / numpy / torch seeds.

Reviewed By: seayoung1112

Differential Revision: D13950646

fbshipit-source-id: f37427311466bb6ffcf00677f81bd63cc858f8c8
  • Loading branch information
Michael Wu authored and facebook-github-bot committed Feb 6, 2019
1 parent 5df1377 commit 38df626
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 5 deletions.
1 change: 0 additions & 1 deletion demo/configs/rnng.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
},
"trainer": {
"real_trainer": {
"random_seed": 0,
"epochs": 1,
"early_stop_after": 0,
"max_clip_norm": null,
Expand Down
1 change: 1 addition & 0 deletions pytext/task/disjoint_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def from_config(cls, task_config, metadata=None, model_state=None):
lr_scheduler=Scheduler(
optimizer, task_config.scheduler, metric_reporter.lower_is_better
),
random_seed=task_config.random_seed,
)

def __init__(self, target_task_name, exporters, **kwargs):
Expand Down
11 changes: 10 additions & 1 deletion pytext/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Config(ConfigBase):
optimizer: Optimizer.Config = Adam.Config()
scheduler: Optional[Scheduler.Config] = Scheduler.Config()
exporter: Optional[ModelExporter.Config] = None
random_seed: int = 0

@classmethod
def from_config(cls, task_config, metadata=None, model_state=None):
Expand Down Expand Up @@ -116,6 +117,7 @@ def from_config(cls, task_config, metadata=None, model_state=None):
optimizer, task_config.scheduler, metric_reporter.lower_is_better
),
exporter=exporter,
random_seed=task_config.random_seed,
)

def __init__(
Expand All @@ -127,6 +129,7 @@ def __init__(
optimizer: torch.optim.Optimizer,
lr_scheduler: List[torch.optim.lr_scheduler._LRScheduler],
exporter: Optional[ModelExporter],
random_seed: int,
) -> None:
self.trainer: Trainer = trainer
self.data_handler: DataHandler = data_handler
Expand All @@ -135,6 +138,7 @@ def __init__(
self.optimizer: torch.optim.Optimizer = optimizer
self.lr_scheduler: List[torch.optim.lr_scheduler._LRScheduler] = lr_scheduler
self.exporter = exporter
self.random_seed = random_seed

def train(self, train_config, rank=0, world_size=1):
"""
Expand All @@ -146,7 +150,9 @@ def train(self, train_config, rank=0, world_size=1):
world_size (int): for distributed training only, total gpu to use, default
is 1
"""
return self.trainer.train(
# Check seed is set correctly.
assert torch.initial_seed() == self.random_seed
result = self.trainer.train(
self.data_handler.get_train_iter(rank, world_size),
self.data_handler.get_eval_iter(),
self.model,
Expand All @@ -156,6 +162,9 @@ def train(self, train_config, rank=0, world_size=1):
self.lr_scheduler,
rank=rank,
)
# Check seed is not tampered with by other code.
assert torch.initial_seed() == self.random_seed
return result

def test(self, test_path):
"""
Expand Down
3 changes: 0 additions & 3 deletions pytext/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class Trainer(TrainerBase):
2 Test trained model, compute and publish metrics against a blind test set.
Attributes:
random_seed (int): Manual random seed
epochs (int): Training epochs
early_stop_after (int): Stop after how many epochs when the eval metric
is not improving
Expand All @@ -40,8 +39,6 @@ class Trainer(TrainerBase):
"""

class Config(ConfigBase):
# Manual random seed
random_seed: int = 0
# Training epochs
epochs: int = 10
# Stop after how many epochs when the eval metric is not improving
Expand Down
10 changes: 10 additions & 0 deletions pytext/utils/python_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import random

import numpy as np
import torch


def cls_vars(cls):
return [v for n, v in vars(cls).items() if not n.startswith("_")]


def set_random_seeds(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
3 changes: 3 additions & 0 deletions pytext/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytext.metric_reporters.channel import Channel
from pytext.task import Task, create_task, load, save
from pytext.utils.dist_utils import dist_init
from pytext.utils.python_utils import set_random_seeds

from .utils import cuda_utils

Expand Down Expand Up @@ -90,6 +91,8 @@ def prepare_task(

print("\nParameters: {}\n".format(config))
_set_cuda(config.use_cuda_if_available, device_id, world_size)
set_random_seeds(config.task.random_seed)

if config.load_snapshot_path and os.path.isfile(config.load_snapshot_path):
task = load(config.load_snapshot_path)
else:
Expand Down

0 comments on commit 38df626

Please sign in to comment.