From 6e0515e99c39444caae39472ee1b2fd76ece32f1 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Tue, 24 Dec 2024 13:21:59 +0100 Subject: [PATCH] Add DINOv2 with registers (#35348) * added changes from 32905 * fixed mistakes caused by select all paste * rename diff_dinov2... * ran tests * Fix modular * Fix tests * Use new init * Simplify drop path * Convert all checkpoints * Add figure and summary * Update paths * Update docs * Update docs * Update toctree * Update docs --------- Co-authored-by: BernardZach Co-authored-by: Zach Bernard <132859071+BernardZach@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + .../en/model_doc/dinov2_with_registers.md | 54 + docs/source/en/perf_infer_gpu_one.md | 1 + src/transformers/__init__.py | 16 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 4 + .../models/dinov2_with_registers/__init__.py | 27 + .../configuration_dinov2_with_registers.py | 166 ++++ .../convert_dinov2_with_registers_to_hf.py | 291 ++++++ .../modeling_dinov2_with_registers.py | 926 ++++++++++++++++++ .../modular_dinov2_with_registers.py | 381 +++++++ src/transformers/utils/dummy_pt_objects.py | 28 + .../models/dinov2_with_registers/__init__.py | 0 .../test_modeling_dinov2_with_registers.py | 369 +++++++ utils/check_repo.py | 1 + 17 files changed, 2270 insertions(+) create mode 100644 docs/source/en/model_doc/dinov2_with_registers.md create mode 100644 src/transformers/models/dinov2_with_registers/__init__.py create mode 100644 src/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py create mode 100644 src/transformers/models/dinov2_with_registers/convert_dinov2_with_registers_to_hf.py create mode 100644 src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py create mode 100644 src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py create mode 100644 tests/models/dinov2_with_registers/__init__.py create mode 100644 tests/models/dinov2_with_registers/test_modeling_dinov2_with_registers.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index de21cd1408a3..a076f704b8ed 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -655,6 +655,8 @@ title: DiNAT - local: model_doc/dinov2 title: DINOV2 + - local: model_doc/dinov2_with_registers + title: DINOv2 with Registers - local: model_doc/dit title: DiT - local: model_doc/dpt diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 967049d89cbe..dcecfc872d61 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -127,6 +127,7 @@ Flax), PyTorch, and/or TensorFlow. | [DialoGPT](model_doc/dialogpt) | ✅ | ✅ | ✅ | | [DiNAT](model_doc/dinat) | ✅ | ❌ | ❌ | | [DINOv2](model_doc/dinov2) | ✅ | ❌ | ✅ | +| [DINOv2 with Registers](model_doc/dinov2_with_registers) | ✅ | ❌ | ❌ | | [DistilBERT](model_doc/distilbert) | ✅ | ✅ | ✅ | | [DiT](model_doc/dit) | ✅ | ❌ | ✅ | | [DonutSwin](model_doc/donut) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/dinov2_with_registers.md b/docs/source/en/model_doc/dinov2_with_registers.md new file mode 100644 index 000000000000..360ebf9b8f8a --- /dev/null +++ b/docs/source/en/model_doc/dinov2_with_registers.md @@ -0,0 +1,54 @@ + + +# DINOv2 with Registers + +## Overview + +The DINOv2 with Registers model was proposed in [Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588) by Timothée Darcet, Maxime Oquab, Julien Mairal, Piotr Bojanowski. + +The [Vision Transformer](vit) (ViT) is a transformer encoder model (BERT-like) originally introduced to do supervised image classification on ImageNet. + +Next, people figured out ways to make ViT work really well on self-supervised image feature extraction (i.e. learning meaningful features, also called embeddings) on images without requiring any labels. Some example papers here include [DINOv2](dinov2) and [MAE](vit_mae). + +The authors of DINOv2 noticed that ViTs have artifacts in attention maps. It’s due to the model using some image patches as “registers”. The authors propose a fix: just add some new tokens (called "register" tokens), which you only use during pre-training (and throw away afterwards). This results in: +- no artifacts +- interpretable attention maps +- and improved performances. + +The abstract from the paper is the following: + +*Transformers have recently emerged as a powerful tool for learning visual representations. In this paper, we identify and characterize artifacts in feature maps of both supervised and self-supervised ViT networks. The artifacts correspond to high-norm tokens appearing during inference primarily in low-informative background areas of images, that are repurposed for internal computations. We propose a simple yet effective solution based on providing additional tokens to the input sequence of the Vision Transformer to fill that role. We show that this solution fixes that problem entirely for both supervised and self-supervised models, sets a new state of the art for self-supervised visual models on dense visual prediction tasks, enables object discovery methods with larger models, and most importantly leads to smoother feature maps and attention maps for downstream visual processing.* + + + + Visualization of attention maps of various models trained with vs. without registers. Taken from the original paper. + +Tips: + +- Usage of DINOv2 with Registers is identical to DINOv2 without, you'll just get better performance. + +This model was contributed by [nielsr](https://huggingface.co/nielsr). +The original code can be found [here](https://github.com/facebookresearch/dinov2). + + +## Dinov2WithRegistersConfig + +[[autodoc]] Dinov2WithRegistersConfig + +## Dinov2WithRegistersModel + +[[autodoc]] Dinov2WithRegistersModel + - forward + +## Dinov2WithRegistersForImageClassification + +[[autodoc]] Dinov2WithRegistersForImageClassification + - forward \ No newline at end of file diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 930f41b6fefb..d79450964180 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -238,6 +238,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel) * [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2) +* [Dinov2_with_registers](https://huggingface.co/docs/transformers/en/model_doc/dinov2) * [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) * [EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder_decoder#transformers.EncoderDecoderModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ef140cc6d3a8..7df1af049de6 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -404,6 +404,7 @@ "models.dialogpt": [], "models.dinat": ["DinatConfig"], "models.dinov2": ["Dinov2Config"], + "models.dinov2_with_registers": ["Dinov2WithRegistersConfig"], "models.distilbert": [ "DistilBertConfig", "DistilBertTokenizer", @@ -2160,6 +2161,14 @@ "Dinov2PreTrainedModel", ] ) + _import_structure["models.dinov2_with_registers"].extend( + [ + "Dinov2WithRegistersBackbone", + "Dinov2WithRegistersForImageClassification", + "Dinov2WithRegistersModel", + "Dinov2WithRegistersPreTrainedModel", + ] + ) _import_structure["models.distilbert"].extend( [ "DistilBertForMaskedLM", @@ -5362,6 +5371,7 @@ from .models.detr import DetrConfig from .models.dinat import DinatConfig from .models.dinov2 import Dinov2Config + from .models.dinov2_with_registers import Dinov2WithRegistersConfig from .models.distilbert import ( DistilBertConfig, DistilBertTokenizer, @@ -7019,6 +7029,12 @@ Dinov2Model, Dinov2PreTrainedModel, ) + from .models.dinov2_with_registers import ( + Dinov2WithRegistersBackbone, + Dinov2WithRegistersForImageClassification, + Dinov2WithRegistersModel, + Dinov2WithRegistersPreTrainedModel, + ) from .models.distilbert import ( DistilBertForMaskedLM, DistilBertForMultipleChoice, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 7fcaddde704c..ff03d09966a4 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -77,6 +77,7 @@ dialogpt, dinat, dinov2, + dinov2_with_registers, distilbert, dit, donut, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 69ce8efa10c7..6c052aa0eaa0 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -94,6 +94,7 @@ ("detr", "DetrConfig"), ("dinat", "DinatConfig"), ("dinov2", "Dinov2Config"), + ("dinov2_with_registers", "Dinov2WithRegistersConfig"), ("distilbert", "DistilBertConfig"), ("donut-swin", "DonutSwinConfig"), ("dpr", "DPRConfig"), @@ -404,6 +405,7 @@ ("dialogpt", "DialoGPT"), ("dinat", "DiNAT"), ("dinov2", "DINOv2"), + ("dinov2_with_registers", "DINOv2 with Registers"), ("distilbert", "DistilBERT"), ("dit", "DiT"), ("donut-swin", "DonutSwin"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e8a2dece4324..861754f59176 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -92,6 +92,7 @@ ("detr", "DetrModel"), ("dinat", "DinatModel"), ("dinov2", "Dinov2Model"), + ("dinov2_with_registers", "Dinov2WithRegistersModel"), ("distilbert", "DistilBertModel"), ("donut-swin", "DonutSwinModel"), ("dpr", "DPRQuestionEncoder"), @@ -584,6 +585,7 @@ ("detr", "DetrModel"), ("dinat", "DinatModel"), ("dinov2", "Dinov2Model"), + ("dinov2_with_registers", "Dinov2WithRegistersModel"), ("dpt", "DPTModel"), ("efficientformer", "EfficientFormerModel"), ("efficientnet", "EfficientNetModel"), @@ -659,6 +661,7 @@ ), ("dinat", "DinatForImageClassification"), ("dinov2", "Dinov2ForImageClassification"), + ("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"), ( "efficientformer", ( @@ -1373,6 +1376,7 @@ ("convnextv2", "ConvNextV2Backbone"), ("dinat", "DinatBackbone"), ("dinov2", "Dinov2Backbone"), + ("dinov2_with_registers", "Dinov2WithRegistersBackbone"), ("focalnet", "FocalNetBackbone"), ("hiera", "HieraBackbone"), ("maskformer-swin", "MaskFormerSwinBackbone"), diff --git a/src/transformers/models/dinov2_with_registers/__init__.py b/src/transformers/models/dinov2_with_registers/__init__.py new file mode 100644 index 000000000000..2d10027b6a3b --- /dev/null +++ b/src/transformers/models/dinov2_with_registers/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_dinov2_with_registers import * + from .modeling_dinov2_with_registers import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py new file mode 100644 index 000000000000..80c095cb4648 --- /dev/null +++ b/src/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py @@ -0,0 +1,166 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dinov2_with_registers.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PretrainedConfig +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an + Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the DINOv2 with Registers + [facebook/dinov2-with-registers-base](https://huggingface.co/facebook/dinov2-with-registers-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + mlp_ratio (`int`, *optional*, defaults to 4): + Ratio of the hidden size of the MLPs relative to the `hidden_size`. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + layerscale_value (`float`, *optional*, defaults to 1.0): + Initial value to use for layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_swiglu_ffn (`bool`, *optional*, defaults to `False`): + Whether to use the SwiGLU feedforward neural network. + num_register_tokens (`int`, *optional*, defaults to 4): + Number of register tokens to use. + interpolate_antialias (`bool`, *optional*, defaults to `True`): + Whether to use antialiasing when interpolating the image patches. + interpolate_offset (`float`, *optional*, defaults to 0.0): + Offset to use when interpolating the image patches. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + apply_layernorm (`bool`, *optional*, defaults to `True`): + Whether to apply layer normalization to the feature maps in case the model is used as backbone. + reshape_hidden_states (`bool`, *optional*, defaults to `True`): + Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in + case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size, + seq_len, hidden_size)`. + + Example: + + ```python + >>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel + + >>> # Initializing a Dinov2WithRegisters base style configuration + >>> configuration = Dinov2WithRegistersConfig() + + >>> # Initializing a model (with random weights) from the base style configuration + >>> model = Dinov2WithRegistersModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "dinov2-with-registers-base" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + mlp_ratio=4, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-6, + image_size=224, + patch_size=16, + num_channels=3, + qkv_bias=True, + layerscale_value=1.0, + drop_path_rate=0.0, + use_swiglu_ffn=False, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + out_features=None, + out_indices=None, + apply_layernorm=True, + reshape_hidden_states=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.mlp_ratio = mlp_ratio + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.layerscale_value = layerscale_value + self.drop_path_rate = drop_path_rate + self.use_swiglu_ffn = use_swiglu_ffn + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + self.apply_layernorm = apply_layernorm + self.reshape_hidden_states = reshape_hidden_states + + +__all__ = ["Dinov2WithRegistersConfig"] diff --git a/src/transformers/models/dinov2_with_registers/convert_dinov2_with_registers_to_hf.py b/src/transformers/models/dinov2_with_registers/convert_dinov2_with_registers_to_hf.py new file mode 100644 index 000000000000..0ff2697f7466 --- /dev/null +++ b/src/transformers/models/dinov2_with_registers/convert_dinov2_with_registers_to_hf.py @@ -0,0 +1,291 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert DINOv2 with Registers checkpoints from the original repository. + +URL: https://github.com/facebookresearch/dinov2/tree/main +""" + +import argparse +import json +from pathlib import Path + +import requests +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from PIL import Image +from torchvision import transforms + +from transformers import ( + BitImageProcessor, + Dinov2WithRegistersConfig, + Dinov2WithRegistersForImageClassification, + Dinov2WithRegistersModel, +) +from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_dinov2_with_registers_config(model_name, image_classifier=False): + config = Dinov2WithRegistersConfig(image_size=518, patch_size=14) + + # size of the architecture + if "vits" in model_name: + config.hidden_size = 384 + config.num_attention_heads = 6 + elif "vitb" in model_name: + pass + elif "vitl" in model_name: + config.hidden_size = 1024 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + elif "vitg" in model_name: + config.use_swiglu_ffn = True + config.hidden_size = 1536 + config.num_hidden_layers = 40 + config.num_attention_heads = 24 + else: + raise ValueError("Model not supported") + + if image_classifier: + repo_id = "huggingface/label-files" + filename = "imagenet-1k-id2label.json" + config.num_labels = 1000 + config.id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + config.id2label = {int(k): v for k, v in config.id2label.items()} + + return config + + +def create_rename_keys(config): + rename_keys = [] + # fmt: off + + # patch embedding layer + rename_keys.append(("cls_token", "embeddings.cls_token")) + rename_keys.append(("mask_token", "embeddings.mask_token")) + rename_keys.append(("pos_embed", "embeddings.position_embeddings")) + rename_keys.append(("register_tokens", "embeddings.register_tokens")) + rename_keys.append(("patch_embed.proj.weight", "embeddings.patch_embeddings.projection.weight")) + rename_keys.append(("patch_embed.proj.bias", "embeddings.patch_embeddings.projection.bias")) + + for i in range(config.num_hidden_layers): + # layernorms + rename_keys.append((f"blocks.{i}.norm1.weight", f"encoder.layer.{i}.norm1.weight")) + rename_keys.append((f"blocks.{i}.norm1.bias", f"encoder.layer.{i}.norm1.bias")) + rename_keys.append((f"blocks.{i}.norm2.weight", f"encoder.layer.{i}.norm2.weight")) + rename_keys.append((f"blocks.{i}.norm2.bias", f"encoder.layer.{i}.norm2.bias")) + # MLP + if config.use_swiglu_ffn: + rename_keys.append((f"blocks.{i}.mlp.w12.weight", f"encoder.layer.{i}.mlp.w12.weight")) + rename_keys.append((f"blocks.{i}.mlp.w12.bias", f"encoder.layer.{i}.mlp.w12.bias")) + rename_keys.append((f"blocks.{i}.mlp.w3.weight", f"encoder.layer.{i}.mlp.w3.weight")) + rename_keys.append((f"blocks.{i}.mlp.w3.bias", f"encoder.layer.{i}.mlp.w3.bias")) + else: + rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"encoder.layer.{i}.mlp.fc1.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"encoder.layer.{i}.mlp.fc1.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"encoder.layer.{i}.mlp.fc2.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"encoder.layer.{i}.mlp.fc2.bias")) + # layerscale + rename_keys.append((f"blocks.{i}.ls1.gamma", f"encoder.layer.{i}.layer_scale1.lambda1")) + rename_keys.append((f"blocks.{i}.ls2.gamma", f"encoder.layer.{i}.layer_scale2.lambda1")) + # attention projection layer + rename_keys.append((f"blocks.{i}.attn.proj.weight", f"encoder.layer.{i}.attention.output.dense.weight")) + rename_keys.append((f"blocks.{i}.attn.proj.bias", f"encoder.layer.{i}.attention.output.dense.bias")) + + # final layernorm + rename_keys.append(("norm.weight", "layernorm.weight")) + rename_keys.append(("norm.bias", "layernorm.bias")) + + # fmt: on + return rename_keys + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config): + for i in range(config.num_hidden_layers): + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :] + state_dict[f"encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-config.hidden_size :, :] + state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + return image + + +@torch.no_grad() +def convert_dinov2_with_registers_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False): + """ + Copy/paste/tweak model's weights to our Dinov2WithRegisters structure. + """ + + # define default Dinov2WithRegisters configuration + image_classifier = "1layer" in model_name + config = get_dinov2_with_registers_config(model_name, image_classifier=image_classifier) + + # load original model from torch hub + original_model = torch.hub.load("facebookresearch/dinov2", model_name.replace("_1layer", "")) + original_model.eval() + + # load state_dict of original model, remove and rename some keys + state_dict = original_model.state_dict() + rename_keys = create_rename_keys(config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config) + + for key, val in state_dict.copy().items(): + val = state_dict.pop(key) + if "w12" in key: + key = key.replace("w12", "weights_in") + if "w3" in key: + key = key.replace("w3", "weights_out") + state_dict[key] = val + + # load HuggingFace model + if image_classifier: + model = Dinov2WithRegistersForImageClassification(config).eval() + model.dinov2_with_registers.load_state_dict(state_dict) + model_name_to_classifier_dict_url = { + "dinov2_vits14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear_head.pth", + "dinov2_vitb14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear_head.pth", + "dinov2_vitl14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear_head.pth", + "dinov2_vitg14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_linear_head.pth", + } + url = model_name_to_classifier_dict_url[model_name] + classifier_state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.classifier.weight = nn.Parameter(classifier_state_dict["weight"]) + model.classifier.bias = nn.Parameter(classifier_state_dict["bias"]) + else: + model = Dinov2WithRegistersModel(config).eval() + model.load_state_dict(state_dict) + + # load image + image = prepare_img() + + # preprocess image + transformations = transforms.Compose( + [ + transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=IMAGENET_DEFAULT_MEAN, # these are RGB mean+std values + std=IMAGENET_DEFAULT_STD, # across a large photo dataset. + ), + ] + ) + + original_pixel_values = transformations(image).unsqueeze(0) # insert batch dimension + + processor = BitImageProcessor( + size={"shortest_edge": 256}, + resample=PILImageResampling.BICUBIC, + image_mean=IMAGENET_DEFAULT_MEAN, + image_std=IMAGENET_DEFAULT_STD, + ) + pixel_values = processor(image, return_tensors="pt").pixel_values + + assert torch.allclose(original_pixel_values, pixel_values) + + with torch.no_grad(): + outputs = model(pixel_values, output_hidden_states=True) + original_outputs = original_model(pixel_values) + + # assert values + if image_classifier: + print("Predicted class:") + class_idx = outputs.logits.argmax(-1).item() + print(model.config.id2label[class_idx]) + else: + assert outputs.last_hidden_state[:, 0].shape == original_outputs.shape + assert torch.allclose(outputs.last_hidden_state[:, 0], original_outputs, atol=1e-3) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + model_name_to_hf_name = { + "dinov2_vits14_reg": "dinov2-with-registers-small", + "dinov2_vitb14_reg": "dinov2-with-registers-base", + "dinov2_vitl14_reg": "dinov2-with-registers-large", + "dinov2_vitg14_reg": "dinov2-with-registers-giant", + "dinov2_vits14_reg_1layer": "dinov2-with-registers-small-imagenet1k-1-layer", + "dinov2_vitb14_reg_1layer": "dinov2-with-registers-base-imagenet1k-1-layer", + "dinov2_vitl14_reg_1layer": "dinov2-with-registers-large-imagenet1k-1-layer", + "dinov2_vitg14_reg_1layer": "dinov2-with-registers-giant-imagenet1k-1-layer", + } + + name = model_name_to_hf_name[model_name] + model.push_to_hub(f"nielsr/{name}") + processor.push_to_hub(f"nielsr/{name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="dinov2_vits14_reg", + type=str, + choices=[ + "dinov2_vits14_reg", + "dinov2_vitb14_reg", + "dinov2_vitl14_reg", + "dinov2_vitg14_reg", + "dinov2_vits14_reg_1layer", + "dinov2_vitb14_reg_1layer", + "dinov2_vitl14_reg_1layer", + "dinov2_vitg14_reg_1layer", + ], + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_dinov2_with_registers_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py new file mode 100644 index 000000000000..4ebefa8bded1 --- /dev/null +++ b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -0,0 +1,926 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dinov2_with_registers.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections.abc +import math +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.backbone_utils import BackboneMixin +from .configuration_dinov2_with_registers import Dinov2WithRegistersConfig + + +logger = logging.get_logger(__name__) + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/dinov2_with_registers-base" + +# General docstring +_CONFIG_FOR_DOC = "Dinov2WithRegistersConfig" + + +class Dinov2WithRegistersPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +class Dinov2WithRegistersEmbeddings(nn.Module): + """ + Construct the CLS token, mask token, register tokens, position and patch embeddings. + """ + + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size)) + self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.config.patch_size + width = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + self.config.interpolate_offset, width + self.config.interpolate_offset + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + target_dtype = patch_pos_embed.dtype + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.to(dtype=torch.float32), + scale_factor=(float(height / math.sqrt(num_positions)), float(width / math.sqrt(num_positions))), + mode="bicubic", + align_corners=False, + antialias=self.config.interpolate_antialias, + ) + patch_pos_embed = patch_pos_embed.to(dtype=target_dtype) + if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embeddings.projection.weight.dtype + embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + + if bool_masked_pos is not None: + embeddings = torch.where( + bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + + # add register tokens + embeddings = torch.cat( + (embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1), embeddings[:, 1:]), dim=1 + ) + + embeddings = self.dropout(embeddings) + + return embeddings + + +class Dinov2WithRegistersSelfAttention(nn.Module): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class Dinov2WithRegistersSdpaSelfAttention(Dinov2WithRegistersSelfAttention): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Dinov2WithRegistersModel is using Dinov2WithRegistersSdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, head_mask=head_mask, output_attentions=output_attentions + ) + + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + +class Dinov2WithRegistersSelfOutput(nn.Module): + """ + The residual connection is defined in Dinov2WithRegistersLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class Dinov2WithRegistersAttention(nn.Module): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + self.attention = Dinov2WithRegistersSelfAttention(config) + self.output = Dinov2WithRegistersSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class Dinov2WithRegistersSdpaAttention(Dinov2WithRegistersAttention): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__(config) + self.attention = Dinov2WithRegistersSdpaSelfAttention(config) + + +class Dinov2WithRegistersLayerScale(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return hidden_state * self.lambda1 + + +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class Dinov2WithRegistersDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Dinov2WithRegistersMLP(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + if isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + self.fc2 = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class Dinov2WithRegistersSwiGLUFFN(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) + self.weights_out = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.weights_in(hidden_state) + x1, x2 = hidden_state.chunk(2, dim=-1) + hidden = nn.functional.silu(x1) * x2 + return self.weights_out(hidden) + + +DINOV2_WITH_REGISTERS_ATTENTION_CLASSES = { + "eager": Dinov2WithRegistersAttention, + "sdpa": Dinov2WithRegistersSdpaAttention, +} + + +class Dinov2WithRegistersLayer(nn.Module): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = DINOV2_WITH_REGISTERS_ATTENTION_CLASSES[config._attn_implementation](config) + self.layer_scale1 = Dinov2WithRegistersLayerScale(config) + self.drop_path = ( + Dinov2WithRegistersDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + ) + + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.use_swiglu_ffn: + self.mlp = Dinov2WithRegistersSwiGLUFFN(config) + else: + self.mlp = Dinov2WithRegistersMLP(config) + self.layer_scale2 = Dinov2WithRegistersLayerScale(config) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.norm1(hidden_states), # in Dinov2WithRegisters, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + attention_output = self.layer_scale1(attention_output) + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in Dinov2WithRegisters, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +class Dinov2WithRegistersEncoder(nn.Module): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([Dinov2WithRegistersLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Dinov2WithRegistersPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Dinov2WithRegistersConfig + base_model_prefix = "dinov2_with_registers" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["Dinov2WithRegistersSwiGLUFFN"] + _supports_sdpa = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Dinov2WithRegistersEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + +_EXPECTED_OUTPUT_SHAPE = [1, 257, 768] + + +DINOV2_WITH_REGISTERS_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Dinov2WithRegistersConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for + pre-training. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Dinov2WithRegisters Model transformer outputting raw hidden-states without any specific head on top.", + DINOV2_WITH_REGISTERS_START_DOCSTRING, +) +class Dinov2WithRegistersModel(Dinov2WithRegistersPreTrainedModel): + def __init__(self, config: Dinov2WithRegistersConfig): + super().__init__(config) + self.config = config + + self.embeddings = Dinov2WithRegistersEmbeddings(config) + self.encoder = Dinov2WithRegistersEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + head_outputs = (sequence_output, pooled_output) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2_with_registers-small-imagenet1k-1-layer" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + +DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """ + Dinov2WithRegisters Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + """, + DINOV2_WITH_REGISTERS_START_DOCSTRING, +) +class Dinov2WithRegistersForImageClassification(Dinov2WithRegistersPreTrainedModel): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.dinov2_with_registers = Dinov2WithRegistersModel(config) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.dinov2_with_registers( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] # batch_size, sequence_length, hidden_size + + cls_token = sequence_output[:, 0] + patch_tokens = sequence_output[:, 1:] + + linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) + + logits = self.classifier(linear_input) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Dinov2WithRegisters backbone, to be used with frameworks like DETR and MaskFormer. + """, + DINOV2_WITH_REGISTERS_START_DOCSTRING, +) +class Dinov2WithRegistersBackbone(Dinov2WithRegistersPreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] + self.embeddings = Dinov2WithRegistersEmbeddings(config) + self.encoder = Dinov2WithRegistersEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.num_register_tokens = config.num_register_tokens + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings: + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + Returns: + + Examples: + + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + hidden_state = hidden_state[:, self.num_register_tokens + 1 :] + # this was actually a bug in the original implementation that we copied here, + # cause normally the order is height, width + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions if output_attentions else None, + ) + + +__all__ = [ + "Dinov2WithRegistersPreTrainedModel", + "Dinov2WithRegistersModel", + "Dinov2WithRegistersForImageClassification", + "Dinov2WithRegistersBackbone", +] diff --git a/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py new file mode 100644 index 000000000000..bbfacd2b5f57 --- /dev/null +++ b/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py @@ -0,0 +1,381 @@ +# coding=utf-8 +# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional + +import torch +import torch.utils.checkpoint +from torch import nn + +from ....transformers.models.dinov2.modeling_dinov2 import ( + Dinov2Backbone, + Dinov2Encoder, + Dinov2ForImageClassification, + Dinov2Model, + Dinov2PatchEmbeddings, + Dinov2PreTrainedModel, +) +from ...configuration_utils import PretrainedConfig +from ...modeling_outputs import BackboneOutput +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an + Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the DINOv2 with Registers + [facebook/dinov2-with-registers-base](https://huggingface.co/facebook/dinov2-with-registers-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + mlp_ratio (`int`, *optional*, defaults to 4): + Ratio of the hidden size of the MLPs relative to the `hidden_size`. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + layerscale_value (`float`, *optional*, defaults to 1.0): + Initial value to use for layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_swiglu_ffn (`bool`, *optional*, defaults to `False`): + Whether to use the SwiGLU feedforward neural network. + num_register_tokens (`int`, *optional*, defaults to 4): + Number of register tokens to use. + interpolate_antialias (`bool`, *optional*, defaults to `True`): + Whether to use antialiasing when interpolating the image patches. + interpolate_offset (`float`, *optional*, defaults to 0.0): + Offset to use when interpolating the image patches. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + apply_layernorm (`bool`, *optional*, defaults to `True`): + Whether to apply layer normalization to the feature maps in case the model is used as backbone. + reshape_hidden_states (`bool`, *optional*, defaults to `True`): + Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in + case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size, + seq_len, hidden_size)`. + + Example: + + ```python + >>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel + + >>> # Initializing a Dinov2WithRegisters base style configuration + >>> configuration = Dinov2WithRegistersConfig() + + >>> # Initializing a model (with random weights) from the base style configuration + >>> model = Dinov2WithRegistersModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "dinov2-with-registers-base" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + mlp_ratio=4, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-6, + image_size=224, + patch_size=16, + num_channels=3, + qkv_bias=True, + layerscale_value=1.0, + drop_path_rate=0.0, + use_swiglu_ffn=False, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + out_features=None, + out_indices=None, + apply_layernorm=True, + reshape_hidden_states=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.mlp_ratio = mlp_ratio + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.layerscale_value = layerscale_value + self.drop_path_rate = drop_path_rate + self.use_swiglu_ffn = use_swiglu_ffn + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + self.apply_layernorm = apply_layernorm + self.reshape_hidden_states = reshape_hidden_states + + +class Dinov2WithRegistersPatchEmbeddings(Dinov2PatchEmbeddings): + pass + + +class Dinov2WithRegistersEmbeddings(nn.Module): + """ + Construct the CLS token, mask token, register tokens, position and patch embeddings. + """ + + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size)) + self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.config.patch_size + width = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + self.config.interpolate_offset, width + self.config.interpolate_offset + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + target_dtype = patch_pos_embed.dtype + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.to(dtype=torch.float32), + scale_factor=(float(height / math.sqrt(num_positions)), float(width / math.sqrt(num_positions))), + mode="bicubic", + align_corners=False, + antialias=self.config.interpolate_antialias, + ) + patch_pos_embed = patch_pos_embed.to(dtype=target_dtype) + if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embeddings.projection.weight.dtype + embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + + if bool_masked_pos is not None: + embeddings = torch.where( + bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + + # add register tokens + embeddings = torch.cat( + (embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1), embeddings[:, 1:]), dim=1 + ) + + embeddings = self.dropout(embeddings) + + return embeddings + + +class Dinov2WithRegistersEncoder(Dinov2Encoder): + pass + + +class Dinov2WithRegistersPreTrainedModel(Dinov2PreTrainedModel): + pass + + +class Dinov2WithRegistersModel(Dinov2Model): + pass + + +class Dinov2WithRegistersForImageClassification(Dinov2ForImageClassification): + pass + + +class Dinov2WithRegistersBackbone(Dinov2Backbone): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_register_tokens = config.num_register_tokens + self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] + self.embeddings = Dinov2WithRegistersEmbeddings(config) + self.encoder = Dinov2WithRegistersEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings: + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + hidden_state = hidden_state[:, self.num_register_tokens + 1 :] + # this was actually a bug in the original implementation that we copied here, + # cause normally the order is height, width + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions if output_attentions else None, + ) + + +__all__ = [ + "Dinov2WithRegistersConfig", + "Dinov2WithRegistersPreTrainedModel", + "Dinov2WithRegistersModel", + "Dinov2WithRegistersForImageClassification", + "Dinov2WithRegistersBackbone", +] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e3463461ea07..922d67264bb1 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3635,6 +3635,34 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Dinov2WithRegistersBackbone(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Dinov2WithRegistersForImageClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Dinov2WithRegistersModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Dinov2WithRegistersPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class DistilBertForMaskedLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/dinov2_with_registers/__init__.py b/tests/models/dinov2_with_registers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/dinov2_with_registers/test_modeling_dinov2_with_registers.py b/tests/models/dinov2_with_registers/test_modeling_dinov2_with_registers.py new file mode 100644 index 000000000000..6aa62138e620 --- /dev/null +++ b/tests/models/dinov2_with_registers/test_modeling_dinov2_with_registers.py @@ -0,0 +1,369 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Dinov2WithRegisters model.""" + +import unittest + +from transformers import Dinov2WithRegistersConfig +from transformers.testing_utils import ( + require_torch, + require_vision, + slow, + torch_device, +) +from transformers.utils import cached_property, is_torch_available, is_vision_available + +from ...test_backbone_common import BackboneTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import ( + Dinov2WithRegistersBackbone, + Dinov2WithRegistersForImageClassification, + Dinov2WithRegistersModel, + ) + + +if is_vision_available(): + from PIL import Image + + from transformers import AutoImageProcessor + + +class Dinov2WithRegistersModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=30, + patch_size=2, + num_channels=3, + is_training=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + type_sequence_label_size=10, + initializer_range=0.02, + num_register_tokens=2, + mask_ratio=0.5, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_register_tokens = num_register_tokens + self.scope = scope + + # in DINOv2 with Registers, the seq length equals the number of patches + 1 + num_register_tokens (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + self.num_register_tokens + self.mask_ratio = mask_ratio + self.num_masks = int(mask_ratio * self.seq_length) + self.mask_length = num_patches + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return Dinov2WithRegistersConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + is_decoder=False, + initializer_range=self.initializer_range, + num_register_tokens=self.num_register_tokens, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = Dinov2WithRegistersModel(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_backbone(self, config, pixel_values, labels): + model = Dinov2WithRegistersBackbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify hidden states + self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) + expected_size = self.image_size // config.patch_size + self.parent.assertListEqual( + list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], expected_size, expected_size] + ) + + # verify channels + self.parent.assertEqual(len(model.channels), len(config.out_features)) + + # verify backbone works with out_features=None + config.out_features = None + model = Dinov2WithRegistersBackbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), 1) + self.parent.assertListEqual( + list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], expected_size, expected_size] + ) + + # verify channels + self.parent.assertEqual(len(model.channels), 1) + + # verify backbone works with apply_layernorm=False and reshape_hidden_states=False + config.apply_layernorm = False + config.reshape_hidden_states = False + + model = Dinov2WithRegistersBackbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), 1) + self.parent.assertListEqual( + list(result.feature_maps[0].shape), [self.batch_size, self.seq_length, self.hidden_size] + ) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.type_sequence_label_size + model = Dinov2WithRegistersForImageClassification(config) + model.to(torch_device) + model.eval() + result = model(pixel_values, labels=labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + # test greyscale images + config.num_channels = 1 + model = Dinov2WithRegistersForImageClassification(config) + model.to(torch_device) + model.eval() + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + pixel_values, + labels, + ) = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class Dinov2WithRegistersModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as Dinov2WithRegisters does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = ( + ( + Dinov2WithRegistersModel, + Dinov2WithRegistersForImageClassification, + Dinov2WithRegistersBackbone, + ) + if is_torch_available() + else () + ) + pipeline_model_mapping = ( + { + "image-feature-extraction": Dinov2WithRegistersModel, + "image-classification": Dinov2WithRegistersForImageClassification, + } + if is_torch_available() + else {} + ) + fx_compatible = False + + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = Dinov2WithRegistersModelTester(self) + self.config_tester = ConfigTester( + self, config_class=Dinov2WithRegistersConfig, has_text_modality=False, hidden_size=37 + ) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad and "register_tokens" not in name: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="Dinov2WithRegisters does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + def test_model_get_set_embeddings(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_backbone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_backbone(*config_and_inputs) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @unittest.skip(reason="Dinov2WithRegisters does not support feedforward chunking yet") + def test_feed_forward_chunking(self): + pass + + @slow + def test_model_from_pretrained(self): + model_name = "facebook/dinov2-with-registers-base" + model = Dinov2WithRegistersModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_torch +@require_vision +class Dinov2WithRegistersModelIntegrationTest(unittest.TestCase): + @cached_property + def default_image_processor(self): + return ( + AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base") + if is_vision_available() + else None + ) + + @slow + def test_inference_no_head(self): + model = Dinov2WithRegistersModel.from_pretrained("facebook/dinov2-with-registers-base").to(torch_device) + + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # verify the last hidden states + # in DINOv2 with Registers, the seq length equals the number of patches + 1 + num_register_tokens (we add 1 for the [CLS] token) + num_patches = (image_processor.crop_size["height"] // model.config.patch_size) ** 2 + expected_seq_length = num_patches + 1 + model.config.num_register_tokens + expected_shape = torch.Size((1, expected_seq_length, model.config.hidden_size)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[-0.4636, -1.4582, -0.0274], [-1.4738, -0.8858, 0.3002], [0.0714, -0.2407, -1.5940]], + device=torch_device, + ) + self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) + + +@require_torch +class Dinov2WithRegistersBackboneTest(unittest.TestCase, BackboneTesterMixin): + all_model_classes = (Dinov2WithRegistersBackbone,) if is_torch_available() else () + config_class = Dinov2WithRegistersConfig + + has_attentions = False + + def setUp(self): + self.model_tester = Dinov2WithRegistersModelTester(self) diff --git a/utils/check_repo.py b/utils/check_repo.py index 3dbe59f19229..130eebf0b838 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -1009,6 +1009,7 @@ def find_all_documented_objects() -> List[str]: "ConvNextV2Backbone", "DinatBackbone", "Dinov2Backbone", + "Dinov2WithRegistersBackbone", "FocalNetBackbone", "HieraBackbone", "MaskFormerSwinBackbone",