Skip to content

Latest commit

 

History

History
253 lines (152 loc) · 10.1 KB

train.md

File metadata and controls

253 lines (152 loc) · 10.1 KB

Train a ControlNet to Control SD

You are here because you want to control SD in your own way, maybe you have an idea for your perfect research project, and you will annotate some data or have already annotated your own dataset automatically or manually. Herein, the control can be anything that can be converted to images, such as edges, keypoints, segments, etc.

Before moving on to your own dataset, we highly recommend to first try the toy dataset, Fill50K, as a sanity check. This will help you get a "feeling" for the training. You will know how long it will take for the model to converge and whether your device will be able to complete the training in an acceptable amount of time. And what it "feels" like when the model converges.

We hope that after you read this page, you will find that training a ControlNet is as easy as (or easier than) training a pix2pix.

Step 0 - Design your control

Let us take a look at a very simple task to control SD to fill color in circles.

p

This is simple: we want to control SD to fill a circle with colors, and the prompt contains some description of our target.

Stable diffusion is trained on billions of images, and it already knows what is "cyan", what is "circle", what is "pink", and what is "background".

But it does not know the meaning of that "Control Image (Source Image)". Our target is to let it know.

Step 1 - Get a dataset

Just download the Fill50K dataset from our huggingface page (training/fill50k.zip, the file is only 200M!). Make sure that the data is decompressed as

ControlNet/training/fill50k/prompt.json
ControlNet/training/fill50k/source/X.png
ControlNet/training/fill50k/target/X.png

In the folder "fill50k/source", you will have 50k images of circle lines.

p

In the folder "fill50k/target", you will have 50k images of filled circles.

p

In the "fill50k/prompt.json", you will have their filenames and prompts. Each prompt is like "a balabala color circle in some other color background."

p

Step 2 - Load the dataset

Then you need to write a simple script to read this dataset for pytorch. (In fact we have written it for you in "tutorial_dataset.py".)

import json
import cv2
import numpy as np

from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self):
        self.data = []
        with open('./training/fill50k/prompt.json', 'rt') as f:
            for line in f:
                self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        source_filename = item['source']
        target_filename = item['target']
        prompt = item['prompt']

        source = cv2.imread('./training/fill50k/' + source_filename)
        target = cv2.imread('./training/fill50k/' + target_filename)

        # Do not forget that OpenCV read images in BGR order.
        source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)

        # Normalize source images to [0, 1].
        source = source.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target = (target.astype(np.float32) / 127.5) - 1.0

        return dict(jpg=target, txt=prompt, hint=source)

This will make your dataset into an array-like object in python. You can test this dataset simply by accessing the array, like this

from tutorial_dataset import MyDataset

dataset = MyDataset()
print(len(dataset))

item = dataset[1234]
jpg = item['jpg']
txt = item['txt']
hint = item['hint']
print(txt)
print(jpg.shape)
print(hint.shape)

The outputs of this simple test on my machine are

50000
burly wood circle with orange background
(512, 512, 3)
(512, 512, 3)

And this code is in "tutorial_dataset_test.py".

In this way, the dataset is an array-like object with 50000 items. Each item is a dict with three entry "jpg", "txt", and "hint". The "jpg" is the target image, the "hint" is the control image, and the "txt" is the prompt.

Do not ask us why we use these three names - this is related to the dark history of a library called LDM.

Step 3 - What SD model do you want to control?

Then you need to decide which Stable Diffusion Model you want to control. In this example, we will just use standard SD1.5. You can download it from the official page of Stability. You want the file "v1-5-pruned.ckpt".

Then you need to attach a control net to the SD model. The architecture is

img

Note that all weights inside the ControlNet are also copied from SD so that no layer is trained from scratch, and you are still finetuning the entire model.

We provide a simple script for you to achieve this easily. If your SD filename is "./models/v1-5-pruned.ckpt" and you want the script to save the processed model (SD+ControlNet) at location "./models/control_sd15_ini.ckpt", you can just run:

python tool_add_control.py ./models/v1-5-pruned.ckpt ./models/control_sd15_ini.ckpt

You may also use other filenames as long as the command is "python tool_add_control.py input_path output_path".

This is the correct output from my machine:

img

Step 4 - Train!

Happy! We finally come to the most exciting part: training!

The training code in "tutorial_train.py" is actually surprisingly simple:

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from tutorial_dataset import MyDataset
from cldm.logger import ImageLogger
from cldm.model import create_model, load_state_dict


# Configs
resume_path = './models/control_sd15_ini.ckpt'
batch_size = 4
logger_freq = 300
learning_rate = 1e-5
sd_locked = True
only_mid_control = False


# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
model = create_model('./models/cldm_v15.yaml').cpu()
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
model.learning_rate = learning_rate
model.sd_locked = sd_locked
model.only_mid_control = only_mid_control


# Misc
dataset = MyDataset()
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
logger = ImageLogger(batch_frequency=logger_freq)
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])


# Train!
trainer.fit(model, dataloader)

Thanks to our organized dataset pytorch object and the power of pytorch_lightning, the entire code is just super short.

Now, you may take a look at Pytorch Lightning Official DOC to find out how to enable many useful features like gradient accumulation, multiple GPU training, accelerated dataset loading, flexible checkpoint saving, etc. All these only need about one line of code. Great!

Note that if you find OOM, perhaps you need to enable Low VRAM mode, and perhaps you also need to use smaller batch size and gradient accumulation. Or you may also want to use some “advanced” tricks like sliced attention or xformers. For example:

# Configs
batch_size = 1

# Misc
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger], accumulate_grad_batches=4)  # But this will be 4x slower

Note that training with 8 GB laptop GPU is challenging. We will need some GPU memory optimization at least as good as automatic1111’s UI. This may require expert modifications to the code.

Screenshots

The training is fast. After 4000 steps (batch size 4, learning rate 1e-5, about 50 minutes on PCIE 40G), the results on my machine (in an output folder "image_log") is

Control:

img

Prompt:

img

Prediction:

img

Ground Truth:

img

Note that the SD's capability is preserved. Even training on this super aligned dataset, it still draws some random textures and those snow decorations. (Besides, note that the ground truth looks a bit modified because it is converted from SD's latent image.)

Larger batch size and longer training will further improve this. Adequate training will make the filling perfect.

Of course, training SD to fill circles is meaningless, but this is a successful beginning of your story.

Let us work together to control large models more and more.

Other options

Beyond standard things, we also provide two important parameters "sd_locked" and "only_mid_control" that you need to know.

only_mid_control

By default, only_mid_control is False. When it is True, you will train the below architecture.

img

This can be helpful when your computation power is limited and want to speed up the training, or when you want to facilitate the "global" context learning. Note that sometimes you may pause training, set it to True, resume training, and pause again, and set it again, and resume again.

If your computation device is good, perhaps you do not need this. But I also know some artists are willing to train a model on their laptop for a month - in that case, perhaps this option can be useful.

sd_locked

By default, sd_locked is True. When it is False, you will train the below architecture.

img

This will unlock some layers in SD and you will train them as a whole.

This option is DANGEROUS! If your dataset is not good enough, this may downgrade the capability of your SD model.

However, this option is also very useful when you are training on images with some specific style, or when you are training with special datasets (like medical dataset with X-ray images or geographic datasets with lots of Google Maps). You can understand this as simultaneously training the ControlNet and something like a DreamBooth.

Also, if your dataset is large, you may want to end the training with a few thousands of steps with those layer unlocked. This usually improve the "problem-specific" solutions a little. You may try it yourself to feel the difference.

Also, if you unlock some original layers, you may want a lower learning rate, like 2e-6.