Skip to content

Commit

Permalink
make timm optional dependency
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #4739

Reviewed By: lyttonhao

Differential Revision: D42397187

fbshipit-source-id: 435394600fec5d6db5feee3eb91cea110b73f85a
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Jan 7, 2023
1 parent b4b6cc9 commit 95a87b8
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 23 deletions.
11 changes: 6 additions & 5 deletions detectron2/modeling/backbone/mvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import numpy as np
import torch
import torch.nn as nn
from timm.models.layers import DropPath, Mlp, trunc_normal_

from .backbone import Backbone
from .utils import (
Expand Down Expand Up @@ -123,8 +122,8 @@ def __init__(
self.rel_pos_w = nn.Parameter(torch.zeros(rel_dim, head_dim))

if not rel_pos_zero_init:
trunc_normal_(self.rel_pos_h, std=0.02)
trunc_normal_(self.rel_pos_w, std=0.02)
nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
nn.init.trunc_normal_(self.rel_pos_w, std=0.02)

def forward(self, x):
B, H, W, _ = x.shape
Expand Down Expand Up @@ -235,6 +234,8 @@ def __init__(
input_size=input_size,
)

from timm.models.layers import DropPath, Mlp

self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim_out)
self.mlp = Mlp(
Expand Down Expand Up @@ -414,13 +415,13 @@ def __init__(
self._last_block_indexes = last_block_indexes

if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.pos_embed, std=0.02)

self.apply(self._init_weights)

def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
Expand Down
24 changes: 15 additions & 9 deletions detectron2/modeling/backbone/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

from detectron2.modeling.backbone.backbone import Backbone

_to_2tuple = nn.modules.utils._ntuple(2)


class Mlp(nn.Module):
"""Multilayer perceptron."""
Expand Down Expand Up @@ -130,7 +131,7 @@ def __init__(
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

trunc_normal_(self.relative_position_bias_table, std=0.02)
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)

def forward(self, x, mask=None):
Expand Down Expand Up @@ -219,15 +220,20 @@ def __init__(
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
window_size=_to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)

self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
if drop_path > 0.0:
from timm.models.layers import DropPath

self.drop_path = DropPath(drop_path)
else:
self.drop_path = nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
Expand Down Expand Up @@ -470,7 +476,7 @@ class PatchEmbed(nn.Module):

def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
patch_size = _to_2tuple(patch_size)
self.patch_size = patch_size

self.in_chans = in_chans
Expand Down Expand Up @@ -571,8 +577,8 @@ def __init__(

# absolute position embedding
if self.ape:
pretrain_img_size = to_2tuple(pretrain_img_size)
patch_size = to_2tuple(patch_size)
pretrain_img_size = _to_2tuple(pretrain_img_size)
patch_size = _to_2tuple(patch_size)
patches_resolution = [
pretrain_img_size[0] // patch_size[0],
pretrain_img_size[1] // patch_size[1],
Expand All @@ -581,7 +587,7 @@ def __init__(
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
)
trunc_normal_(self.absolute_pos_embed, std=0.02)
nn.init.trunc_normal_(self.absolute_pos_embed, std=0.02)

self.pos_drop = nn.Dropout(p=drop_rate)

Expand Down Expand Up @@ -648,7 +654,7 @@ def _freeze_stages(self):

def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
Expand Down
11 changes: 6 additions & 5 deletions detectron2/modeling/backbone/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import fvcore.nn.weight_init as weight_init
import torch
import torch.nn as nn
from timm.models.layers import DropPath, Mlp, trunc_normal_

from detectron2.layers import CNNBlockBase, Conv2d, get_norm
from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
Expand Down Expand Up @@ -60,8 +59,8 @@ def __init__(
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))

if not rel_pos_zero_init:
trunc_normal_(self.rel_pos_h, std=0.02)
trunc_normal_(self.rel_pos_w, std=0.02)
nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
nn.init.trunc_normal_(self.rel_pos_w, std=0.02)

def forward(self, x):
B, H, W, _ = x.shape
Expand Down Expand Up @@ -189,6 +188,8 @@ def __init__(
input_size=input_size if window_size == 0 else (window_size, window_size),
)

from timm.models.layers import DropPath, Mlp

self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
Expand Down Expand Up @@ -332,13 +333,13 @@ def __init__(
self._out_features = [out_feature]

if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.pos_embed, std=0.02)

self.apply(self._init_weights)

def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ known_standard_library=numpy,setuptools,mock
skip=./datasets,docs
skip_glob=*/__init__.py,**/configs/**,**/tests/config/**
known_myself=detectron2
known_third_party=fvcore,matplotlib,cv2,torch,torchvision,PIL,pycocotools,yacs,termcolor,cityscapesscripts,tabulate,tqdm,scipy,lvis,psutil,pkg_resources,caffe2,onnx,panopticapi,black,isort,av,iopath,omegaconf,hydra,yaml,pydoc,submitit,cloudpickle,packaging,timm,pandas,fairscale,pytorch3d
known_third_party=fvcore,matplotlib,cv2,torch,torchvision,PIL,pycocotools,yacs,termcolor,cityscapesscripts,tabulate,tqdm,scipy,lvis,psutil,pkg_resources,caffe2,onnx,panopticapi,black,isort,av,iopath,omegaconf,hydra,yaml,pydoc,submitit,cloudpickle,packaging,timm,pandas,fairscale,pytorch3d,pytorch_lightning
no_lines_before=STDLIB,THIRDPARTY
sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER
default_section=FIRSTPARTY
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def get_model_zoo_configs() -> List[str]:
"omegaconf>=2.1",
"hydra-core>=1.1",
"black",
"timm",
"packaging",
# NOTE: When adding new dependencies, if it is required at import time (in addition
# to runtime), it probably needs to appear in docs/requirements.txt, or as a mock
Expand All @@ -195,6 +194,7 @@ def get_model_zoo_configs() -> List[str]:
# optional dependencies, required by some features
"all": [
"fairscale",
"timm", # Used by a few ViT models.
"scipy>1.5.1",
"shapely",
"pygments>=2.2",
Expand Down
4 changes: 2 additions & 2 deletions tools/lightning_train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import weakref
from collections import OrderedDict
from typing import Any, Dict, List
import pytorch_lightning as pl # type: ignore
from pytorch_lightning import LightningDataModule, LightningModule

import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
Expand All @@ -31,8 +33,6 @@
from detectron2.utils.events import EventStorage
from detectron2.utils.logger import setup_logger

import pytorch_lightning as pl # type: ignore
from pytorch_lightning import LightningDataModule, LightningModule
from train_net import build_evaluator

logging.basicConfig(level=logging.INFO)
Expand Down

0 comments on commit 95a87b8

Please sign in to comment.