Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["trackers"]

[tool.bandit]
# Ignore B614 temporary until permanent fix
# B614: Test for unsafe PyTorch load or save
# https://bandit.readthedocs.io/en/1.7.10/plugins/b704_pytorch_load_save.html
skips = ["B614"]

[tool.ruff]
target-version = "py39"

Expand Down
7 changes: 6 additions & 1 deletion trackers/core/deepsort/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,15 @@ def _initialize_model(
self._load_model_from_path(checkpoint_path)
else:
self._load_model_from_path(model_or_checkpoint_path)
else:
elif isinstance(model_or_checkpoint_path, torch.nn.Module):
self.model = model_or_checkpoint_path
self.model.to(self.device)
self.model.eval()
else:
raise TypeError(
"model_or_checkpoint_path must be a string (path/URL) "
"or a torch.nn.Module instance."
)

def _load_model_from_path(self, model_path):
"""
Expand Down
47 changes: 29 additions & 18 deletions trackers/core/deepsort/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def callback(frame: np.ndarray, _: int):
combined distance.

Args:
feature_extractor (Optional[Union[DeepSORTFeatureExtractor, torch.nn.Module, str]]):
feature_extractor (Union[DeepSORTFeatureExtractor, torch.nn.Module, str]):
A feature extractor model checkpoint URL, model checkpoint path, or model
instance or an instance of `DeepSORTFeatureExtractor` to extract
appearance features. By default, the a default model checkpoint is downloaded
Expand Down Expand Up @@ -153,9 +153,7 @@ def callback(frame: np.ndarray, _: int):

def __init__(
self,
feature_extractor: Optional[
Union[DeepSORTFeatureExtractor, torch.nn.Module, str]
] = None,
feature_extractor: Union[DeepSORTFeatureExtractor, torch.nn.Module, str],
device: Optional[str] = None,
lost_track_buffer: int = 30,
frame_rate: float = 30.0,
Expand All @@ -166,20 +164,9 @@ def __init__(
appearance_weight: float = 0.5,
distance_metric: str = "cosine",
):
if feature_extractor is None:
self.feature_extractor = DeepSORTFeatureExtractor(device=device)
elif isinstance(feature_extractor, str):
self.feature_extractor = DeepSORTFeatureExtractor(
model_or_checkpoint_path=feature_extractor,
device=device,
)
elif isinstance(feature_extractor, torch.nn.Module):
self.feature_extractor = DeepSORTFeatureExtractor(
model_or_checkpoint_path=feature_extractor,
device=device,
)
else:
self.feature_extractor = feature_extractor
self.feature_extractor = self._initialize_feature_extractor(
feature_extractor, device
)

self.lost_track_buffer = lost_track_buffer
self.frame_rate = frame_rate
Expand All @@ -198,6 +185,30 @@ def __init__(

self.trackers: list[DeepSORTKalmanBoxTracker] = []

def _initialize_feature_extractor(
self,
feature_extractor: Union[DeepSORTFeatureExtractor, torch.nn.Module, str],
device: Optional[str],
) -> DeepSORTFeatureExtractor:
"""
Initialize the feature extractor based on the input type.

Args:
feature_extractor: The feature extractor input, which can be a model path,
a torch module, or a DeepSORTFeatureExtractor instance.
device: The device to run the model on.

Returns:
DeepSORTFeatureExtractor: The initialized feature extractor.
"""
if isinstance(feature_extractor, (str, torch.nn.Module)):
return DeepSORTFeatureExtractor(
model_or_checkpoint_path=feature_extractor,
device=device,
)
else:
return feature_extractor

def _get_appearance_distance_matrix(
self,
detection_features: np.ndarray,
Expand Down