Skip to content

Commit

Permalink
add Swin
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 14, 2023
1 parent 232a50b commit fd9a270
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion vision_toolbox/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .darknet import Darknet, DarknetYOLOv5
from .mlp_mixer import MLPMixer
from .patchconvnet import PatchConvNet
from .swin import SwinTransformer
from .torchvision_models import EfficientNetExtractor, MobileNetExtractor, RegNetExtractor, ResNetExtractor
from .vit import ViT
from .vovnet import VoVNet
from .swin import SwinTransformer
11 changes: 7 additions & 4 deletions vision_toolbox/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,14 @@ def __init__(

self.head_norm = norm(d_model)

def get_feature_maps(self, x: Tensor) -> list[Tensor]:
out = [self.dropout(self.norm(self.patch_embed(x).permute(0, 2, 3, 1)))]
for stage in self.stages:
out.append(stage(out[-1]))
return out[1:]

def forward(self, x: Tensor) -> Tensor:
x = self.dropout(self.norm(self.patch_embed(x).permute(0, 2, 3, 1)))
x = self.stages(x)
x = self.head_norm(x).mean((1, 2))
return x
return self.head_norm(self.get_feature_maps(x)[-1]).mean((1, 2))

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTransformer:
Expand Down

0 comments on commit fd9a270

Please sign in to comment.