Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jul 9, 2023
1 parent a36e947 commit ee98523
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions vision_toolbox/backbones/mobile_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import Tensor, nn

from ..components import ConvBnAct
from ..components import ConvNormAct
from .base import BaseBackbone


Expand Down Expand Up @@ -37,9 +37,9 @@ def __init__(self, in_channels: int, out_channels: int, stride: int, expand_rati

self.layers = nn.Sequential()
if expand_ratio != 1:
self.layers.append(ConvBnAct(in_channels, hidden_dim, 1, 1, 0, act="swish"))
self.layers.append(ConvBnAct(hidden_dim, hidden_dim, 3, 2, 1, act="swish"))
self.layers.append(ConvBnAct(hidden_dim, out_channels, 1, 1, 0, act="none"))
self.layers.append(ConvNormAct(in_channels, hidden_dim, 1, act="swish"))
self.layers.append(ConvNormAct(hidden_dim, hidden_dim, 3, stride, act="swish"))
self.layers.append(ConvNormAct(hidden_dim, out_channels, 1, act="none"))

def forward(self, x: Tensor) -> Tensor:
return x + self.layers(x) if self.res else self.layers(x)
Expand Down Expand Up @@ -69,16 +69,16 @@ def __init__(self, in_channels: int, d_model: int, n_layers: int, dropout: float
super().__init__()
self.local_rep = nn.Sequential(
dict(
conv_3x3=ConvBnAct(in_channels, in_channels, act="swish"),
conv_1x1=ConvBnAct(in_channels, d_model, 1, 1, 0, act="swish"),
conv_3x3=ConvNormAct(in_channels, in_channels, act="swish"),
conv_1x1=ConvNormAct(in_channels, d_model, 1, act="swish"),
)
)
layer = nn.TransformerEncoderLayer(
d_model, d_model // self.head_dim, d_model * 2, dropout, "gelu", norm_eps, True, True
)
self.global_rep = nn.TransformerEncoder(layer, n_layers)
self.conv_proj = (ConvBnAct(d_model, in_channels, 1, 1, 0, act="swish"),)
self.fusion = ConvBnAct(in_channels * 2, in_channels, act="swish")
self.conv_proj = (ConvNormAct(d_model, in_channels, 1, act="swish"),)
self.fusion = ConvNormAct(in_channels * 2, in_channels, act="swish")

def forward(self, x: Tensor) -> Tensor:
out = self.local_rep(x)
Expand All @@ -97,7 +97,7 @@ class MobileViTStageConfig(NamedTuple):
class MobileViT(BaseBackbone):
def __init__(self, out_channels_list: list[int], out_channels: int, mv2_expand_ratio: int, d_models: list[int]):
super().__init__()
self.stem = ConvBnAct(3, 16, 3, 2, 1, act="swish")
self.stem = ConvNormAct(3, 16, 3, 2, act="swish")
self.stages = nn.ModuleList()
stage1 = nn.Sequential(
InvertedResidual(16, out_channels_list[0], 1, mv2_expand_ratio),
Expand Down

0 comments on commit ee98523

Please sign in to comment.