diff --git a/README.md b/README.md index cea4603..3816183 100755 --- a/README.md +++ b/README.md @@ -137,6 +137,22 @@ Default embedder is a pytorch MobilenetV2 (trained on Imagenet). For convenience (I know it's not exactly best practice) & since the weights file is quite small, it is pushed in this github repo and will be installed to your Python environment when you install deep_sort_realtime. +#### TorchReID + +[Torchreid](https://github.com/KaiyangZhou/deep-person-reid) is a person re-identification library, and is supported here especially useful for extracting features of humans. It provides a zoo of [models](https://kaiyangzhou.github.io/deep-person-reid/MODEL_ZOO). Select model type to use, note the model name and provide as arguments. Download the corresponding model weights file on the model zoo site and point to the downloaded file. Model 'osnet_ain_x1_0' with domain generalized training on (MS+D+C) is provide by default, together with the corresponding weights. If `embedder='torchreid'` when initalizing `DeepSort` object without specifying `embedder_model_name` or `embedder_wts`, it will default to that. + +```python +from deep_sort_realtime.deepsort_tracker import DeepSort +tracker = DeepSort(max_age=5, embedder='torchreid') +bbs = object_detector.detect(frame) +tracks = tracker.update_tracks(bbs, frame=frame) # bbs expected to be a list of detections, each in tuples of ( [left,top,w,h], confidence, detection_class ) +for track in tracks: + if not track.is_confirmed(): + continue + track_id = track.track_id + ltrb = track.to_ltrb() +``` + #### CLIP [CLIP](https://github.com/openai/CLIP) is added as another option of embedder due to its proven flexibility and generalisability. Download the CLIP model weights you want at [deep_sort_realtime/embedder/weights/download_clip_wts.sh](deep_sort_realtime/embedder/weights/download_clip_wts.sh) and store the weights at that directory as well, or you can provide your own CLIP weights through `embedder_wts` argument of the `DeepSort` object. @@ -146,3 +162,19 @@ For convenience (I know it's not exactly best practice) & since the weights file Available now at `deep_sort_realtime/embedder/embedder_tf.py`, as alternative to (the default) pytorch embedder. Tested on Tensorflow 2.3.1. You need to make your own code change to use it. The tf MobilenetV2 weights (pretrained on imagenet) are not available in this github repo (unlike the torch one). Download from this [link](https://drive.google.com/file/d/1RBroAFc0tmfxgvrh7iXc2e1EK8TVzXkA/view?usp=sharing) or run [download script](./deep_sort_realtime/embedder/weights/download_tf_wts.sh). You may drop it into `deep_sort_realtime/embedder/weights/` before pip installing. + +### Example + +Example cosine distances between images in `./test/` ("diff": rock vs smallapple, "close": smallapple vs smallapple slightly augmented) + +``` +.Testing pytorch embedder +close: 0.012196660041809082 vs diff: 0.4409685730934143 + +.Testing Torchreid embedder +Model: osnet_ain_x1_0 +- params: 2,193,616 +- flops: 978,878,352 +Successfully loaded pretrained weights from "/Users/levan/Workspace/deep_sort_realtime/deep_sort_realtime/embedder/weights/osnet_ain_ms_d_c_wtsonly.pth" +close: 0.012312591075897217 vs diff: 0.4590487480163574 +``` \ No newline at end of file diff --git a/deep_sort_realtime/deepsort_tracker.py b/deep_sort_realtime/deepsort_tracker.py index b07ea56..4534c21 100644 --- a/deep_sort_realtime/deepsort_tracker.py +++ b/deep_sort_realtime/deepsort_tracker.py @@ -14,6 +14,7 @@ EMBEDDER_CHOICES = [ "mobilenet", + "torchreid", "clip_RN50", "clip_RN101", "clip_RN50x4", @@ -36,6 +37,7 @@ def __init__( half=True, bgr=True, embedder_gpu=True, + embedder_model_name=None, embedder_wts=None, polygon=False, today=None, @@ -58,13 +60,15 @@ def __init__( Giving this will override default Track class, this must inherit Track embedder : Optional[str] = 'mobilenet' Whether to use in-built embedder or not. If None, then embeddings must be given during update. - Choice of ['mobilenet', 'clip_RN50', 'clip_RN101', 'clip_RN50x4', 'clip_RN50x16', 'clip_ViT-B/32', 'clip_ViT-B/16'] + Choice of ['mobilenet', 'torchreid', 'clip_RN50', 'clip_RN101', 'clip_RN50x4', 'clip_RN50x16', 'clip_ViT-B/32', 'clip_ViT-B/16'] half : Optional[bool] = True Whether to use half precision for deep embedder (applicable for mobilenet only) bgr : Optional[bool] = True Whether frame given to embedder is expected to be BGR or not (RGB) embedder_gpu: Optional[bool] = True Whether embedder uses gpu or not + embedder_model_name: Optional[str] = None + Only used when embedder=='torchreid'. This provides which model to use within torchreid library. Check out torchreid's model zoo. embedder_wts: Optional[str] = None Optional specification of path to embedder's model weights. Will default to looking for weights in `deep_sort_realtime/embedder/weights`. If deep_sort_realtime is installed as a package and CLIP models is used as embedder, best to provide path. polygon: Optional[bool] = False @@ -99,7 +103,17 @@ def __init__( gpu=embedder_gpu, model_wts_path=embedder_wts, ) - else: + elif embedder == 'torchreid': + from deep_sort_realtime.embedder.embedder_pytorch import TorchReID_Embedder as Embedder + + self.embedder = Embedder( + bgr=bgr, + gpu=embedder_gpu, + model_name=embedder_model_name, + model_wts_path=embedder_wts, + ) + + elif embedder.startswith('clip_'): from deep_sort_realtime.embedder.embedder_clip import ( Clip_Embedder as Embedder, ) @@ -164,9 +178,9 @@ def update_tracks(self, raw_detections, embeds=None, frame=None, today=None, oth raise Exception("either embeddings or frame must be given!") assert isinstance(raw_detections,Iterable) - assert len(raw_detections[0][0])==4 if not self.polygon: + assert len(raw_detections[0][0])==4 raw_detections = [d for d in raw_detections if d[0][2] > 0 and d[0][3] > 0] if embeds is None: diff --git a/deep_sort_realtime/embedder/embedder_pytorch.py b/deep_sort_realtime/embedder/embedder_pytorch.py index e031fc5..9847edf 100644 --- a/deep_sort_realtime/embedder/embedder_pytorch.py +++ b/deep_sort_realtime/embedder/embedder_pytorch.py @@ -14,6 +14,11 @@ MOBILENETV2_BOTTLENECK_WTS = pkg_resources.resource_filename( "deep_sort_realtime", "embedder/weights/mobilenetv2_bottleneck_wts.pt" ) + +TORCHREID_OSNET_AIN_X1_0_MS_D_C_WTS = pkg_resources.resource_filename( + "deep_sort_realtime", "embedder/weights/osnet_ain_ms_d_c_wtsonly.pth" +) + INPUT_WIDTH = 224 @@ -132,3 +137,96 @@ def predict(self, np_images): all_feats.extend(output.cpu().data.numpy()) return all_feats + + +class TorchReID_Embedder(object): + """ + Embedder that works with torchreid (https://github.com/KaiyangZhou/deep-person-reid). Model zoo: https://kaiyangzhou.github.io/deep-person-reid/MODEL_ZOO + + Params + ------ + - model_name (optional, str): name of model, see torchreid model zoo. defaults to osnet_ain_x1_0 + - model_wts_path (optional, str) : path to torchreid model weights, defaults to TORCHREID_OSNET_AIN_X1_0_MS_D_C_WTS if model_name=='osnet_ain_x1_0' (default) and else, imagenet pretrained weights of given model + - bgr (optional, Bool) : boolean flag indicating if input frames are bgr or not, defaults to True + - gpu (optional, Bool) : boolean flag indicating if gpu is enabled or not + - max_batch_size: Does nothing, just for compatibility to other embedder classes + """ + + def __init__( + self, model_name=None, model_wts_path=None, bgr=True, gpu=True, max_batch_size=None, + ): + try: + import torchreid + except ImportError: + raise Exception('ImportError: torchreid is not installed, please install and try again or choose another embedder') + + from torchreid.utils import FeatureExtractor + + if model_name is None: + model_name = 'osnet_ain_x1_0' + + if model_wts_path is None: + model_wts_path = '' + + if model_name=='osnet_ain_x1_0' and model_wts_path=='': + model_wts_path = TORCHREID_OSNET_AIN_X1_0_MS_D_C_WTS + + self.gpu = gpu and torch.cuda.is_available() + if self.gpu: + device = 'cuda' + else: + device = 'cpu' + + self.model = FeatureExtractor( + model_name=model_name, + model_path=model_wts_path, + device=device, + ) + + self.bgr = bgr + + logger.info("TorchReID Embedder for Deep Sort initialised") + logger.info(f"- gpu enabled: {self.gpu}") + logger.info(f"- expects BGR: {self.bgr}") + + zeros = np.zeros((100, 100, 3), dtype=np.uint8) + self.predict([zeros]) # warmup + + def preprocess(self, np_image): + """ + Preprocessing for embedder network: Flips BGR to RGB, resize, convert to torch tensor, normalise with imagenet mean and variance, reshape. Note: input image yet to be loaded to GPU through tensor.cuda() + + Parameters + ---------- + np_image : ndarray + (H x W x C) + + Returns + ------- + Torch Tensor + + """ + if self.bgr: + np_image_rgb = np_image[..., ::-1] + else: + np_image_rgb = np_image + # torchreid handles the rest of the preprocessing + return np_image_rgb + + def predict(self, np_images): + """ + batch inference + + Params + ------ + np_images : list of ndarray + list of (H x W x C), bgr or rgb according to self.bgr + + Returns + ------ + list of features (np.array with dim = 1280) + + """ + preproc_imgs = [self.preprocess(img) for img in np_images] + output = self.model(preproc_imgs) + return output.cpu().data.numpy() diff --git a/deep_sort_realtime/embedder/weights/osnet_ain_ms_d_c_wtsonly.pth b/deep_sort_realtime/embedder/weights/osnet_ain_ms_d_c_wtsonly.pth new file mode 100644 index 0000000..20e46f4 Binary files /dev/null and b/deep_sort_realtime/embedder/weights/osnet_ain_ms_d_c_wtsonly.pth differ diff --git a/setup.py b/setup.py index ecec39f..43b8427 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ package_data={ "deep_sort_realtime.embedder": [ "weights/mobilenetv2_bottleneck_wts.pt", + "weights/osnet_ain_ms_d_c_wtsonly.pth" "weights/download_clip_wts.sh", "weights/download_tf_wts.sh", ] diff --git a/test/test_embedder.py b/test/test_embedder.py index fa66187..63c8a02 100644 --- a/test/test_embedder.py +++ b/test/test_embedder.py @@ -21,6 +21,12 @@ CLIP_INSTALLED = True except ModuleNotFoundError: CLIP_INSTALLED = False + + try: + import torchreid + TORCHREID_INSTALLED = True + except ModuleNotFoundError: + TORCHREID_INSTALLED = False try: import tensorflow @@ -117,6 +123,20 @@ def test_embedder_clip_cpu(self): print("Testing CLIP embedder") return test_embedder_generic(Clip_Embedder, gpu=False) + @unittest.skipIf(not TORCHREID_INSTALLED, "Torchreid is not installed") + def test_embedder_torchreid(self): + from deep_sort_realtime.embedder.embedder_pytorch import TorchReID_Embedder + + print("Testing Torchreid embedder") + return test_embedder_generic(TorchReID_Embedder) + + @unittest.skipIf(not TORCHREID_INSTALLED, "Torchreid is not installed") + def test_embedder_torchreid_cpu(self): + from deep_sort_realtime.embedder.embedder_pytorch import TorchReID_Embedder + + print("Testing Torchreid embedder") + return test_embedder_generic(TorchReID_Embedder, gpu=False) + if __name__ == "__main__": unittest.main()