From fd9a27001994db517619839c1790edf72f0de64f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 14 Aug 2023 22:21:22 +0800 Subject: [PATCH] add Swin --- vision_toolbox/backbones/__init__.py | 2 +- vision_toolbox/backbones/swin.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vision_toolbox/backbones/__init__.py b/vision_toolbox/backbones/__init__.py index 3ca3664..55a1857 100644 --- a/vision_toolbox/backbones/__init__.py +++ b/vision_toolbox/backbones/__init__.py @@ -1,7 +1,7 @@ from .darknet import Darknet, DarknetYOLOv5 from .mlp_mixer import MLPMixer from .patchconvnet import PatchConvNet +from .swin import SwinTransformer from .torchvision_models import EfficientNetExtractor, MobileNetExtractor, RegNetExtractor, ResNetExtractor from .vit import ViT from .vovnet import VoVNet -from .swin import SwinTransformer diff --git a/vision_toolbox/backbones/swin.py b/vision_toolbox/backbones/swin.py index 7ddb57d..124d179 100644 --- a/vision_toolbox/backbones/swin.py +++ b/vision_toolbox/backbones/swin.py @@ -166,11 +166,14 @@ def __init__( self.head_norm = norm(d_model) + def get_feature_maps(self, x: Tensor) -> list[Tensor]: + out = [self.dropout(self.norm(self.patch_embed(x).permute(0, 2, 3, 1)))] + for stage in self.stages: + out.append(stage(out[-1])) + return out[1:] + def forward(self, x: Tensor) -> Tensor: - x = self.dropout(self.norm(self.patch_embed(x).permute(0, 2, 3, 1))) - x = self.stages(x) - x = self.head_norm(x).mean((1, 2)) - return x + return self.head_norm(self.get_feature_maps(x)[-1]).mean((1, 2)) @staticmethod def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTransformer: