diff --git a/tests/test_convnext.py b/tests/test_convnext.py index 97000d9..c516420 100644 --- a/tests/test_convnext.py +++ b/tests/test_convnext.py @@ -9,12 +9,12 @@ def test_forward(): m(torch.randn(1, 3, 224, 224)) -# def test_from_pretrained(): -# m = ViT.from_config("Ti", 16, 224, True).eval() -# x = torch.randn(1, 3, 224, 224) -# out = m(x) +def test_from_pretrained(): + m = ConvNeXt.from_config("T", True).eval() + x = torch.randn(1, 3, 224, 224) + out = m(x) -# m_timm = timm.create_model("vit_tiny_patch16_224.augreg_in21k", pretrained=True, num_classes=0).eval() -# out_timm = m_timm(x) + m_timm = timm.create_model("convnext_tiny.fb_in22k", pretrained=True, num_classes=0).eval() + out_timm = m_timm(x) -# torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5) + torch.testing.assert_close(out, out_timm) diff --git a/vision_toolbox/backbones/convnext.py b/vision_toolbox/backbones/convnext.py index 2ed5b11..a32ee9a 100644 --- a/vision_toolbox/backbones/convnext.py +++ b/vision_toolbox/backbones/convnext.py @@ -100,6 +100,43 @@ def from_config(variant: str, pretrained: bool = False) -> ConvNeXt: m = ConvNeXt(d_model, depths) if pretrained: - pass + # TODO: also add torchvision checkpoints? + ckpt = dict( + T="convnext_tiny_22k_224.pth", + S="convnext_small_22k_224.pth", + B="convnext_base_22k_224.pth", + L="convnext_large_22k_224.pth", + XL="convnext_xlarge_22k_224.pth", + )[variant] + base_url = "https://dl.fbaipublicfiles.com/convnext/" + state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"] + m.load_official_ckpt(state_dict) return m + + @torch.no_grad() + def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None: + def copy_(m: nn.Conv2d | nn.Linear | nn.LayerNorm, prefix: str): + m.weight.copy_(state_dict.pop(prefix + ".weight")) + m.bias.copy_(state_dict.pop(prefix + ".bias")) + + copy_(self.stem[0], "downsample_layers.0.0") # Conv2d + copy_(self.stem[2], "downsample_layers.0.1") # LayerNorm + + for stage_idx, stage in enumerate(self.stages): + if stage_idx > 0: + copy_(stage[0][0], f"downsample_layers.{stage_idx}.0") # LayerNorm + copy_(stage[0][2], f"downsample_layers.{stage_idx}.1") # Conv2d + + for block_idx in range(1, len(stage)): + block: ConvNeXtBlock = stage[block_idx] + prefix = f"stages.{stage_idx}.{block_idx - 1}." + + copy_(block.layers[1], prefix + "dwconv") + copy_(block.layers[3], prefix + "norm") + copy_(block.layers[4], prefix + "pwconv1") + copy_(block.layers[6], prefix + "pwconv2") + block.layer_scale.copy_(state_dict.pop(prefix + "gamma")) + + copy_(self.head_norm, "norm") + assert len(state_dict) == 2