Skip to content

francescocastelli/torch-trainer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

54 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Torch trainer

Python library for training Pytorch models without having to care about devices, training loops and all the boilerplate code required for usually training a pytorch model.

Model

A model is defined by defining a class that inherits from torchtrainer.model.Model and implements the methods forward, training_step and validation_step. These last two methods contain the logic of a single training and validation step, the library is then resposible for creating the training loop and also send all the data to the correct device.

from torchtrainer.model import Model
import torch

class MyModel(Model):
    def __init__(self, name):
        super().__init__(name=name)
        # my init
        
    def forward(self, x):
        # my forward
        return x

    def training_step(self, x):
        # all the logic of a single training step
        # goes here. This method must return the value
        # of the loss function at the end of the training step
        
        input, target = x
        prediction = self(input)
        loss = torch.nn.functional.cross_entropy(prediction, target)
        return loss

    def validation_step(self, x):
        # all the logic of a validation step, if needed
        pass

    def define_optimizer_scheduler(self):
        # this method is used to define the optimizer
        # and the learning rate scheduler (not mandatory) to be
        # used for training the model
        
        opt = optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=opt, gamma=1.0)
        return (opt, scheduler)

Dataloaders and training

Once the model is created, an instance of torchtrainer.trainer.Trainer and the dataloaders for the training and validation dataset are needed. TrainLoader is a wrapper of torch.utils.data.DataLoader, which can be used to define custom batchers and samplers. The Trainer instead requires the datasets (train and validation), the model and the loader previously defined. Moreover, there are additional options that can be used, for instance, to: save models checkpoints, save tensorboard data (epochs info and hyper-parameters), ... .

from torchtrainer.trainer import Trainer
from torchtrainer.dataloader import TrainerLoader

def main():
  # the model previously defined
  model = MyModel('model')
  # train dataset and valid dataset
  train_dataset = ...
  valid_dataset = ...
  
  loader = TrainerLoader(batch_size=self.bs, num_workers=0, shuffle=False)
  trainer = Trainer(model=model, train_dataset=train_dataset, 
                    valid_dataset=valid_dataset, epoch_num=10, loader=loader)
  
  trainer.train()

Build

Use conda build to build the library as a conda package:

git clone https://github.com/francescocastelli/torch-trainer
cd torch-trainer/conda-recipe
conda build . --output-folder /path/to/output_folder

and install it in the current conda environment:

conda install --use-local /path/to/output_folder/linux-64/torchtrainer-1.1-py38_1.tar.bz2

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages