From d7560ca73ffe8567e992c9c38e63db02ec953ad6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 20 Aug 2023 10:41:00 +0800 Subject: [PATCH] rename variables --- vision_toolbox/backbones/vit.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index 5ca4049..da3f919 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -87,8 +87,8 @@ def forward(self, x: Tensor) -> Tensor: class ViT(nn.Module): def __init__( self, - n_layers: int, d_model: int, + depth: int, n_heads: int, patch_size: int, img_size: int, @@ -113,7 +113,7 @@ def __init__( nn.init.normal_(self.pe, 0, 0.02) self.layers = nn.Sequential() - for _ in range(n_layers): + for _ in range(depth): block = ViTBlock(d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act) self.layers.append(block) @@ -145,14 +145,14 @@ def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None: @staticmethod def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> ViT: - n_layers, d_model, n_heads = dict( - Ti=(12, 192, 3), - S=(12, 384, 6), - B=(12, 768, 12), - L=(24, 1024, 16), - H=(32, 1280, 16), + d_model, depth, n_heads = dict( + Ti=(192, 12, 3), + S=(384, 12, 6), + B=(768, 12, 12), + L=(1024, 24, 16), + H=(1280, 32, 16), )[variant] - m = ViT(n_layers, d_model, n_heads, patch_size, img_size) + m = ViT(d_model, depth, n_heads, patch_size, img_size) if pretrained: ckpt = {