Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Segformer decoder performance #998

Merged
merged 1 commit into from
Dec 6, 2024

Conversation

brianhou0208
Copy link
Contributor

@brianhou0208 brianhou0208 commented Dec 5, 2024

Hi @omarequalmars,

This PR addresses a performance issue in the Segformer decoder caused by unnecessary .contiguous() operations. These operations significantly slowed down the training process. By removing the redundant .contiguous() calls, the throughput has been restored to the expected level.

Fixes #996

I ran the following experiments and checks to replicate the anomaly under your configuration:

  • Compared the implementations of Transformers and SMP, along with their throughput.
  • Tested throughput using your hyperparameters.
  • Tested throughput using default hyperparameters.

os: linux
gpu: v100
reference: transformers/segformer

Issue

resolve #996

Fix up

The slowdown was caused by an unnecessary .contiguous() call in the decoder pipeline.

x = x.transpose(1, 2).reshape(batch, -1, height, width).contiguous()

This PR removes the redundant operation, as shown below:

x = x.transpose(1, 2).reshape(batch, -1, height, width)

Transformers & SMP

Since our implementation references Transformers, I directly compared the throughput after applying the fix.
The throughput of both implementations aligns after the fix:

Transformers 1068.5092061961527 images/s @ batch size 100
SMP 1073.6746549135887 images/s @ batch size 100
test code
import time
import torch
import requests
from PIL import Image
import segmentation_models_pytorch as smp
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

T0 = 5
T1 = 10
def map_weights(state_dict: dict):
    if "state_dict" in state_dict:
        state_dict = state_dict["state_dict"]

    mapped_state_dict = {
        # Map backbone to encoder
        k.replace("backbone", "encoder"): v
        for k, v in state_dict.items()
        if k.startswith("backbone")
    }

    # Map the linear_cX layers to MLP stages
    for i in range(4):
        base = f"decode_head.linear_c{i+1}.proj"
        mapped_state_dict[f"decoder.mlp_stage.{3-i}.linear.weight"] = state_dict[
            f"{base}.weight"
        ]
        mapped_state_dict[f"decoder.mlp_stage.{3-i}.linear.bias"] = state_dict[
            f"{base}.bias"
        ]

    # Map fuse_stage components
    fuse_base = "decode_head.linear_fuse"
    mapped_state_dict.update(
        {
            "decoder.fuse_stage.0.weight": state_dict[f"{fuse_base}.conv.weight"],
            "decoder.fuse_stage.1.weight": state_dict[f"{fuse_base}.bn.weight"],
            "decoder.fuse_stage.1.bias": state_dict[f"{fuse_base}.bn.bias"],
            "decoder.fuse_stage.1.running_mean": state_dict[
                f"{fuse_base}.bn.running_mean"
            ],
            "decoder.fuse_stage.1.running_var": state_dict[
                f"{fuse_base}.bn.running_var"
            ],
            "decoder.fuse_stage.1.num_batches_tracked": state_dict[
                f"{fuse_base}.bn.num_batches_tracked"
            ],
        }
    )

    # Map final layer components
    mapped_state_dict["segmentation_head.0.weight"] = state_dict[
        "decode_head.linear_pred.weight"
    ]
    mapped_state_dict["segmentation_head.0.bias"] = state_dict[
        "decode_head.linear_pred.bias"
    ]

    return mapped_state_dict


def segformer_smp(inputs, path="./segformer.b0.512x1024.city.160k.pth"):
    original_checkpoint = torch.load(
        path, map_location="cpu", weights_only=False
    )

    num_classes = int(
        original_checkpoint["meta"]["config"].split("num_classes=")[1].split(",\n")[0]
    )
    decoder_dims = int(original_checkpoint["meta"]["config"].split("embed_dim=")[1][:3])
    smp_state_dict = map_weights(original_checkpoint)
    print(
        f"Pretrain Weight setting: num_classes={num_classes} "
        f"decoder_embed_dim={decoder_dims}"
    )
    inputs['pixel_values'] = inputs['pixel_values'].to('cuda:0')
    model = smp.create_model(
        in_channels=3,
        classes=num_classes,
        arch="segformer",
        encoder_name="mit_b0",
        encoder_weights=None,
        decoder_segmentation_channels=decoder_dims,
    ).eval().to('cuda:0')
    model.load_state_dict(smp_state_dict, strict=False)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    start = time.time()
    while time.time() - start < T0:
        output = model(inputs['pixel_values'])
        output = torch.softmax(output, dim=1)
        output = torch.argmax(output, dim=1)
    timing = []
    torch.cuda.synchronize()
    while sum(timing) < T1:
        start = time.time()
        output = model(inputs['pixel_values'])
        output = torch.softmax(output, dim=1)
        output = torch.argmax(output, dim=1)
        torch.cuda.synchronize()
        timing.append(time.time() - start)
    timing = torch.as_tensor(timing, dtype=torch.float32)
    print('SMP', 100 / timing.mean().item(), 'images/s @ batch size', 100)
    del model
    return output

def segformer_tf(inputs, path):
    model = SegformerForSemanticSegmentation.from_pretrained(path)
    model.eval().to('cuda:0')
    inputs = inputs.to('cuda:0')
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    start = time.time()
    while time.time() - start < T0:
        outputs = model(**inputs)
        logits = outputs.logits
        logits = torch.nn.functional.interpolate(
            logits, scale_factor=4, mode='bilinear', align_corners=True)
        logits = torch.softmax(logits, dim=1)
        logits = torch.argmax(logits, dim=1)
    timing = []
    torch.cuda.synchronize()
    while sum(timing) < T1:
        start = time.time()
        outputs = model(**inputs)
        logits = outputs.logits
        logits = torch.nn.functional.interpolate(
            logits, scale_factor=4, mode='bilinear', align_corners=True)
        logits = torch.softmax(logits, dim=1)
        logits = torch.argmax(logits, dim=1)
        torch.cuda.synchronize()
        timing.append(time.time() - start)
    timing = torch.as_tensor(timing, dtype=torch.float32)
    print('Transformers', 100 / timing.mean().item(), 'images/s @ batch size', 100)
    del model
    return logits

    
if __name__ == "__main__":
    model_path_hugging = "nvidia/segformer-b0-finetuned-cityscapes-512-1024"
    url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
    image = Image.open(requests.get(url, stream=True).raw).resize((256, 256))
    processor = SegformerImageProcessor.from_pretrained(model_path_hugging)
    inputs = processor(images=image, return_tensors="pt")
    inputs['pixel_values'] = inputs['pixel_values'].repeat(100, 1, 1, 1)
    inputs['pixel_values'] = torch.nn.functional.interpolate(inputs['pixel_values'], (256, 256))
    output_tf = segformer_tf(inputs, model_path_hugging)
    output_smp = segformer_smp(inputs)

Custom hyperparameters

Tested your hyperparameters:
image

test code
import torch
import segmentation_models_pytorch as smp
import time

T0 = 5
T1 = 10
def get_throughput(model, batch_size=100, resolution=128):
    model.eval()
    model.to('cuda:0')
    inputs = torch.randn(batch_size, 3, resolution, resolution, device='cuda:0')
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    start = time.time()
    while time.time() - start < T0:
        model(inputs)
    timing = []
    torch.cuda.synchronize()
    while sum(timing) < T1:
        start = time.time()
        model(inputs)
        torch.cuda.synchronize()
        timing.append(time.time() - start)
    timing = torch.as_tensor(timing, dtype=torch.float32)
    print(batch_size / timing.mean().item(), 'images/s @ batch size', batch_size)

timm_name = 'tu-mobilevit_xxs'

model_unet = smp.Unet(
    encoder_name=timm_name,
    encoder_depth=5,
    encoder_weights=None,
    decoder_use_batchnorm=True,
    decoder_channels=[256, 128, 64, 32, 16]
)
model_deeplabv3p = smp.DeepLabV3Plus(
    encoder_name=timm_name,
    encoder_depth=5,
    encoder_weights=None,
    encoder_output_stride=16,
    decoder_channels=512,
    decoder_atrous_rates=(16, 32, 128),
)
model_segformer = smp.Segformer(
    encoder_name=timm_name,
    encoder_depth=5,
    encoder_weights=None,
    decoder_segmentation_channels=512
)

if __name__ == '__main__':
    get_throughput(model_unet, 100,  resolution=64)
    get_throughput(model_unet, 100, resolution=128)
    get_throughput(model_unet, 100, resolution=256)
    # get_throughput(model_unet, 100, resolution=512)
    del model_unet
    get_throughput(model_deeplabv3p, 100, resolution=64)
    get_throughput(model_deeplabv3p, 100, resolution=128)
    get_throughput(model_deeplabv3p, 100, resolution=256)
    # get_throughput(model_deeplabv3p, 100, resolution=512)
    del model_deeplabv3p
    get_throughput(model_segformer, 100, resolution=64)
    get_throughput(model_segformer, 100, resolution=128)
    get_throughput(model_segformer, 100, resolution=256)
    # get_throughput(model_segformer, 100, resolution=512)
    del model_segformer

Default hyperparameters

image

test code
import torch
import segmentation_models_pytorch as smp
import time

T0 = 5
T1 = 10
def get_throughput(model, batch_size=100, resolution=128):
    model.eval()
    model.to('cuda:0')
    inputs = torch.randn(batch_size, 3, resolution, resolution, device='cuda:0')
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    start = time.time()
    while time.time() - start < T0:
        model(inputs)
    timing = []
    torch.cuda.synchronize()
    while sum(timing) < T1:
        start = time.time()
        model(inputs)
        torch.cuda.synchronize()
        timing.append(time.time() - start)
    timing = torch.as_tensor(timing, dtype=torch.float32)
    print(batch_size / timing.mean().item(), 'images/s @ batch size', batch_size)

timm_name = 'tu-mobilevit_xxs'

model_unet = smp.Unet(
    encoder_name=timm_name,
    encoder_depth=5,
    encoder_weights=None,
)
model_deeplabv3p = smp.DeepLabV3Plus(
    encoder_name=timm_name,
    encoder_depth=5,
    encoder_weights=None,
)
model_segformer = smp.Segformer(
    encoder_name=timm_name,
    encoder_depth=5,
    encoder_weights=None,
)

if __name__ == '__main__':
    get_throughput(model_unet, 100,  resolution=64)
    get_throughput(model_unet, 100, resolution=128)
    get_throughput(model_unet, 100, resolution=256)
    # get_throughput(model_unet, 100, resolution=512)
    del model_unet
    get_throughput(model_deeplabv3p, 100, resolution=64)
    get_throughput(model_deeplabv3p, 100, resolution=128)
    get_throughput(model_deeplabv3p, 100, resolution=256)
    # get_throughput(model_deeplabv3p, 100, resolution=512)
    del model_deeplabv3p
    get_throughput(model_segformer, 100, resolution=64)
    get_throughput(model_segformer, 100, resolution=128)
    get_throughput(model_segformer, 100, resolution=256)
    # get_throughput(model_segformer, 100, resolution=512)
    del model_segformer

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix and detailed report!

@qubvel qubvel merged commit 5a1a8f1 into qubvel-org:main Dec 6, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

SegFormer Training extremely slow
2 participants