Skip to content

Commit abc248a

Browse files
committed
add: DeepSORTTracker
1 parent 622c70b commit abc248a

File tree

10 files changed

+2760
-195
lines changed

10 files changed

+2760
-195
lines changed

pyproject.toml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ description = "A unified library for object tracking"
55
readme = "README.md"
66
requires-python = ">=3.9"
77
dependencies = [
8+
"firerequests>=0.1.2",
9+
"numpy>=2.0.2",
810
"supervision>=0.25.1",
911
]
1012

@@ -30,6 +32,49 @@ docs = [
3032
"mike>=2.1.3",
3133
]
3234

35+
cpu = [
36+
"torch>=2.6.0",
37+
"torchvision>=0.21.0",
38+
]
39+
cu124 = [
40+
"torch>=2.6.0",
41+
"torchvision>=0.21.0",
42+
]
43+
44+
deepsort = [
45+
"scipy>=1.13.1",
46+
"timm>=1.0.15",
47+
"validators>=0.34.0",
48+
]
49+
50+
[tool.uv]
51+
conflicts = [
52+
[
53+
{ extra = "cpu" },
54+
{ extra = "cu124" },
55+
],
56+
]
57+
58+
[tool.uv.sources]
59+
torch = [
60+
{ index = "pytorch-cpu", extra = "cpu" },
61+
{ index = "pytorch-cu124", extra = "cu124" },
62+
]
63+
torchvision = [
64+
{ index = "pytorch-cpu", extra = "cpu" },
65+
{ index = "pytorch-cu124", extra = "cu124" },
66+
]
67+
68+
[[tool.uv.index]]
69+
name = "pytorch-cpu"
70+
url = "https://download.pytorch.org/whl/cpu"
71+
explicit = true
72+
73+
[[tool.uv.index]]
74+
name = "pytorch-cu124"
75+
url = "https://download.pytorch.org/whl/cu124"
76+
explicit = true
77+
3378
[build-system]
3479
requires = ["hatchling"]
3580
build-backend = "hatchling.build"

trackers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from trackers.core.deepsort.feature_extractor import DeepSORTFeatureExtractor
2+
from trackers.core.deepsort.tracker import DeepSORTTracker
13
from trackers.core.sort.tracker import SORTTracker
24

3-
__all__ = ["SORTTracker"]
5+
__all__ = ["DeepSORTFeatureExtractor", "DeepSORTTracker", "SORTTracker"]

trackers/core/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
from abc import ABC, abstractmethod
22

3+
import numpy as np
34
import supervision as sv
45

56

67
class BaseTracker(ABC):
78
@abstractmethod
89
def update(self, detections: sv.Detections) -> sv.Detections:
910
pass
11+
12+
13+
class BaseTrackerWithFeatures(ABC):
14+
@abstractmethod
15+
def update(self, detections: sv.Detections, frame: np.ndarray) -> sv.Detections:
16+
pass

trackers/core/deepsort/__init__.py

Whitespace-only changes.
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from typing import Optional, Tuple, Union
2+
3+
import numpy as np
4+
import supervision as sv
5+
import timm
6+
import torch
7+
import torch.nn as nn
8+
import torchvision.transforms as transforms
9+
import validators
10+
from firerequests import FireRequests
11+
12+
from trackers.utils.torch_utils import parse_device_spec
13+
14+
15+
class FeatureExtractionBackbone(nn.Module):
16+
"""
17+
A simple backbone model for feature extraction.
18+
19+
Args:
20+
backbone_model (nn.Module): The backbone model to use for feature extraction.
21+
"""
22+
23+
def __init__(self, backbone_model: nn.Module):
24+
super(FeatureExtractionBackbone, self).__init__()
25+
self.backbone_model = backbone_model
26+
27+
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
28+
"""
29+
Forward pass on a single input tensor.
30+
31+
Args:
32+
input_tensor (torch.Tensor): The input tensor.
33+
"""
34+
output = self.backbone_model(input_tensor)
35+
output = torch.squeeze(output)
36+
return output
37+
38+
39+
class DeepSORTFeatureExtractor:
40+
"""
41+
Feature extractor for DeepSORT that loads a PyTorch model and
42+
extracts appearance features from detection crops.
43+
44+
Args:
45+
model_or_checkpoint_path (Union[str, torch.nn.Module]): Path/URL to the PyTorch
46+
model checkpoint or the model itself.
47+
device (str): Device to run the model on.
48+
input_size (Tuple[int, int]): Size to which the input images are resized.
49+
"""
50+
51+
def __init__(
52+
self,
53+
model_or_checkpoint_path: Union[str, torch.nn.Module],
54+
device: Optional[str] = "auto",
55+
input_size: Tuple[int, int] = (128, 128),
56+
):
57+
self.device = parse_device_spec(device or "auto")
58+
self.input_size = input_size
59+
60+
self._initialize_model(model_or_checkpoint_path)
61+
62+
self.transform = transforms.Compose(
63+
[
64+
transforms.ToPILImage(),
65+
transforms.Resize(self.input_size),
66+
transforms.ToTensor(),
67+
]
68+
)
69+
70+
@classmethod
71+
def from_timm(
72+
cls,
73+
model_name: str,
74+
device: Optional[str] = "auto",
75+
input_size: Tuple[int, int] = (128, 128),
76+
pretrained: bool = True,
77+
*args,
78+
**kwargs,
79+
):
80+
"""
81+
Create a feature extractor from a timm model.
82+
83+
Args:
84+
model_name (str): Name of the timm model to use.
85+
device (str): Device to run the model on.
86+
input_size (Tuple[int, int]): Size to which the input images are resized.
87+
pretrained (bool): Whether to use pretrained weights from timm or not.
88+
*args: Additional arguments to pass to
89+
[`timm.create_model`](https://huggingface.co/docs/timm/en/reference/models#timm.create_model).
90+
**kwargs: Additional keyword arguments to pass to
91+
[`timm.create_model`](https://huggingface.co/docs/timm/en/reference/models#timm.create_model).
92+
93+
Returns:
94+
DeepSORTFeatureExtractor: A new instance of DeepSORTFeatureExtractor.
95+
"""
96+
if model_name not in timm.list_models(filter=model_name, pretrained=pretrained):
97+
raise ValueError(
98+
f"Model {model_name} not found in timm. "
99+
+ "Please check the model name and try again."
100+
)
101+
model = timm.create_model(model_name, pretrained=pretrained, *args, **kwargs)
102+
backbone_model = FeatureExtractionBackbone(model)
103+
return cls(backbone_model, device, input_size)
104+
105+
def _initialize_model(
106+
self, model_or_checkpoint_path: Union[str, torch.nn.Module, None]
107+
):
108+
if isinstance(model_or_checkpoint_path, str):
109+
if validators.url(model_or_checkpoint_path):
110+
checkpoint_path = FireRequests().download(model_or_checkpoint_path)[0]
111+
self._load_model_from_path(checkpoint_path)
112+
else:
113+
self._load_model_from_path(model_or_checkpoint_path)
114+
else:
115+
self.model = model_or_checkpoint_path
116+
self.model.to(self.device)
117+
self.model.eval()
118+
119+
def _load_model_from_path(self, model_path):
120+
"""
121+
Load the PyTorch model from the given path.
122+
123+
Args:
124+
model_path (str): Path to the model checkpoint.
125+
126+
Returns:
127+
torch.nn.Module: Loaded PyTorch model.
128+
"""
129+
self.model = FeatureExtractionBackbone(torch.load(model_path))
130+
self.model.to(self.device)
131+
self.model.eval()
132+
133+
def extract_features(
134+
self, frame: np.ndarray, detections: sv.Detections
135+
) -> np.ndarray:
136+
"""
137+
Extract features from detection crops in the frame.
138+
139+
Args:
140+
frame (np.ndarray): The input frame.
141+
detections (sv.Detections): Detections from which to extract features.
142+
143+
Returns:
144+
np.ndarray: Extracted features for each detection.
145+
"""
146+
if len(detections) == 0:
147+
return np.array([])
148+
149+
features = []
150+
with torch.no_grad():
151+
for box in detections.xyxy:
152+
crop = sv.crop_image(image=frame, xyxy=[*box.astype(int)])
153+
tensor = self.transform(crop).unsqueeze(0).to(self.device)
154+
feature = self.model(tensor).cpu().numpy().flatten()
155+
features.append(feature)
156+
157+
return np.array(features)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Optional, Union
2+
3+
import numpy as np
4+
5+
from trackers.core.sort.kalman_box_tracker import SORTKalmanBoxTracker
6+
7+
8+
class DeepSORTKalmanBoxTracker(SORTKalmanBoxTracker):
9+
"""
10+
The `DeepSORTKalmanBoxTracker` class represents the internals of a single
11+
tracked object (bounding box), with a Kalman filter to predict and update
12+
its position. It also maintains a feature vector for the object, which is
13+
used to identify the object across frames.
14+
"""
15+
16+
def __init__(self, bbox: np.ndarray, feature: Optional[np.ndarray] = None):
17+
super().__init__(bbox)
18+
self.features: list[np.ndarray] = []
19+
if feature is not None:
20+
self.features.append(feature)
21+
22+
def update_feature(self, feature: np.ndarray):
23+
self.features.append(feature)
24+
25+
def get_feature(self) -> Union[np.ndarray, None]:
26+
"""
27+
Get the mean feature vector for this tracker.
28+
29+
Returns:
30+
np.ndarray: Mean feature vector.
31+
"""
32+
if len(self.features) > 0:
33+
# Return the mean of all features, thus (in theory) capturing the
34+
# "average appearance" of the object, which should be more robust
35+
# to minor appearance changes. Otherwise, the last feature can
36+
# also be returned like the following:
37+
# return self.features[-1]
38+
return np.mean(self.features, axis=0)
39+
return None

0 commit comments

Comments
 (0)