Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docs #25

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open

Docs #25

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
59 changes: 59 additions & 0 deletions demo/01_cycleGAN/HowTo_train.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
## How to Train a CycleGAN

This is a basic outline of how to train a CycleGAN using the provided train script.

### Prerequisites

1. Python environment with necessary dependencies installed.
2. Image datasets for both source and target domains.
3. `raygun` repository cloned to your local machine.
4. Configuration file (train_conf.json) specifying training parameters.

### Configuration JSON Parameters
The configuration JSON file contains several parameters that you can modify to customize your CycleGAN training. Here are the key parameters:

- "framework": Specifies the deep learning framework to use (e.g. "torch", "tensorflow").
- "system": Specifies the type of system to use for training, so in this case "CycleGAN".
- "job_command": Specifies the job command for running the training script (e.g "bsub", "-n 16", "-gpu "num=1"", "-q gpu_a100").
- "sources": Specifies the source domains and their corresponding paths, real names, and mask names.
- "common_voxel_size": Specifies the voxel size to cast all data into.
- "ndims": Specifies the number of dimensions for the input data.
- "batch_size": Specifies the batch size for training.
- "num_workers": Specifies the number of workers for data loading.
- "cache_size": Specifies the cache size for data loading.
- "scheduler": Specifies the scheduler type for adjusting learning rate during training.
- "scheduler_kwargs": Specifies the arguments for the scheduler.
- "g_optim_type": Specifies the optimizer type for the generator.
- "g_optim_kwargs": Specifies the arguments for the generator optimizer.
- "d_optim_type": Specifies the optimizer type for the discriminator.
- "d_optim_kwargs": Specifies the arguments for the discriminator optimizer.
- "loss_kwargs": Specifies the arguments for the loss functions.
- "gnet_type": Specifies the type of generator network architecture.
- "gnet_kwargs": Specifies the arguments for the generator network architecture.
- "pretrain_gnet": Specifies whether to pretrain the generator network.
- "dnet_type": Specifies the type of discriminator network architecture.
- "dnet_kwargs": Specifies the arguments for the discriminator network architecture.
- "spawn_subprocess": Specifies whether to spawn subprocesses for training.
- "side_length": Specifies the side length of the input image.
- "num_epochs": Specifies the number of training epochs.
- "log_every": Specifies the frequency of logging during training.
- "save_every": Specifies the frequency of saving models during training.
- "snapshot_every": Specifies the frequency of taking snapshots during training.

Here's an example of a CycleGan [configuration file]('../../experiments/ieee-isbi-2023/01_cycleGAN/train_conf.json)

### Training Methods

#### General training
- From the repository directory, run the following command:`rauygun-train CONFIG_FILE_LOCATION` where `CONGIF_FILE_LOCATION` is the relative path to the JSON configuration file for your training objective.

#### Batch training
- From the repository directory, run the following command: `rauygun-train-batch CONFIG_FILE_LOCATION` where `CONGIF_FILE_LOCATION` is the relative path to the JSON

#### Cluster training
- From the repository directory, run the following command: `rauygun-train-cluster CONFIG_FILE_LOCATION` where `CONGIF_FILE_LOCATION` is the relative path to the JSON


The CycleGAN training will start and progress will be displayed in the console.

Once the training is complete, the trained models will be saved in the specified output directory as per the configuration file.
Empty file added demo/01_cycleGAN/README.md
Empty file.
7 changes: 7 additions & 0 deletions notes/naming_conventions.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
CycleGAN System:

- default_config
- config
- common_voxel_size
- ndims
-
97 changes: 71 additions & 26 deletions src/raygun/torch/losses/BaseCompetentLoss.py
brianreicher marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,53 +1,98 @@
from raygun.evaluation.validate_affinities import run_validation
import torch
from torch.utils.tensorboard import SummaryWriter

from raygun.utils import passing_locals
from raygun.torch.losses import GANLoss


class BaseCompetentLoss(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
kwargs = passing_locals(locals())
for key, value in kwargs.items():
setattr(self, key, value)

if hasattr(self, "gan_mode"):
self.gan_loss = GANLoss(gan_mode=self.gan_mode)

self.loss_dict = {}

def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=False for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""Base loss function, implemented in PyTorch.

Args:
**kwargs:
Optional keyword arguments.
"""

def __init__(self, **kwargs) -> None:
super().__init__()
kwargs: dict = passing_locals(locals())
for key, value in kwargs.items():
setattr(self, key, value)

if hasattr(self, "gan_mode"):
self.gan_loss: GANLoss = GANLoss(gan_mode=self.gan_mode)

self.loss_dict: dict = {}

def set_requires_grad(self, nets:list, requires_grad=False) -> None:
"""Sets requies_grad=False for all the networks to avoid unnecessary computations.

Args:
nets (``list[torch.nn.Module, ...]``):
A list of networks.

requires_grad (``bool``):
Whether the networks require gradients or not.
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad

def crop(self, x, shape):
"""Center-crop x to match spatial dimensions given by shape."""
def crop(self, x:torch.Tensor, shape:tuple) -> torch.Tensor:
"""Center-crop x to match spatial dimensions given by shape.

Args:
x (``torch.Tensor``):
The tensor to center-crop.

shape (``tuple``):
The shape to match the crop to.

Returns:
``torch.Tensor``:
The center-cropped tensor to the spatial dimensions given.
"""

x_target_size = x.size()[: -self.dims] + shape
x_target_size:tuple = x.size()[: -self.dims] + shape

offset = tuple((a - b) // 2 for a, b in zip(x.size(), x_target_size))
offset: tuple = tuple((a - b) // 2 for a, b in zip(x.size(), x_target_size))

slices = tuple(slice(o, o + s) for o, s in zip(offset, x_target_size))
slices: tuple = tuple(slice(o, o + s) for o, s in zip(offset, x_target_size))

return x[slices]

def clamp_weights(self, net, min=-0.01, max=0.01):
def clamp_weights(self, net:torch.nn.Module, min=-0.01, max=0.01) -> None:
"""Clamp the weights of a given network.

Args:
net (``torch.nn.Module``):
The network to clamp.

min (``float``, optional):
The minimum value to clamp network weights to.

max (``float``, optional):
The maximum value to clamp network weights to.
"""

for module in net.model:
if hasattr(module, "weight") and hasattr(module.weight, "data"):
temp = module.weight.data
module.weight.data = temp.clamp(min, max)

def add_log(self, writer, step):
def add_log(self, writer, step) -> None:
"""Add an additional log to the writer, containing loss values and image examples.

Args:
writer (``SummaryWriter``):
The display writer to append the losses & images to.

step (``int``):
The current training step.
"""

# add loss values
for key, loss in self.loss_dict.items():
writer.add_scalar(key, loss, step)
Expand All @@ -69,7 +114,7 @@ def add_log(self, writer, step):
img = (img * 0.5) + 0.5
writer.add_image(name, img, global_step=step, dataformats="HW")

def update_status(self, step):
def update_status(self, step) -> None:
if hasattr(self, "validation_config") and (
step % self.validation_config["validate_every"] == 0
):
Expand Down
66 changes: 40 additions & 26 deletions src/raygun/torch/losses/GANLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,49 @@


class GANLoss(torch.nn.Module):
"""Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
"""
"""Define different GAN objectives. The GANLoss class abstracts away the need to create the target label tensor that has the same size as the input.

Args:
gan_mode (``string``):
The type of GAN objective. It currently supports vanilla, lsgan, and wgangp.

target_real_label (``float``, optional):
Label for a real image, with a default of 1.0.

target_fake_label (``float``, optional):
Label for a fake image fake image, with a default of 0.0.

def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
"""Initialize the GANLoss class.
Parameters:
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool) - - label for a real image
target_fake_label (bool) - - label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
"""
"""

def __init__(self, gan_mode:str, target_real_label=1.0, target_fake_label=0.0) -> None:
super(GANLoss, self).__init__()
self.register_buffer("real_label", torch.tensor(target_real_label))
self.register_buffer("fake_label", torch.tensor(target_fake_label))
self.gan_mode = gan_mode
self.gan_mode: str = gan_mode
if gan_mode == "lsgan":
self.loss = torch.nn.MSELoss()
self.loss: torch.nn.MSELoss = torch.nn.MSELoss()
elif gan_mode == "vanilla":
self.loss = torch.nn.BCEWithLogitsLoss()
self.loss: torch.nn.BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss()
elif gan_mode in ["wgangp"]:
self.loss = None
else:
raise NotImplementedError("gan mode %s not implemented" % gan_mode)

def get_target_tensor(self, prediction, target_is_real):
def get_target_tensor(self, prediction:torch.Tensor, target_is_real:bool) -> torch.Tensor:
"""Create label tensors with the same size as the input.
Parameters:
prediction (tensor) - - typically the prediction from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images

Args:
prediction (``torch.Tensor``):
Typically the prediction from a discriminator.

target_is_real (``bool``):
Boolean to determine the ground truth label is for real images or fake images.

Returns:
A label tensor filled with ground truth label, and with the size of the input
``torch.Tensor``:
A label tensor filled with ground truth label, and with the size of the input
"""

if target_is_real:
Expand All @@ -45,17 +54,22 @@ def get_target_tensor(self, prediction, target_is_real):
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)

def __call__(self, prediction, target_is_real):
def __call__(self, prediction:torch.Tensor, target_is_real:bool) -> float:
"""Calculate loss given Discriminator's output and grount truth labels.
Parameters:
prediction (tensor) - - typically the prediction output from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Args:
prediction (``torch.Tensor``):
Typically the prediction output from a discriminator.

target_is_real (``bool``):
Boolean to determine the ground truth label is for real images or fake images.

Returns:
the calculated loss.
``float``:
The calculated loss.
"""
if self.gan_mode in ["lsgan", "vanilla"]:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
target_tensor: torch.Tensor = self.get_target_tensor(prediction, target_is_real)
loss: float = self.loss(prediction, target_tensor)
elif self.gan_mode == "wgangp":
if target_is_real:
loss = -prediction.mean()
Expand Down
Loading
Loading