Skip to content

Commit

Permalink
rename variables
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 20, 2023
1 parent efe4c9e commit d7560ca
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def forward(self, x: Tensor) -> Tensor:
class ViT(nn.Module):
def __init__(
self,
n_layers: int,
d_model: int,
depth: int,
n_heads: int,
patch_size: int,
img_size: int,
Expand All @@ -113,7 +113,7 @@ def __init__(
nn.init.normal_(self.pe, 0, 0.02)

self.layers = nn.Sequential()
for _ in range(n_layers):
for _ in range(depth):
block = ViTBlock(d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act)
self.layers.append(block)

Expand Down Expand Up @@ -145,14 +145,14 @@ def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:

@staticmethod
def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> ViT:
n_layers, d_model, n_heads = dict(
Ti=(12, 192, 3),
S=(12, 384, 6),
B=(12, 768, 12),
L=(24, 1024, 16),
H=(32, 1280, 16),
d_model, depth, n_heads = dict(
Ti=(192, 12, 3),
S=(384, 12, 6),
B=(768, 12, 12),
L=(1024, 24, 16),
H=(1280, 32, 16),
)[variant]
m = ViT(n_layers, d_model, n_heads, patch_size, img_size)
m = ViT(d_model, depth, n_heads, patch_size, img_size)

if pretrained:
ckpt = {
Expand Down

0 comments on commit d7560ca

Please sign in to comment.