diff --git a/boxmot/motion/kalman_filters/xywha_kf.py b/boxmot/motion/kalman_filters/xywha_kf.py new file mode 100644 index 0000000000..73d1b04d5f --- /dev/null +++ b/boxmot/motion/kalman_filters/xywha_kf.py @@ -0,0 +1,468 @@ +""" +This module implements the linear Kalman filter in both an object +oriented and procedural form. The KalmanFilter class implements +the filter by storing the various matrices in instance variables, +minimizing the amount of bookkeeping you have to do. +All Kalman filters operate with a predict->update cycle. The +predict step, implemented with the method or function predict(), +uses the state transition matrix F to predict the state in the next +time period (epoch). The state is stored as a gaussian (x, P), where +x is the state (column) vector, and P is its covariance. Covariance +matrix Q specifies the process covariance. In Bayesian terms, this +prediction is called the *prior*, which you can think of colloquially +as the estimate prior to incorporating the measurement. +The update step, implemented with the method or function `update()`, +incorporates the measurement z with covariance R, into the state +estimate (x, P). The class stores the system uncertainty in S, +the innovation (residual between prediction and measurement in +measurement space) in y, and the Kalman gain in k. The procedural +form returns these variables to you. In Bayesian terms this computes +the *posterior* - the estimate after the information from the +measurement is incorporated. +Whether you use the OO form or procedural form is up to you. If +matrices such as H, R, and F are changing each epoch, you'll probably +opt to use the procedural form. If they are unchanging, the OO +form is perhaps easier to use since you won't need to keep track +of these matrices. This is especially useful if you are implementing +banks of filters or comparing various KF designs for performance; +a trivial coding bug could lead to using the wrong sets of matrices. +This module also offers an implementation of the RTS smoother, and +other helper functions, such as log likelihood computations. +The Saver class allows you to easily save the state of the +KalmanFilter class after every update. +""" + +from __future__ import absolute_import, division + +from copy import deepcopy +from math import log, exp, sqrt +import sys +import numpy as np +from numpy import dot, zeros, eye, isscalar, shape +import numpy.linalg as linalg +from filterpy.stats import logpdf +from filterpy.common import pretty_str, reshape_z +from collections import deque + + +class KalmanFilterXYWHA(object): + """ Implements a Kalman filter. You are responsible for setting the + various state variables to reasonable values; the defaults will + not give you a functional filter. + """ + + def __init__(self, dim_x, dim_z, dim_u=0, max_obs=50): + if dim_x < 1: + raise ValueError('dim_x must be 1 or greater') + if dim_z < 1: + raise ValueError('dim_z must be 1 or greater') + if dim_u < 0: + raise ValueError('dim_u must be 0 or greater') + + self.dim_x = dim_x + self.dim_z = dim_z + self.dim_u = dim_u + + self.x = zeros((dim_x, 1)) # state + self.P = eye(dim_x) # uncertainty covariance + self.Q = eye(dim_x) # process uncertainty + self.B = None # control transition matrix + self.F = eye(dim_x) # state transition matrix + self.H = zeros((dim_z, dim_x)) # measurement function + self.R = eye(dim_z) # measurement uncertainty + self._alpha_sq = 1. # fading memory control + self.M = np.zeros((dim_x, dim_z)) # process-measurement cross correlation + self.z = np.array([[None]*self.dim_z]).T + + # gain and residual are computed during the innovation step. We + # save them so that in case you want to inspect them for various + # purposes + self.K = np.zeros((dim_x, dim_z)) # kalman gain + self.y = zeros((dim_z, 1)) + self.S = np.zeros((dim_z, dim_z)) # system uncertainty + self.SI = np.zeros((dim_z, dim_z)) # inverse system uncertainty + + # identity matrix. Do not alter this. + self._I = np.eye(dim_x) + + # these will always be a copy of x,P after predict() is called + self.x_prior = self.x.copy() + self.P_prior = self.P.copy() + + # these will always be a copy of x,P after update() is called + self.x_post = self.x.copy() + self.P_post = self.P.copy() + + # Only computed only if requested via property + self._log_likelihood = log(sys.float_info.min) + self._likelihood = sys.float_info.min + self._mahalanobis = None + + # keep all observations + self.max_obs = max_obs + self.history_obs = deque([], maxlen=self.max_obs) + + self.inv = np.linalg.inv + + self.attr_saved = None + self.observed = False + self.last_measurement = None + + + def apply_affine_correction(self, m, t): + #TODO ADAPT FOR OBB + """ + Apply to both last state and last observation for OOS smoothing. + + Messy due to internal logic for kalman filter being messy. + """ + + scale = np.linalg.norm(m[:, 0]) + self.x[:2] = m @ self.x[:2] + t + self.x[4:6] = m @ self.x[4:6] + + self.P[:2, :2] = m @ self.P[:2, :2] @ m.T + self.P[4:6, 4:6] = m @ self.P[4:6, 4:6] @ m.T + + # If frozen, also need to update the frozen state for OOS + if not self.observed and self.attr_saved is not None: + self.attr_saved["x"][:2] = m @ self.attr_saved["x"][:2] + t + self.attr_saved["x"][4:6] = m @ self.attr_saved["x"][4:6] + + self.attr_saved["P"][:2, :2] = m @ self.attr_saved["P"][:2, :2] @ m.T + self.attr_saved["P"][4:6, 4:6] = m @ self.attr_saved["P"][4:6, 4:6] @ m.T + + self.attr_saved["last_measurement"][:2] = m @ self.attr_saved["last_measurement"][:2] + t + + + def predict(self, u=None, B=None, F=None, Q=None): + """ + Predict next state (prior) using the Kalman filter state propagation + equations. + Parameters + ---------- + u : np.array, default 0 + Optional control vector. + B : np.array(dim_x, dim_u), or None + Optional control transition matrix; a value of None + will cause the filter to use `self.B`. + F : np.array(dim_x, dim_x), or None + Optional state transition matrix; a value of None + will cause the filter to use `self.F`. + Q : np.array(dim_x, dim_x), scalar, or None + Optional process noise matrix; a value of None will cause the + filter to use `self.Q`. + """ + if B is None: + B = self.B + if F is None: + F = self.F + if Q is None: + Q = self.Q + elif isscalar(Q): + Q = eye(self.dim_x) * Q + + # x = Fx + Bu + if B is not None and u is not None: + self.x = dot(F, self.x) + dot(B, u) + else: + self.x = dot(F, self.x) + + # P = FPF' + Q + self.P = self._alpha_sq * dot(dot(F, self.P), F.T) + Q + + # save prior + self.x_prior = self.x.copy() + self.P_prior = self.P.copy() + + def freeze(self): + """ + Save the parameters before non-observation forward + """ + self.attr_saved = deepcopy(self.__dict__) + + def unfreeze(self): + if self.attr_saved is not None: + new_history = deepcopy(list(self.history_obs)) + self.__dict__ = self.attr_saved + self.history_obs = deque(list(self.history_obs)[:-1], maxlen=self.max_obs) + occur = [int(d is None) for d in new_history] + indices = np.where(np.array(occur) == 0)[0] + index1, index2 = indices[-2], indices[-1] + box1, box2 = new_history[index1], new_history[index2] + x1, y1, w1, h1, a1 = box1 + x2, y2, w2, h2, a2 = box2 + time_gap = index2 - index1 + dx, dy = (x2 - x1) / time_gap, (y2 - y1) / time_gap + dw, dh = (w2 - w1) / time_gap, (h2 - h1) / time_gap + da = (a2 - a1) / time_gap + for i in range(index2 - index1): + x, y = x1 + (i + 1) * dx, y1 + (i + 1) * dy + w, h = w1 + (i + 1) * dw, h1 + (i + 1) * dh + a = a1 + (i + 1) * da + new_box = np.array([x, y, w, h, a]).reshape((5, 1)) + self.update(new_box) + if not i == (index2 - index1 - 1): + self.predict() + self.history_obs.pop() + self.history_obs.pop() + + def update(self, z, R=None, H=None): + """ + Add a new measurement (z) to the Kalman filter. If z is None, nothing is changed. + Parameters + ---------- + z : np.array + Measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be a column vector. + R : np.array, scalar, or None + Measurement noise. If None, the filter's self.R value is used. + H : np.array, or None + Measurement function. If None, the filter's self.H value is used. + """ + + # set to None to force recompute + self._log_likelihood = None + self._likelihood = None + self._mahalanobis = None + + # append the observation + self.history_obs.append(z) + + if z is None: + if self.observed: + """ + Got no observation so freeze the current parameters for future + potential online smoothing. + """ + self.last_measurement = self.history_obs[-2] + self.freeze() + self.observed = False + self.z = np.array([[None] * self.dim_z]).T + self.x_post = self.x.copy() + self.P_post = self.P.copy() + self.y = zeros((self.dim_z, 1)) + return + + # self.observed = True + if not self.observed: + """ + Get observation, use online smoothing to re-update parameters + """ + self.unfreeze() + self.observed = True + + if R is None: + R = self.R + elif isscalar(R): + R = eye(self.dim_z) * R + if H is None: + z = reshape_z(z, self.dim_z, self.x.ndim) + H = self.H + + # y = z - Hx + # error (residual) between measurement and prediction + self.y = z - dot(H, self.x) + + # common subexpression for speed + PHT = dot(self.P, H.T) + + # S = HPH' + R + self.S = dot(H, PHT) + R + self.SI = self.inv(self.S) + + # K = PH'inv(S) + self.K = PHT.dot(self.SI) + + # x = x + Ky + self.x = self.x + dot(self.K, self.y) + + # P = (I-KH)P(I-KH)' + KRK' + I_KH = self._I - dot(self.K, H) + self.P = dot(dot(I_KH, self.P), I_KH.T) + dot(dot(self.K, R), self.K.T) + + # save measurement and posterior state + self.z = deepcopy(z) + self.x_post = self.x.copy() + self.P_post = self.P.copy() + + # save history of observations + self.history_obs.append(z) + + def update_steadystate(self, z, H=None): + """ Update Kalman filter using the Kalman gain and state covariance + matrix as computed for the steady state. Only x is updated, and the + new value is stored in self.x. P is left unchanged. Must be called + after a prior call to compute_steady_state(). + """ + if z is None: + self.history_obs.append(z) + return + + if H is None: + H = self.H + + H = np.asarray(H) + # error (residual) between measurement and prediction + self.y = z - dot(H, self.x) + + # x = x + Ky + self.x = self.x + dot(self.K_steady_state, self.y) + + # save measurement and posterior state + self.z = deepcopy(z) + self.x_post = self.x.copy() + + # save history of observations + self.history_obs.append(z) + + def log_likelihood(self, z=None): + """ log-likelihood of the measurement z. Computed from the + system uncertainty S. + """ + + if z is None: + z = self.z + return logpdf(z, dot(self.H, self.x), self.S) + + def likelihood(self, z=None): + """ likelihood of the measurement z. Computed from the + system uncertainty S. + """ + + if z is None: + z = self.z + return exp(self.log_likelihood(z)) + + @property + def log_likelihood(self): + """ log-likelihood of the last measurement. + """ + + return self._log_likelihood + + @property + def likelihood(self): + """ likelihood of the last measurement. + """ + + return self._likelihood + + +def batch_filter(x, P, zs, Fs, Qs, Hs, Rs, Bs=None, us=None, update_first=False, saver=None): + """ + Batch processes a sequences of measurements. + Parameters + ---------- + zs : list-like + list of measurements at each time step. Missing measurements must be + represented by None. + Fs : list-like + list of values to use for the state transition matrix matrix. + Qs : list-like + list of values to use for the process error + covariance. + Hs : list-like + list of values to use for the measurement matrix. + Rs : list-like + list of values to use for the measurement error + covariance. + Bs : list-like, optional + list of values to use for the control transition matrix; + a value of None in any position will cause the filter + to use `self.B` for that time step. + us : list-like, optional + list of values to use for the control input vector; + a value of None in any position will cause the filter to use + 0 for that time step. + update_first : bool, optional + controls whether the order of operations is update followed by + predict, or predict followed by update. Default is predict->update. + saver : filterpy.common.Saver, optional + filterpy.common.Saver object. If provided, saver.save() will be + called after every epoch + Returns + ------- + means : np.array((n,dim_x,1)) + array of the state for each time step after the update. Each entry + is an np.array. In other words `means[k,:]` is the state at step + `k`. + covariance : np.array((n,dim_x,dim_x)) + array of the covariances for each time step after the update. + In other words `covariance[k,:,:]` is the covariance at step `k`. + means_predictions : np.array((n,dim_x,1)) + array of the state for each time step after the predictions. Each + entry is an np.array. In other words `means[k,:]` is the state at + step `k`. + covariance_predictions : np.array((n,dim_x,dim_x)) + array of the covariances for each time step after the prediction. + In other words `covariance[k,:,:]` is the covariance at step `k`. + Examples + -------- + .. code-block:: Python + zs = [t + random.randn()*4 for t in range (40)] + Fs = [kf.F for t in range (40)] + Hs = [kf.H for t in range (40)] + (mu, cov, _, _) = kf.batch_filter(zs, Rs=R_list, Fs=Fs, Hs=Hs, Qs=None, + Bs=None, us=None, update_first=False) + (xs, Ps, Ks, Pps) = kf.rts_smoother(mu, cov, Fs=Fs, Qs=None) + """ + + n = np.size(zs, 0) + dim_x = x.shape[0] + + # mean estimates from Kalman Filter + if x.ndim == 1: + means = zeros((n, dim_x)) + means_p = zeros((n, dim_x)) + else: + means = zeros((n, dim_x, 1)) + means_p = zeros((n, dim_x, 1)) + + # state covariances from Kalman Filter + covariances = zeros((n, dim_x, dim_x)) + covariances_p = zeros((n, dim_x, dim_x)) + + if us is None: + us = [0.0] * n + Bs = [0.0] * n + + if update_first: + for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)): + + x, P = update(x, P, z, R=R, H=H) + means[i, :] = x + covariances[i, :, :] = P + + x, P = predict(x, P, u=u, B=B, F=F, Q=Q) + means_p[i, :] = x + covariances_p[i, :, :] = P + if saver is not None: + saver.save() + else: + for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)): + + x, P = predict(x, P, u=u, B=B, F=F, Q=Q) + means_p[i, :] = x + covariances_p[i, :, :] = P + + x, P = update(x, P, z, R=R, H=H) + means[i, :] = x + covariances[i, :, :] = P + if saver is not None: + saver.save() + + return (means, covariances, means_p, covariances_p) + + def batch_filter(self, zs, Rs=None): + """ + Batch process a sequence of measurements. This method is suitable + for cases where the measurement noise varies with each measurement. + """ + means, covariances = [], [] + for z, R in zip(zs, Rs): + self.predict() + self.update(z, R=R) + means.append(self.x.copy()) + covariances.append(self.P.copy()) + return np.array(means), np.array(covariances) \ No newline at end of file diff --git a/boxmot/trackers/basetracker.py b/boxmot/trackers/basetracker.py index c56ac0f669..c7b6ab6c5d 100644 --- a/boxmot/trackers/basetracker.py +++ b/boxmot/trackers/basetracker.py @@ -17,7 +17,8 @@ def __init__( max_obs: int = 50, nr_classes: int = 80, per_class: bool = False, - asso_func: str = 'iou' + asso_func: str = 'iou', + is_obb: bool = False ): """ Initialize the BaseTracker object with detection threshold, maximum age, minimum hits, @@ -41,8 +42,9 @@ def __init__( self.nr_classes = nr_classes self.iou_threshold = iou_threshold self.last_emb_size = None - self.asso_func_name = asso_func - + self.asso_func_name = asso_func+"_obb" if is_obb else asso_func + self.is_obb = is_obb + self.frame_count = 0 self.active_tracks = [] # This might be handled differently in derived classes self.per_class_active_tracks = None @@ -178,9 +180,17 @@ def check_inputs(self, dets, img): assert ( len(dets.shape) == 2 ), "Unsupported 'dets' dimensions, valid number of dimensions is two" - assert ( - dets.shape[1] == 6 - ), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6" + if self.is_obb : + + assert ( + dets.shape[1] == 7 + ), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6 (cx,cy,w,h,angle,conf,cls)" + + else : + assert ( + dets.shape[1] == 6 + ), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6 (x1,y1,x2,y2,conf,cls)" + def id_to_color(self, id: int, saturation: float = 0.75, value: float = 0.95) -> tuple: """ diff --git a/boxmot/trackers/ocsort/ocsort.py b/boxmot/trackers/ocsort/ocsort.py index 0131463a32..573d24df3d 100644 --- a/boxmot/trackers/ocsort/ocsort.py +++ b/boxmot/trackers/ocsort/ocsort.py @@ -6,16 +6,21 @@ import numpy as np from collections import deque +import cv2 as cv from boxmot.motion.kalman_filters.xysr_kf import KalmanFilterXYSR +from boxmot.motion.kalman_filters.xywha_kf import KalmanFilterXYWHA from boxmot.utils.association import associate, linear_assignment from boxmot.trackers.basetracker import BaseTracker from boxmot.utils.ops import xyxy2xysr -def k_previous_obs(observations, cur_age, k): +def k_previous_obs(observations, cur_age, k, is_obb=False): if len(observations) == 0: - return [-1, -1, -1, -1, -1] + if is_obb: + return [-1, -1, -1, -1, -1, -1] + else : + return [-1, -1, -1, -1, -1] for i in range(k): dt = k - i if cur_age - dt in observations: @@ -48,6 +53,13 @@ def speed_direction(bbox1, bbox2): norm = np.sqrt((cy2 - cy1) ** 2 + (cx2 - cx1) ** 2) + 1e-6 return speed / norm +def speed_direction_obb(bbox1, bbox2): + cx1, cy1 = bbox1[0], bbox1[1] + cx2, cy2 = bbox2[0], bbox2[1] + speed = np.array([cy2 - cy1, cx2 - cx1]) + norm = np.sqrt((cy2 - cy1) ** 2 + (cx2 - cx1) ** 2) + 1e-6 + return speed / norm + class KalmanBoxTracker(object): """ @@ -179,6 +191,141 @@ def get_state(self): return convert_x_to_bbox(self.kf.x) +class KalmanBoxTrackerOBB(object): + """ + This class represents the internal state of individual tracked objects observed as oriented bbox. + """ + + count = 0 + + def __init__(self, bbox, cls, det_ind, delta_t=3, max_obs=50, Q_xy_scaling = 0.01, Q_a_scaling = 0.01): + """ + Initialises a tracker using initial bounding box. + + """ + # define constant velocity model + self.det_ind = det_ind + + self.Q_xy_scaling = Q_xy_scaling + self.Q_a_scaling = Q_a_scaling + + self.kf = KalmanFilterXYWHA(dim_x=10, dim_z=5, max_obs=max_obs) + self.kf.F = np.array( + [ + [1, 0, 0, 0, 0, 1, 0, 0, 0, 0], # cx = cx + vx + [0, 1, 0, 0, 0, 0, 1, 0, 0, 0], # cy = cy + vy + [0, 0, 1, 0, 0, 0, 0, 1, 0, 0], # w = w + vw + [0, 0, 0, 1, 0, 0, 0, 0, 1, 0], # h = h + vh + [0, 0, 0, 0, 1, 0, 0, 0, 0, 1], # a = a + va + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1] + ] + ) + self.kf.H = np.array( + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0 ,0], # cx + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], # cy + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # w + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # h + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], # angle + ] + ) + + self.kf.R[2:, 2:] *= 10.0 + self.kf.P[ + 5:, 5: + ] *= 1000.0 # give high uncertainty to the unobservable initial velocities + self.kf.P *= 10.0 + + self.kf.Q[5:7, 5:7] *= self.Q_xy_scaling + self.kf.Q[-1, -1] *= self.Q_a_scaling + + self.kf.x[:5] = bbox[:5].reshape((5, 1)) # x, y, w, h, angle (dont take confidence score) + self.time_since_update = 0 + self.id = KalmanBoxTrackerOBB.count + KalmanBoxTrackerOBB.count += 1 + self.max_obs = max_obs + self.history = deque([], maxlen=self.max_obs) + self.hits = 0 + self.hit_streak = 0 + self.age = 0 + self.conf = bbox[-1] + self.cls = cls + """ + NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of + function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a + fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), + let's bear it for now. + """ + self.last_observation = np.array([-1, -1, -1, -1, -1, -1]) #WARNING : -1 is a valid angle value + self.observations = dict() + self.history_observations = deque([], maxlen=self.max_obs) + self.velocity = None + self.delta_t = delta_t + + def update(self, bbox, cls, det_ind): + """ + Updates the state vector with observed bbox. + """ + self.det_ind = det_ind + if bbox is not None: + self.conf = bbox[-1] + self.cls = cls + if self.last_observation.sum() >= 0: # no previous observation + previous_box = None + for i in range(self.delta_t): + dt = self.delta_t - i + if self.age - dt in self.observations: + previous_box = self.observations[self.age - dt] + break + if previous_box is None: + previous_box = self.last_observation + """ + Estimate the track speed direction with observations \Delta t steps away + """ + self.velocity = speed_direction_obb(previous_box, bbox) + + """ + Insert new observations. This is a ugly way to maintain both self.observations + and self.history_observations. Bear it for the moment. + """ + self.last_observation = bbox + self.observations[self.age] = bbox + self.history_observations.append(bbox) + + self.time_since_update = 0 + self.hits += 1 + self.hit_streak += 1 + self.kf.update(bbox[:5].reshape((5, 1))) # x, y, w, h, angle as column vector (dont take confidence score) + else: + self.kf.update(bbox) + + def predict(self): + """ + Advances the state vector and returns the predicted bounding box estimate. + """ + if (self.kf.x[7] + self.kf.x[2]) <= 0: # Negative width + self.kf.x[7] *= 0.0 + if (self.kf.x[8] + self.kf.x[3]) <= 0: # Negative Height + self.kf.x[8] *= 0.0 + self.kf.predict() + self.age += 1 + if self.time_since_update > 0: + self.hit_streak = 0 + self.time_since_update += 1 + self.history.append(self.kf.x[0:5].reshape((1, 5))) + return self.history[-1] + + def get_state(self): + """ + Returns the current bounding box estimate. + """ + return self.kf.x[0:5].reshape((1, 5)) + + class OcSort(BaseTracker): """ OCSort Tracker: A tracking algorithm that utilizes motion-based tracking. @@ -208,9 +355,10 @@ def __init__( inertia: float = 0.2, use_byte: bool = False, Q_xy_scaling: float = 0.01, - Q_s_scaling: float = 0.0001 + Q_s_scaling: float = 0.0001, + is_obb: bool = False ): - super().__init__(max_age=max_age, per_class=per_class, asso_func=asso_func) + super().__init__(max_age=max_age, per_class=per_class, asso_func=asso_func, is_obb=is_obb) """ Sets key parameters for SORT """ @@ -245,7 +393,7 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> h, w = img.shape[0:2] dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)]) - confs = dets[:, 4] + confs = dets[:, 4+self.is_obb] inds_low = confs > 0.1 inds_high = confs < self.det_thresh @@ -257,12 +405,12 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> dets = dets[remain_inds] # get predicted locations from existing trackers. - trks = np.zeros((len(self.active_tracks), 5)) + trks = np.zeros((len(self.active_tracks), 5+self.is_obb)) to_del = [] ret = [] for t, trk in enumerate(trks): pos = self.active_tracks[t].predict()[0] - trk[:] = [pos[0], pos[1], pos[2], pos[3], 0] + trk[:] = [pos[i] for i in range(4+self.is_obb)] + [0] if np.any(np.isnan(pos)): to_del.append(t) trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) @@ -276,9 +424,10 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> ] ) last_boxes = np.array([trk.last_observation for trk in self.active_tracks]) + k_observations = np.array( [ - k_previous_obs(trk.observations, trk.age, self.delta_t) + k_previous_obs(trk.observations, trk.age, self.delta_t, is_obb=self.is_obb) for trk in self.active_tracks ] ) @@ -287,10 +436,10 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> First round of association """ matched, unmatched_dets, unmatched_trks = associate( - dets[:, 0:5], trks, self.asso_func, self.asso_threshold, velocities, k_observations, self.inertia, w, h + dets[:, 0:5+self.is_obb], trks, self.asso_func, self.asso_threshold, velocities, k_observations, self.inertia, w, h ) for m in matched: - self.active_tracks[m[1]].update(dets[m[0], :5], dets[m[0], 5], dets[m[0], 6]) + self.active_tracks[m[1]].update(dets[m[0], :-2], dets[m[0], -2], dets[m[0], -1]) """ Second round of associaton by OCR @@ -315,7 +464,7 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> if iou_left[m[0], m[1]] < self.asso_threshold: continue self.active_tracks[trk_ind].update( - dets_second[det_ind, :5], dets_second[det_ind, 5], dets_second[det_ind, 6] + dets_second[det_ind, :-2], dets_second[det_ind, -2], dets_second[det_ind, -1] ) to_remove_trk_indices.append(trk_ind) unmatched_trks = np.setdiff1d( @@ -340,7 +489,7 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]] if iou_left[m[0], m[1]] < self.asso_threshold: continue - self.active_tracks[trk_ind].update(dets[det_ind, :5], dets[det_ind, 5], dets[det_ind, 6]) + self.active_tracks[trk_ind].update(dets[det_ind, :-2], dets[det_ind, -2], dets[det_ind, -1]) to_remove_det_indices.append(det_ind) to_remove_trk_indices.append(trk_ind) unmatched_dets = np.setdiff1d( @@ -355,7 +504,10 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> # create and initialise new trackers for unmatched detections for i in unmatched_dets: - trk = KalmanBoxTracker(dets[i, :5], dets[i, 5], dets[i, 6], delta_t=self.delta_t, Q_xy_scaling=self.Q_xy_scaling, Q_s_scaling=self.Q_s_scaling, max_obs=self.max_obs) + if self.is_obb: + trk = KalmanBoxTrackerOBB(dets[i, :-2], dets[i, -2], dets[i, -1], delta_t=self.delta_t, Q_xy_scaling=self.Q_xy_scaling, Q_a_scaling=self.Q_s_scaling, max_obs=self.max_obs) + else: + trk = KalmanBoxTracker(dets[i, :5], dets[i, 5], dets[i, 6], delta_t=self.delta_t, Q_xy_scaling=self.Q_xy_scaling, Q_s_scaling=self.Q_s_scaling, max_obs=self.max_obs) self.active_tracks.append(trk) i = len(self.active_tracks) for trk in reversed(self.active_tracks): @@ -366,7 +518,7 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> this is optional to use the recent observation or the kalman filter prediction, we didn't notice significant difference here """ - d = trk.last_observation[:4] + d = trk.last_observation[:4+self.is_obb] if (trk.time_since_update < 1) and ( trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits ): @@ -382,4 +534,140 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> self.active_tracks.pop(i) if len(ret) > 0: return np.concatenate(ret) - return np.array([]) \ No newline at end of file + return np.array([]) + + def plot_box_on_img(self, img: np.ndarray, box: tuple, conf: float, cls: int, id: int, thickness: int = 2, fontscale: float = 0.5) -> np.ndarray: + """ + Draws a bounding box with ID, confidence, and class information on an image. + + Parameters: + - img (np.ndarray): The image array to draw on. + - box (tuple): The bounding box coordinates as (x1, y1, x2, y2). + - conf (float): Confidence score of the detection. + - cls (int): Class ID of the detection. + - id (int): Unique identifier for the detection. + - thickness (int): The thickness of the bounding box. + - fontscale (float): The font scale for the text. + + Returns: + - np.ndarray: The image array with the bounding box drawn on it. + """ + if self.is_obb: + + angle = box[4] * 180.0 / np.pi # Convert radians to degrees + box_poly = ((box[0], box[1]), (box[2], box[3]), angle) + # print((width, height)) + rotrec = cv.boxPoints(box_poly) + box_poly = np.int_(rotrec) # Convert to integer + + # Draw the rectangle on the image + img = cv.polylines(img, [box_poly], isClosed=True, color=self.id_to_color(id), thickness=thickness) + + img = cv.putText( + img, + f'id: {int(id)}, conf: {conf:.2f}, c: {int(cls)}, a: {box[4]:.2f}', + (int(box[0]), int(box[1]) - 10), + cv.FONT_HERSHEY_SIMPLEX, + fontscale, + self.id_to_color(id), + thickness + ) + else : + + img = cv.rectangle( + img, + (int(box[0]), int(box[1])), + (int(box[2]), int(box[3])), + self.id_to_color(id), + thickness + ) + img = cv.putText( + img, + f'id: {int(id)}, conf: {conf:.2f}, c: {int(cls)}', + (int(box[0]), int(box[1]) - 10), + cv.FONT_HERSHEY_SIMPLEX, + fontscale, + self.id_to_color(id), + thickness + ) + return img + + + def plot_trackers_trajectories(self, img: np.ndarray, observations: list, id: int) -> np.ndarray: + """ + Draws the trajectories of tracked objects based on historical observations. Each point + in the trajectory is represented by a circle, with the thickness increasing for more + recent observations to visualize the path of movement. + + Parameters: + - img (np.ndarray): The image array on which to draw the trajectories. + - observations (list): A list of bounding box coordinates representing the historical + observations of a tracked object. Each observation is in the format (x1, y1, x2, y2). + - id (int): The unique identifier of the tracked object for color consistency in visualization. + + Returns: + - np.ndarray: The image array with the trajectories drawn on it. + """ + for i, box in enumerate(observations): + trajectory_thickness = int(np.sqrt(float (i + 1)) * 1.2) + if self.is_obb: + img = cv.circle( + img, + (int(box[0]), int(box[1])), + 2, + color=self.id_to_color(int(id)), + thickness=trajectory_thickness + ) + else: + + img = cv.circle( + img, + (int((box[0] + box[2]) / 2), + int((box[1] + box[3]) / 2)), + 2, + color=self.id_to_color(int(id)), + thickness=trajectory_thickness + ) + return img + + def plot_results(self, img: np.ndarray, show_trajectories: bool, thickness: int = 2, fontscale: float = 0.5) -> np.ndarray: + """ + Visualizes the trajectories of all active tracks on the image. For each track, + it draws the latest bounding box and the path of movement if the history of + observations is longer than two. This helps in understanding the movement patterns + of each tracked object. + + Parameters: + - img (np.ndarray): The image array on which to draw the trajectories and bounding boxes. + - show_trajectories (bool): Whether to show the trajectories. + - thickness (int): The thickness of the bounding box. + - fontscale (float): The font scale for the text. + + Returns: + - np.ndarray: The image array with trajectories and bounding boxes of all active tracks. + """ + + # if values in dict + if self.per_class_active_tracks is not None: + for k in self.per_class_active_tracks.keys(): + active_tracks = self.per_class_active_tracks[k] + for a in active_tracks: + if a.history_observations: + if a.hits >= self.min_hits: + if len(a.history_observations) > 2: + box = a.history_observations[-1] + img = self.plot_box_on_img(img, box, a.conf, a.cls, a.id, thickness, fontscale) + if show_trajectories: + img = self.plot_trackers_trajectories(img, a.history_observations, a.id) + else: + for a in self.active_tracks: + if a.history_observations: + if a.hits >= self.min_hits: + if len(a.history_observations) > 2: + box = a.history_observations[-1] + img = self.plot_box_on_img(img, box, a.conf, a.cls, a.id, thickness, fontscale) + if show_trajectories: + img = self.plot_trackers_trajectories(img, a.history_observations, a.id) + + return img + \ No newline at end of file diff --git a/boxmot/utils/iou.py b/boxmot/utils/iou.py index f1b53dfc31..19fdcc9f7c 100644 --- a/boxmot/utils/iou.py +++ b/boxmot/utils/iou.py @@ -1,4 +1,35 @@ import numpy as np +import cv2 as cv + +def iou_obb_pair(i, j, bboxes1, bboxes2): + """ + Compute IoU for the rotated rectangles at index i and j in the batches `bboxes1`, `bboxes2` . + """ + rect1 = bboxes1[int(i)] + rect2 = bboxes2[int(j)] + + (cx1, cy1, w1, h1, angle1) = rect1[0:5] + (cx2, cy2, w2, h2, angle2) = rect2[0:5] + + + r1 = ((cx1, cy1), (w1, h1), angle1) + r2 = ((cx2, cy2), (w2, h2), angle2) + + # Compute intersection + ret, intersect = cv.rotatedRectangleIntersection(r1, r2) + if ret == 0 or intersect is None: + return 0.0 # No intersection + + # Calculate intersection area + intersection_area = cv.contourArea(intersect) + + # Calculate union area + area1 = w1 * h1 + area2 = w2 * h2 + union_area = area1 + area2 - intersection_area + + # Compute IoU + return intersection_area / union_area if union_area > 0 else 0.0 class AssociationFunction: def __init__(self, w, h, asso_mode="iou"): @@ -34,7 +65,17 @@ def iou_batch(bboxes1, bboxes2) -> np.ndarray: wh ) return o + + @staticmethod + def iou_batch_obb(bboxes1, bboxes2) -> np.ndarray: + N, M = len(bboxes1), len(bboxes2) + + def wrapper(i, j): + return iou_obb_pair(i, j, bboxes1, bboxes2) + + iou_matrix = np.fromfunction(np.vectorize(wrapper), shape=(N, M), dtype=int) + return iou_matrix @staticmethod def hmiou_batch(bboxes1, bboxes2): @@ -144,6 +185,19 @@ def centroid_batch(self, bboxes1, bboxes2) -> np.ndarray: return 1 - normalized_distances + def centroid_batch_obb(self, bboxes1, bboxes2) -> np.ndarray: + centroids1 = np.stack((bboxes1[..., 0], bboxes1[..., 1]),axis=-1) + centroids2 = np.stack((bboxes2[..., 0], bboxes2[..., 1]),axis=-1) + + centroids1 = np.expand_dims(centroids1, 1) + centroids2 = np.expand_dims(centroids2, 0) + + distances = np.sqrt(np.sum((centroids1 - centroids2) ** 2, axis=-1)) + norm_factor = np.sqrt(self.w ** 2 + self.h ** 2) + normalized_distances = distances / norm_factor + + return 1 - normalized_distances + @staticmethod def ciou_batch(bboxes1, bboxes2) -> np.ndarray: @@ -279,11 +333,13 @@ def _get_asso_func(self, asso_mode): """ ASSO_FUNCS = { "iou": AssociationFunction.iou_batch, + "iou_obb": AssociationFunction.iou_batch_obb, "hmiou": AssociationFunction.hmiou_batch, "giou": AssociationFunction.giou_batch, "ciou": AssociationFunction.ciou_batch, "diou": AssociationFunction.diou_batch, - "centroid": self.centroid_batch # only not being staticmethod + "centroid": self.centroid_batch, # only not being staticmethod + "centroid_obb": self.centroid_batch_obb } if self.asso_mode not in ASSO_FUNCS: