Skip to content

Commit 390a342

Browse files
committed
Chore: respond to feedback
1 parent df8ae6f commit 390a342

File tree

3 files changed

+50
-13
lines changed

3 files changed

+50
-13
lines changed

trackers/dataset/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
# --- Base Dataset ---
8-
class Dataset(abc.ABC):
8+
class EvaluationDataset(abc.ABC):
99
"""Abstract base class for datasets used in tracking evaluation."""
1010

1111
@abc.abstractmethod

trackers/dataset/mot_challenge.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import supervision as sv
77
from scipy.optimize import linear_sum_assignment
88

9-
from trackers.dataset.base import Dataset
10-
from trackers.dataset.utils import _relabel_ids
9+
from trackers.dataset.base import EvaluationDataset
10+
from trackers.dataset.utils import relabel_ids
1111
from trackers.log import get_logger
1212

1313
# --- Define MOT Constants needed for preprocessing ---
@@ -25,7 +25,7 @@
2525
logger = get_logger(__name__)
2626

2727

28-
class MOTChallengeDataset(Dataset):
28+
class MOTChallengeDataset(EvaluationDataset):
2929
"""
3030
Dataset class for loading sequences in the MOTChallenge format.
3131
Handles parsing `seqinfo.ini`, `gt/gt.txt`, and optionally `det/det.txt`.
@@ -61,9 +61,8 @@ def __init__(self, dataset_path: Union[str, Path]):
6161
if not self.root_path.is_dir():
6262
raise FileNotFoundError(f"Dataset path not found: {self.root_path}")
6363
self._sequence_names = self._find_sequences()
64-
self._public_detections: Optional[Dict[str, sv.Detections]] = (
65-
None # Cache for public detections
66-
)
64+
self._public_detections: Dict[str, sv.Detections] = {}
65+
self._frame_maps: Dict[str, Dict[int, str]] = {}
6766

6867
def _find_sequences(self) -> List[str]:
6968
"""
@@ -362,6 +361,7 @@ def load_public_detections(self, min_confidence: Optional[float] = None) -> None
362361
"""
363362
logger.info("Loading public detections...")
364363
self._public_detections = {}
364+
self._frame_maps = {}
365365
loaded_count = 0
366366
total_dets = 0
367367

@@ -387,6 +387,8 @@ def load_public_detections(self, min_confidence: Optional[float] = None) -> None
387387
for info in self.get_frame_iterator(seq_name)
388388
}
389389

390+
self._frame_maps[seq_name] = frame_map
391+
390392
for frame_idx, detections in frame_detections.items():
391393
if frame_idx not in frame_map:
392394
logger.warning(
@@ -429,7 +431,7 @@ def has_public_detections(self) -> bool:
429431
`load_public_detections`."""
430432
return self._public_detections is not None
431433

432-
def get_public_detections(self, image_path: str) -> sv.Detections:
434+
def get_public_detections_from_image_path(self, image_path: str) -> sv.Detections:
433435
"""
434436
Retrieves the loaded public detections associated with a specific image path.
435437
@@ -456,6 +458,41 @@ def get_public_detections(self, image_path: str) -> sv.Detections:
456458
abs_image_path, sv.Detections.empty()
457459
)
458460

461+
def get_public_detections_from_frame_index(
462+
self, sequence_name: str, frame_idx: int
463+
) -> sv.Detections:
464+
"""
465+
Retrieves the loaded public detections for a specific frame index in a
466+
sequence.
467+
Requires `load_public_detections()` to have been called first.
468+
Args:
469+
sequence_name: The name of the sequence (e.g., 'MOT17-02-SDP').
470+
frame_idx: The frame index (1-based).
471+
Returns:
472+
An sv.Detections object containing the public detections for the
473+
specified frame index. Returns `sv.Detections.empty()` if no detections
474+
were loaded for this frame or if `load_public_detections()` was not
475+
called.
476+
"""
477+
if not self.has_public_detections:
478+
logger.warning(
479+
"Public detections requested but not loaded. \
480+
Call load_public_detections() first."
481+
)
482+
return sv.Detections.empty()
483+
484+
frame_map = self._frame_maps.get(sequence_name, {})
485+
abs_image_path = frame_map.get(frame_idx)
486+
487+
if abs_image_path is None:
488+
logger.warning(
489+
f"No public detections found for sequence {sequence_name} at frame \
490+
{frame_idx}"
491+
)
492+
return sv.Detections.empty()
493+
494+
return self.get_public_detections_from_image_path(abs_image_path)
495+
459496
def preprocess(
460497
self,
461498
gt_dets: sv.Detections,
@@ -596,7 +633,7 @@ def preprocess(
596633
)
597634

598635
# --- TrackEval Preprocessing Step 6: Relabel IDs using the utility function ---
599-
gt_processed = _relabel_ids(gt_processed)
600-
pred_processed = _relabel_ids(pred_processed)
636+
gt_processed = relabel_ids(gt_processed)
637+
pred_processed = relabel_ids(pred_processed)
601638

602639
return gt_processed, pred_processed

trackers/dataset/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
logger = get_logger(__name__) # Added logger instance
77

88

9-
def _relabel_ids(detections: sv.Detections) -> sv.Detections:
9+
def relabel_ids(detections: sv.Detections) -> sv.Detections:
1010
"""
1111
Relabels `tracker_id`s to be contiguous integers starting from 0.
1212
@@ -43,8 +43,8 @@ def _relabel_ids(detections: sv.Detections) -> sv.Detections:
4343
return detections
4444

4545
# Now unique_ids contains only valid integers
46-
max_id = np.max(unique_ids)
47-
min_id = np.min(unique_ids)
46+
max_id: int = np.max(unique_ids)
47+
min_id: int = np.min(unique_ids)
4848

4949
offset = 0
5050
if min_id < 0:

0 commit comments

Comments
 (0)