Skip to content

Commit 03cbd3e

Browse files
authored
Merge pull request #3 from roboflow/feat/deepsort-tracker
feat(trackers): Implement `DeepSORTTracker`
2 parents ac42b7e + abc248a commit 03cbd3e

File tree

10 files changed

+2760
-195
lines changed

10 files changed

+2760
-195
lines changed

pyproject.toml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ classifiers = [
3333
keywords = ["tracking", "machine-learning", "deep-learning", "vision", "ML", "DL", "AI", "DETR", "YOLO", "Roboflow"]
3434

3535
dependencies = [
36+
"firerequests>=0.1.2",
37+
"numpy>=2.0.2",
3638
"supervision>=0.25.1",
3739
]
3840

@@ -58,6 +60,49 @@ docs = [
5860
"mike>=2.1.3",
5961
]
6062

63+
cpu = [
64+
"torch>=2.6.0",
65+
"torchvision>=0.21.0",
66+
]
67+
cu124 = [
68+
"torch>=2.6.0",
69+
"torchvision>=0.21.0",
70+
]
71+
72+
deepsort = [
73+
"scipy>=1.13.1",
74+
"timm>=1.0.15",
75+
"validators>=0.34.0",
76+
]
77+
78+
[tool.uv]
79+
conflicts = [
80+
[
81+
{ extra = "cpu" },
82+
{ extra = "cu124" },
83+
],
84+
]
85+
86+
[tool.uv.sources]
87+
torch = [
88+
{ index = "pytorch-cpu", extra = "cpu" },
89+
{ index = "pytorch-cu124", extra = "cu124" },
90+
]
91+
torchvision = [
92+
{ index = "pytorch-cpu", extra = "cpu" },
93+
{ index = "pytorch-cu124", extra = "cu124" },
94+
]
95+
96+
[[tool.uv.index]]
97+
name = "pytorch-cpu"
98+
url = "https://download.pytorch.org/whl/cpu"
99+
explicit = true
100+
101+
[[tool.uv.index]]
102+
name = "pytorch-cu124"
103+
url = "https://download.pytorch.org/whl/cu124"
104+
explicit = true
105+
61106
[build-system]
62107
requires = ["hatchling"]
63108
build-backend = "hatchling.build"

trackers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from trackers.core.deepsort.feature_extractor import DeepSORTFeatureExtractor
2+
from trackers.core.deepsort.tracker import DeepSORTTracker
13
from trackers.core.sort.tracker import SORTTracker
24

3-
__all__ = ["SORTTracker"]
5+
__all__ = ["DeepSORTFeatureExtractor", "DeepSORTTracker", "SORTTracker"]

trackers/core/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
from abc import ABC, abstractmethod
22

3+
import numpy as np
34
import supervision as sv
45

56

67
class BaseTracker(ABC):
78
@abstractmethod
89
def update(self, detections: sv.Detections) -> sv.Detections:
910
pass
11+
12+
13+
class BaseTrackerWithFeatures(ABC):
14+
@abstractmethod
15+
def update(self, detections: sv.Detections, frame: np.ndarray) -> sv.Detections:
16+
pass

trackers/core/deepsort/__init__.py

Whitespace-only changes.
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from typing import Optional, Tuple, Union
2+
3+
import numpy as np
4+
import supervision as sv
5+
import timm
6+
import torch
7+
import torch.nn as nn
8+
import torchvision.transforms as transforms
9+
import validators
10+
from firerequests import FireRequests
11+
12+
from trackers.utils.torch_utils import parse_device_spec
13+
14+
15+
class FeatureExtractionBackbone(nn.Module):
16+
"""
17+
A simple backbone model for feature extraction.
18+
19+
Args:
20+
backbone_model (nn.Module): The backbone model to use for feature extraction.
21+
"""
22+
23+
def __init__(self, backbone_model: nn.Module):
24+
super(FeatureExtractionBackbone, self).__init__()
25+
self.backbone_model = backbone_model
26+
27+
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
28+
"""
29+
Forward pass on a single input tensor.
30+
31+
Args:
32+
input_tensor (torch.Tensor): The input tensor.
33+
"""
34+
output = self.backbone_model(input_tensor)
35+
output = torch.squeeze(output)
36+
return output
37+
38+
39+
class DeepSORTFeatureExtractor:
40+
"""
41+
Feature extractor for DeepSORT that loads a PyTorch model and
42+
extracts appearance features from detection crops.
43+
44+
Args:
45+
model_or_checkpoint_path (Union[str, torch.nn.Module]): Path/URL to the PyTorch
46+
model checkpoint or the model itself.
47+
device (str): Device to run the model on.
48+
input_size (Tuple[int, int]): Size to which the input images are resized.
49+
"""
50+
51+
def __init__(
52+
self,
53+
model_or_checkpoint_path: Union[str, torch.nn.Module],
54+
device: Optional[str] = "auto",
55+
input_size: Tuple[int, int] = (128, 128),
56+
):
57+
self.device = parse_device_spec(device or "auto")
58+
self.input_size = input_size
59+
60+
self._initialize_model(model_or_checkpoint_path)
61+
62+
self.transform = transforms.Compose(
63+
[
64+
transforms.ToPILImage(),
65+
transforms.Resize(self.input_size),
66+
transforms.ToTensor(),
67+
]
68+
)
69+
70+
@classmethod
71+
def from_timm(
72+
cls,
73+
model_name: str,
74+
device: Optional[str] = "auto",
75+
input_size: Tuple[int, int] = (128, 128),
76+
pretrained: bool = True,
77+
*args,
78+
**kwargs,
79+
):
80+
"""
81+
Create a feature extractor from a timm model.
82+
83+
Args:
84+
model_name (str): Name of the timm model to use.
85+
device (str): Device to run the model on.
86+
input_size (Tuple[int, int]): Size to which the input images are resized.
87+
pretrained (bool): Whether to use pretrained weights from timm or not.
88+
*args: Additional arguments to pass to
89+
[`timm.create_model`](https://huggingface.co/docs/timm/en/reference/models#timm.create_model).
90+
**kwargs: Additional keyword arguments to pass to
91+
[`timm.create_model`](https://huggingface.co/docs/timm/en/reference/models#timm.create_model).
92+
93+
Returns:
94+
DeepSORTFeatureExtractor: A new instance of DeepSORTFeatureExtractor.
95+
"""
96+
if model_name not in timm.list_models(filter=model_name, pretrained=pretrained):
97+
raise ValueError(
98+
f"Model {model_name} not found in timm. "
99+
+ "Please check the model name and try again."
100+
)
101+
model = timm.create_model(model_name, pretrained=pretrained, *args, **kwargs)
102+
backbone_model = FeatureExtractionBackbone(model)
103+
return cls(backbone_model, device, input_size)
104+
105+
def _initialize_model(
106+
self, model_or_checkpoint_path: Union[str, torch.nn.Module, None]
107+
):
108+
if isinstance(model_or_checkpoint_path, str):
109+
if validators.url(model_or_checkpoint_path):
110+
checkpoint_path = FireRequests().download(model_or_checkpoint_path)[0]
111+
self._load_model_from_path(checkpoint_path)
112+
else:
113+
self._load_model_from_path(model_or_checkpoint_path)
114+
else:
115+
self.model = model_or_checkpoint_path
116+
self.model.to(self.device)
117+
self.model.eval()
118+
119+
def _load_model_from_path(self, model_path):
120+
"""
121+
Load the PyTorch model from the given path.
122+
123+
Args:
124+
model_path (str): Path to the model checkpoint.
125+
126+
Returns:
127+
torch.nn.Module: Loaded PyTorch model.
128+
"""
129+
self.model = FeatureExtractionBackbone(torch.load(model_path))
130+
self.model.to(self.device)
131+
self.model.eval()
132+
133+
def extract_features(
134+
self, frame: np.ndarray, detections: sv.Detections
135+
) -> np.ndarray:
136+
"""
137+
Extract features from detection crops in the frame.
138+
139+
Args:
140+
frame (np.ndarray): The input frame.
141+
detections (sv.Detections): Detections from which to extract features.
142+
143+
Returns:
144+
np.ndarray: Extracted features for each detection.
145+
"""
146+
if len(detections) == 0:
147+
return np.array([])
148+
149+
features = []
150+
with torch.no_grad():
151+
for box in detections.xyxy:
152+
crop = sv.crop_image(image=frame, xyxy=[*box.astype(int)])
153+
tensor = self.transform(crop).unsqueeze(0).to(self.device)
154+
feature = self.model(tensor).cpu().numpy().flatten()
155+
features.append(feature)
156+
157+
return np.array(features)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Optional, Union
2+
3+
import numpy as np
4+
5+
from trackers.core.sort.kalman_box_tracker import SORTKalmanBoxTracker
6+
7+
8+
class DeepSORTKalmanBoxTracker(SORTKalmanBoxTracker):
9+
"""
10+
The `DeepSORTKalmanBoxTracker` class represents the internals of a single
11+
tracked object (bounding box), with a Kalman filter to predict and update
12+
its position. It also maintains a feature vector for the object, which is
13+
used to identify the object across frames.
14+
"""
15+
16+
def __init__(self, bbox: np.ndarray, feature: Optional[np.ndarray] = None):
17+
super().__init__(bbox)
18+
self.features: list[np.ndarray] = []
19+
if feature is not None:
20+
self.features.append(feature)
21+
22+
def update_feature(self, feature: np.ndarray):
23+
self.features.append(feature)
24+
25+
def get_feature(self) -> Union[np.ndarray, None]:
26+
"""
27+
Get the mean feature vector for this tracker.
28+
29+
Returns:
30+
np.ndarray: Mean feature vector.
31+
"""
32+
if len(self.features) > 0:
33+
# Return the mean of all features, thus (in theory) capturing the
34+
# "average appearance" of the object, which should be more robust
35+
# to minor appearance changes. Otherwise, the last feature can
36+
# also be returned like the following:
37+
# return self.features[-1]
38+
return np.mean(self.features, axis=0)
39+
return None

0 commit comments

Comments
 (0)