Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions docs/trackers/core/reid/reid.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,40 @@ The `ReIDModel` class provides a flexible interface to extract appearance featur

### Loading a ReIDModel

You can initialize a `ReIDModel` from any supported pretrained model in the [`timm`](https://huggingface.co/docs/timm/en/index) library using the `from_timm` method.
=== "timm"
You can initialize a `ReIDModel` from any supported pretrained model in the [`timm`](https://huggingface.co/docs/timm/en/index) library using the `from_timm` method.

```python
from trackers import ReIDModel
```python
from trackers import ReIDModel

reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k")
```
reid_model = ReIDModel.from_timm("resnetv2_50.a1h_in1k")
```

=== "torchreid"
You can initialize a `ReIDModel` from any supported pretrained or fine-tuned model in the [`torchreid`](https://github.com/KaiyangZhou/deep-person-reid) library using the `from_torchreid` method.

```python
from trackers import ReIDModel

reid_model = ReIDModel.from_timm(
model_name="resnet50", num_classes=751
)
```

### Supported Models

The `ReIDModel` supports all models available in the timm library. You can list available models using:
=== "timm"

```python
import timm
print(timm.list_models())
```
The `ReIDModel` supports all models available in the timm library. You can list available models using:

```python
import timm
print(timm.list_models())
```

=== "torchreid"

The `ReIDModel` supports all models available in the torchreid library [model zoo](https://kaiyangzhou.github.io/deep-person-reid/MODEL_ZOO.html).

### Extracting Embeddings

Expand Down
83 changes: 77 additions & 6 deletions trackers/core/reid/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from safetensors.torch import save_file
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToPILImage
from torchvision.transforms import Compose, Normalize, Resize, ToPILImage, ToTensor
from tqdm.auto import tqdm

from trackers.core.reid.callbacks import BaseCallback
Expand Down Expand Up @@ -60,6 +61,7 @@ def _initialize_reid_model_from_timm(
config = resolve_data_config(model.pretrained_cfg)
transforms = create_transform(**config)
model_metadata = {
"model_source": "timm",
"model_name_or_checkpoint_path": model_name_or_checkpoint_path,
"get_pooled_features": get_pooled_features,
"kwargs": kwargs,
Expand Down Expand Up @@ -105,7 +107,6 @@ def __init__(
self.backbone_model = backbone_model
self.device = parse_device_spec(device or "auto")
self.backbone_model.to(self.device)
self.backbone_model.eval()
self.train_transforms = (
(Compose(*transforms) if isinstance(transforms, list) else transforms)
if transforms is not None
Expand Down Expand Up @@ -156,6 +157,61 @@ def from_timm(
**kwargs,
)

@classmethod
def from_torchreid(
cls,
model_name: str,
num_classes: int,
checkpoint_path: Optional[str] = None,
loss_name: str = "softmax",
device: str = "auto",
crop_size: tuple[int, int] = (256, 128),
) -> ReIDModel:
"""
Create a `ReIDModel` with a [torchreid](https://github.com/KaiyangZhou/deep-person-reid)
model as the backbone.

Args:
model_name (str): Name of the torchreid model to use.
num_classes (int): Number of training identities.
checkpoint_path (Optional[str]): Path to the checkpoint file to load.
loss_name (str): Loss function to optimize the model. Currently supports
"softmax" and "triplet".
device (str): Device to run the model on.
crop_size (tuple[int, int]): The size a bounding box crop should be
resized to before being fed to the model.
"""
from torchreid.models import build_model
from torchreid.utils import load_pretrained_weights

model = build_model(
name=model_name,
num_classes=num_classes,
loss=loss_name,
pretrained=True,
use_gpu=False,
)

if checkpoint_path is not None:
load_pretrained_weights(model, checkpoint_path)

transforms = Compose(
[
Resize(crop_size),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)

model_metadata = {"model_source": "torchreid"}

return cls(
model,
device=device,
transforms=transforms,
model_metadata=model_metadata,
)

def extract_features(
self, detections: sv.Detections, frame: Union[np.ndarray, PIL.Image.Image]
) -> np.ndarray:
Expand All @@ -179,11 +235,26 @@ def extract_features(
with torch.inference_mode():
for box in detections.xyxy:
crop = sv.crop_image(image=frame, xyxy=[*box.astype(int)])
tensor = self.inference_transforms(crop).unsqueeze(0).to(self.device)
feature = (
torch.squeeze(self.backbone_model(tensor)).cpu().numpy().flatten()
input_tensor = (
self.inference_transforms(crop).unsqueeze(0).to(self.device)
)
features.append(feature)
if self.model_metadata["model_source"] == "timm":
feature = (
torch.squeeze(self.backbone_model(input_tensor))
.cpu()
.numpy()
.flatten()
)
features.append(feature)
elif self.model_metadata["model_source"] == "torchreid":
model_output = self.backbone_model.featuremaps(input_tensor)
pooled_featuremaps = F.adaptive_avg_pool2d(model_output, (1, 1))
feature = torch.squeeze(pooled_featuremaps).cpu().numpy().flatten()
features.append(feature)
else:
raise ValueError(
f"Model source {self.model_metadata['model_source']} not supported." # noqa: E501
)

return np.array(features)

Expand Down