diff --git a/docs/trackers/core/reid/reid.md b/docs/trackers/core/reid/reid.md index db9358e0..b0f5aa02 100644 --- a/docs/trackers/core/reid/reid.md +++ b/docs/trackers/core/reid/reid.md @@ -48,22 +48,40 @@ The `ReIDModel` class provides a flexible interface to extract appearance featur ### Loading a ReIDModel -You can initialize a `ReIDModel` from any supported pretrained model in the [`timm`](https://huggingface.co/docs/timm/en/index) library using the `from_timm` method. +=== "timm" + You can initialize a `ReIDModel` from any supported pretrained model in the [`timm`](https://huggingface.co/docs/timm/en/index) library using the `from_timm` method. -```python -from trackers import ReIDModel + ```python + from trackers import ReIDModel -reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k") -``` + reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k") + ``` + +=== "torchreid" + You can initialize a `ReIDModel` from any supported pretrained or fine-tuned model in the [`torchreid`](https://github.com/KaiyangZhou/deep-person-reid) library using the `from_torchreid` method. + + ```python + from trackers import ReIDModel + + reid_model = ReIDModel.from_timm( + model_name="resnet50", num_classes=751 + ) + ``` ### Supported Models -The `ReIDModel` supports all models available in the timm library. You can list available models using: +=== "timm" -```python -import timm -print(timm.list_models()) -``` + The `ReIDModel` supports all models available in the timm library. You can list available models using: + + ```python + import timm + print(timm.list_models()) + ``` + +=== "torchreid" + + The `ReIDModel` supports all models available in the torchreid library [model zoo](https://kaiyangzhou.github.io/deep-person-reid/MODEL_ZOO.html). ### Extracting Embeddings diff --git a/trackers/core/reid/model.py b/trackers/core/reid/model.py index bbd7aacf..eead1d2f 100644 --- a/trackers/core/reid/model.py +++ b/trackers/core/reid/model.py @@ -10,12 +10,13 @@ import timm import torch import torch.nn as nn +import torch.nn.functional as F import torch.optim as optim from safetensors.torch import save_file from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform from torch.utils.data import DataLoader -from torchvision.transforms import Compose, ToPILImage +from torchvision.transforms import Compose, Normalize, Resize, ToPILImage, ToTensor from tqdm.auto import tqdm from trackers.core.reid.callbacks import BaseCallback @@ -60,6 +61,7 @@ def _initialize_reid_model_from_timm( config = resolve_data_config(model.pretrained_cfg) transforms = create_transform(**config) model_metadata = { + "model_source": "timm", "model_name_or_checkpoint_path": model_name_or_checkpoint_path, "get_pooled_features": get_pooled_features, "kwargs": kwargs, @@ -105,7 +107,6 @@ def __init__( self.backbone_model = backbone_model self.device = parse_device_spec(device or "auto") self.backbone_model.to(self.device) - self.backbone_model.eval() self.train_transforms = ( (Compose(*transforms) if isinstance(transforms, list) else transforms) if transforms is not None @@ -156,6 +157,61 @@ def from_timm( **kwargs, ) + @classmethod + def from_torchreid( + cls, + model_name: str, + num_classes: int, + checkpoint_path: Optional[str] = None, + loss_name: str = "softmax", + device: str = "auto", + crop_size: tuple[int, int] = (256, 128), + ) -> ReIDModel: + """ + Create a `ReIDModel` with a [torchreid](https://github.com/KaiyangZhou/deep-person-reid) + model as the backbone. + + Args: + model_name (str): Name of the torchreid model to use. + num_classes (int): Number of training identities. + checkpoint_path (Optional[str]): Path to the checkpoint file to load. + loss_name (str): Loss function to optimize the model. Currently supports + "softmax" and "triplet". + device (str): Device to run the model on. + crop_size (tuple[int, int]): The size a bounding box crop should be + resized to before being fed to the model. + """ + from torchreid.models import build_model + from torchreid.utils import load_pretrained_weights + + model = build_model( + name=model_name, + num_classes=num_classes, + loss=loss_name, + pretrained=True, + use_gpu=False, + ) + + if checkpoint_path is not None: + load_pretrained_weights(model, checkpoint_path) + + transforms = Compose( + [ + Resize(crop_size), + ToTensor(), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + model_metadata = {"model_source": "torchreid"} + + return cls( + model, + device=device, + transforms=transforms, + model_metadata=model_metadata, + ) + def extract_features( self, detections: sv.Detections, frame: Union[np.ndarray, PIL.Image.Image] ) -> np.ndarray: @@ -179,11 +235,26 @@ def extract_features( with torch.inference_mode(): for box in detections.xyxy: crop = sv.crop_image(image=frame, xyxy=[*box.astype(int)]) - tensor = self.inference_transforms(crop).unsqueeze(0).to(self.device) - feature = ( - torch.squeeze(self.backbone_model(tensor)).cpu().numpy().flatten() + input_tensor = ( + self.inference_transforms(crop).unsqueeze(0).to(self.device) ) - features.append(feature) + if self.model_metadata["model_source"] == "timm": + feature = ( + torch.squeeze(self.backbone_model(input_tensor)) + .cpu() + .numpy() + .flatten() + ) + features.append(feature) + elif self.model_metadata["model_source"] == "torchreid": + model_output = self.backbone_model.featuremaps(input_tensor) + pooled_featuremaps = F.adaptive_avg_pool2d(model_output, (1, 1)) + feature = torch.squeeze(pooled_featuremaps).cpu().numpy().flatten() + features.append(feature) + else: + raise ValueError( + f"Model source {self.model_metadata['model_source']} not supported." # noqa: E501 + ) return np.array(features)