From c5f18896cd41b3933d8686be52ae249539398c05 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 29 Oct 2023 14:48:55 +0800 Subject: [PATCH] fix deit --- vision_toolbox/backbones/deit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vision_toolbox/backbones/deit.py b/vision_toolbox/backbones/deit.py index a5252c8..ae3e486 100644 --- a/vision_toolbox/backbones/deit.py +++ b/vision_toolbox/backbones/deit.py @@ -28,8 +28,8 @@ def __init__( ) -> None: # fmt: off super().__init__( - d_model, depth, n_heads, patch_size, img_size, True, bias, mlp_ratio, - dropout, layer_scale_init, stochastic_depth, norm_eps + d_model, depth, n_heads, patch_size, img_size, True, "cls_token", bias, + mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps, ) # fmt: on self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) @@ -133,7 +133,7 @@ def __init__( ): # fmt: off super().__init__( - d_model, depth, n_heads, patch_size, img_size, cls_token, bias, + d_model, depth, n_heads, patch_size, img_size, cls_token, "cls_token", bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps, ) # fmt: on