6
6
import supervision as sv
7
7
from scipy .optimize import linear_sum_assignment
8
8
9
- from trackers .dataset .base import Dataset
10
- from trackers .dataset .utils import _relabel_ids
9
+ from trackers .dataset .base import EvaluationDataset
10
+ from trackers .dataset .utils import relabel_ids
11
11
from trackers .log import get_logger
12
12
13
13
# --- Define MOT Constants needed for preprocessing ---
25
25
logger = get_logger (__name__ )
26
26
27
27
28
- class MOTChallengeDataset (Dataset ):
28
+ class MOTChallengeDataset (EvaluationDataset ):
29
29
"""
30
30
Dataset class for loading sequences in the MOTChallenge format.
31
31
Handles parsing `seqinfo.ini`, `gt/gt.txt`, and optionally `det/det.txt`.
@@ -61,9 +61,8 @@ def __init__(self, dataset_path: Union[str, Path]):
61
61
if not self .root_path .is_dir ():
62
62
raise FileNotFoundError (f"Dataset path not found: { self .root_path } " )
63
63
self ._sequence_names = self ._find_sequences ()
64
- self ._public_detections : Optional [Dict [str , sv .Detections ]] = (
65
- None # Cache for public detections
66
- )
64
+ self ._public_detections : Dict [str , sv .Detections ] = {}
65
+ self ._frame_maps : Dict [str , Dict [int , str ]] = {}
67
66
68
67
def _find_sequences (self ) -> List [str ]:
69
68
"""
@@ -362,6 +361,7 @@ def load_public_detections(self, min_confidence: Optional[float] = None) -> None
362
361
"""
363
362
logger .info ("Loading public detections..." )
364
363
self ._public_detections = {}
364
+ self ._frame_maps = {}
365
365
loaded_count = 0
366
366
total_dets = 0
367
367
@@ -387,6 +387,8 @@ def load_public_detections(self, min_confidence: Optional[float] = None) -> None
387
387
for info in self .get_frame_iterator (seq_name )
388
388
}
389
389
390
+ self ._frame_maps [seq_name ] = frame_map
391
+
390
392
for frame_idx , detections in frame_detections .items ():
391
393
if frame_idx not in frame_map :
392
394
logger .warning (
@@ -429,7 +431,7 @@ def has_public_detections(self) -> bool:
429
431
`load_public_detections`."""
430
432
return self ._public_detections is not None
431
433
432
- def get_public_detections (self , image_path : str ) -> sv .Detections :
434
+ def get_public_detections_from_image_path (self , image_path : str ) -> sv .Detections :
433
435
"""
434
436
Retrieves the loaded public detections associated with a specific image path.
435
437
@@ -456,6 +458,41 @@ def get_public_detections(self, image_path: str) -> sv.Detections:
456
458
abs_image_path , sv .Detections .empty ()
457
459
)
458
460
461
+ def get_public_detections_from_frame_index (
462
+ self , sequence_name : str , frame_idx : int
463
+ ) -> sv .Detections :
464
+ """
465
+ Retrieves the loaded public detections for a specific frame index in a
466
+ sequence.
467
+ Requires `load_public_detections()` to have been called first.
468
+ Args:
469
+ sequence_name: The name of the sequence (e.g., 'MOT17-02-SDP').
470
+ frame_idx: The frame index (1-based).
471
+ Returns:
472
+ An sv.Detections object containing the public detections for the
473
+ specified frame index. Returns `sv.Detections.empty()` if no detections
474
+ were loaded for this frame or if `load_public_detections()` was not
475
+ called.
476
+ """
477
+ if not self .has_public_detections :
478
+ logger .warning (
479
+ "Public detections requested but not loaded. \
480
+ Call load_public_detections() first."
481
+ )
482
+ return sv .Detections .empty ()
483
+
484
+ frame_map = self ._frame_maps .get (sequence_name , {})
485
+ abs_image_path = frame_map .get (frame_idx )
486
+
487
+ if abs_image_path is None :
488
+ logger .warning (
489
+ f"No public detections found for sequence { sequence_name } at frame \
490
+ { frame_idx } "
491
+ )
492
+ return sv .Detections .empty ()
493
+
494
+ return self .get_public_detections_from_image_path (abs_image_path )
495
+
459
496
def preprocess (
460
497
self ,
461
498
gt_dets : sv .Detections ,
@@ -596,7 +633,7 @@ def preprocess(
596
633
)
597
634
598
635
# --- TrackEval Preprocessing Step 6: Relabel IDs using the utility function ---
599
- gt_processed = _relabel_ids (gt_processed )
600
- pred_processed = _relabel_ids (pred_processed )
636
+ gt_processed = relabel_ids (gt_processed )
637
+ pred_processed = relabel_ids (pred_processed )
601
638
602
639
return gt_processed , pred_processed
0 commit comments