Skip to content

Commit

Permalink
add weight loading
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 19, 2023
1 parent 6205cc3 commit cb2c5f8
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
14 changes: 7 additions & 7 deletions tests/test_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ def test_forward():
m(torch.randn(1, 3, 224, 224))


# def test_from_pretrained():
# m = ViT.from_config("Ti", 16, 224, True).eval()
# x = torch.randn(1, 3, 224, 224)
# out = m(x)
def test_from_pretrained():
m = ConvNeXt.from_config("T", True).eval()
x = torch.randn(1, 3, 224, 224)
out = m(x)

# m_timm = timm.create_model("vit_tiny_patch16_224.augreg_in21k", pretrained=True, num_classes=0).eval()
# out_timm = m_timm(x)
m_timm = timm.create_model("convnext_tiny.fb_in22k", pretrained=True, num_classes=0).eval()
out_timm = m_timm(x)

# torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5)
torch.testing.assert_close(out, out_timm)
39 changes: 38 additions & 1 deletion vision_toolbox/backbones/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,43 @@ def from_config(variant: str, pretrained: bool = False) -> ConvNeXt:
m = ConvNeXt(d_model, depths)

if pretrained:
pass
# TODO: also add torchvision checkpoints?
ckpt = dict(
T="convnext_tiny_22k_224.pth",
S="convnext_small_22k_224.pth",
B="convnext_base_22k_224.pth",
L="convnext_large_22k_224.pth",
XL="convnext_xlarge_22k_224.pth",
)[variant]
base_url = "https://dl.fbaipublicfiles.com/convnext/"
state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"]
m.load_official_ckpt(state_dict)

return m

@torch.no_grad()
def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None:
def copy_(m: nn.Conv2d | nn.Linear | nn.LayerNorm, prefix: str):
m.weight.copy_(state_dict.pop(prefix + ".weight"))
m.bias.copy_(state_dict.pop(prefix + ".bias"))

copy_(self.stem[0], "downsample_layers.0.0") # Conv2d
copy_(self.stem[2], "downsample_layers.0.1") # LayerNorm

for stage_idx, stage in enumerate(self.stages):
if stage_idx > 0:
copy_(stage[0][0], f"downsample_layers.{stage_idx}.0") # LayerNorm
copy_(stage[0][2], f"downsample_layers.{stage_idx}.1") # Conv2d

for block_idx in range(1, len(stage)):
block: ConvNeXtBlock = stage[block_idx]
prefix = f"stages.{stage_idx}.{block_idx - 1}."

copy_(block.layers[1], prefix + "dwconv")
copy_(block.layers[3], prefix + "norm")
copy_(block.layers[4], prefix + "pwconv1")
copy_(block.layers[6], prefix + "pwconv2")
block.layer_scale.copy_(state_dict.pop(prefix + "gamma"))

copy_(self.head_norm, "norm")
assert len(state_dict) == 2

0 comments on commit cb2c5f8

Please sign in to comment.