Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ description = "A unified library for object tracking"
readme = "README.md"
requires-python = ">=3.9"
dependencies = [
"firerequests>=0.1.2",
"numpy>=2.0.2",
"supervision>=0.25.1",
]

Expand All @@ -30,6 +32,49 @@ docs = [
"mike>=2.1.3",
]

cpu = [
"torch>=2.6.0",
"torchvision>=0.21.0",
]
cu124 = [
"torch>=2.6.0",
"torchvision>=0.21.0",
]

deepsort = [
"scipy>=1.13.1",
"timm>=1.0.15",
"validators>=0.34.0",
]

[tool.uv]
conflicts = [
[
{ extra = "cpu" },
{ extra = "cu124" },
],
]

[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu124", extra = "cu124" },
]
torchvision = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu124", extra = "cu124" },
]

[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true

[[tool.uv.index]]
name = "pytorch-cu124"
url = "https://download.pytorch.org/whl/cu124"
explicit = true

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand Down
4 changes: 3 additions & 1 deletion trackers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from trackers.core.deepsort.feature_extractor import DeepSORTFeatureExtractor
from trackers.core.deepsort.tracker import DeepSORTTracker
from trackers.core.sort.tracker import SORTTracker

__all__ = ["SORTTracker"]
__all__ = ["DeepSORTFeatureExtractor", "DeepSORTTracker", "SORTTracker"]
7 changes: 7 additions & 0 deletions trackers/core/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from abc import ABC, abstractmethod

import numpy as np
import supervision as sv


class BaseTracker(ABC):
@abstractmethod
def update(self, detections: sv.Detections) -> sv.Detections:
pass


class BaseTrackerWithFeatures(ABC):
@abstractmethod
def update(self, detections: sv.Detections, frame: np.ndarray) -> sv.Detections:
pass
Empty file.
157 changes: 157 additions & 0 deletions trackers/core/deepsort/feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from typing import Optional, Tuple, Union

import numpy as np
import supervision as sv
import timm
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import validators
from firerequests import FireRequests

from trackers.utils.torch_utils import parse_device_spec


class FeatureExtractionBackbone(nn.Module):
"""
A simple backbone model for feature extraction.

Args:
backbone_model (nn.Module): The backbone model to use for feature extraction.
"""

def __init__(self, backbone_model: nn.Module):
super(FeatureExtractionBackbone, self).__init__()
self.backbone_model = backbone_model

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""
Forward pass on a single input tensor.

Args:
input_tensor (torch.Tensor): The input tensor.
"""
output = self.backbone_model(input_tensor)
output = torch.squeeze(output)
return output


class DeepSORTFeatureExtractor:
"""
Feature extractor for DeepSORT that loads a PyTorch model and
extracts appearance features from detection crops.

Args:
model_or_checkpoint_path (Union[str, torch.nn.Module]): Path/URL to the PyTorch
model checkpoint or the model itself.
device (str): Device to run the model on.
input_size (Tuple[int, int]): Size to which the input images are resized.
"""

def __init__(
self,
model_or_checkpoint_path: Union[str, torch.nn.Module],
device: Optional[str] = "auto",
input_size: Tuple[int, int] = (128, 128),
):
self.device = parse_device_spec(device or "auto")
self.input_size = input_size

self._initialize_model(model_or_checkpoint_path)

self.transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize(self.input_size),
transforms.ToTensor(),
]
)

@classmethod
def from_timm(
cls,
model_name: str,
device: Optional[str] = "auto",
input_size: Tuple[int, int] = (128, 128),
pretrained: bool = True,
*args,
**kwargs,
):
"""
Create a feature extractor from a timm model.

Args:
model_name (str): Name of the timm model to use.
device (str): Device to run the model on.
input_size (Tuple[int, int]): Size to which the input images are resized.
pretrained (bool): Whether to use pretrained weights from timm or not.
*args: Additional arguments to pass to
[`timm.create_model`](https://huggingface.co/docs/timm/en/reference/models#timm.create_model).
**kwargs: Additional keyword arguments to pass to
[`timm.create_model`](https://huggingface.co/docs/timm/en/reference/models#timm.create_model).

Returns:
DeepSORTFeatureExtractor: A new instance of DeepSORTFeatureExtractor.
"""
if model_name not in timm.list_models(filter=model_name, pretrained=pretrained):
raise ValueError(
f"Model {model_name} not found in timm. "
+ "Please check the model name and try again."
)
model = timm.create_model(model_name, pretrained=pretrained, *args, **kwargs)
backbone_model = FeatureExtractionBackbone(model)
return cls(backbone_model, device, input_size)

def _initialize_model(
self, model_or_checkpoint_path: Union[str, torch.nn.Module, None]
):
if isinstance(model_or_checkpoint_path, str):
if validators.url(model_or_checkpoint_path):
checkpoint_path = FireRequests().download(model_or_checkpoint_path)[0]
self._load_model_from_path(checkpoint_path)
else:
self._load_model_from_path(model_or_checkpoint_path)
else:
self.model = model_or_checkpoint_path
self.model.to(self.device)
self.model.eval()

def _load_model_from_path(self, model_path):
"""
Load the PyTorch model from the given path.

Args:
model_path (str): Path to the model checkpoint.

Returns:
torch.nn.Module: Loaded PyTorch model.
"""
self.model = FeatureExtractionBackbone(torch.load(model_path))
self.model.to(self.device)
self.model.eval()

def extract_features(
self, frame: np.ndarray, detections: sv.Detections
) -> np.ndarray:
"""
Extract features from detection crops in the frame.

Args:
frame (np.ndarray): The input frame.
detections (sv.Detections): Detections from which to extract features.

Returns:
np.ndarray: Extracted features for each detection.
"""
if len(detections) == 0:
return np.array([])

features = []
with torch.no_grad():
for box in detections.xyxy:
crop = sv.crop_image(image=frame, xyxy=[*box.astype(int)])
tensor = self.transform(crop).unsqueeze(0).to(self.device)
feature = self.model(tensor).cpu().numpy().flatten()
features.append(feature)

return np.array(features)
39 changes: 39 additions & 0 deletions trackers/core/deepsort/kalman_box_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Optional, Union

import numpy as np

from trackers.core.sort.kalman_box_tracker import SORTKalmanBoxTracker


class DeepSORTKalmanBoxTracker(SORTKalmanBoxTracker):
"""
The `DeepSORTKalmanBoxTracker` class represents the internals of a single
tracked object (bounding box), with a Kalman filter to predict and update
its position. It also maintains a feature vector for the object, which is
used to identify the object across frames.
"""

def __init__(self, bbox: np.ndarray, feature: Optional[np.ndarray] = None):
super().__init__(bbox)
self.features: list[np.ndarray] = []
if feature is not None:
self.features.append(feature)

def update_feature(self, feature: np.ndarray):
self.features.append(feature)

def get_feature(self) -> Union[np.ndarray, None]:
"""
Get the mean feature vector for this tracker.

Returns:
np.ndarray: Mean feature vector.
"""
if len(self.features) > 0:
# Return the mean of all features, thus (in theory) capturing the
# "average appearance" of the object, which should be more robust
# to minor appearance changes. Otherwise, the last feature can
# also be returned like the following:
# return self.features[-1]
return np.mean(self.features, axis=0)
return None
Loading