This code is mainly for reproducing the results reported on our TPAMI submitted paper Contrastive Bayesian Analysis for Deep Metric Learning. Beyound for this purpose, we will continue to maintain this project and provide tools for both supervised and unsupervised metric learning research. Aiming to integrate various loss functions and backbones to facilitate academic research progress on deep metric learning. Now, this project contains GoogleNet, BN-Inception, ResNet18, ResNet34, ResNet50, ResNet101 and ResNet152 backbones, and cbml_loss with log, square root and constant, crossentropy_loss, ms_loss, rank_loss, softtriple_loss, margin_loss, adv_loss, proxynca_loss, npair_loss, angular_loss, contrastive_loss, triplet_loss, cluster_loss, histogram_loss, center_loss and multiple losses.
Recent methods for deep metric learning has been focusing on designing different contrastive loss functions betweenpositive and negative pairs of samples so that the learned feature embedding is able to pull positive samples of the same class closerand push negative samples from different classes away from each other. In this work, we recognize that there is a significant semanticgap between features at intermediate feature layers and class label decision at the final output layer. To bridge this gap, we develop a contrastive Bayesian analysis to characterize and model the posterior probabilities of image labels conditioned by their metric similarity in a contrastive learning setting. This contrastive Bayesian analysis leads to a new loss function for deep metric learning. To improve the generalization capability of the proposed method onto new classes, we further extend the contrastive Bayesian loss with a metric variance constraint. Our experimental results and ablation studies demonstrate that the proposed contrastive Bayesian metric learning method significantly improves the performance of deep metric learning, outperforming existing methods by a large margin.
- Googlenet Backbone
Recall@K | 1 | 2 | 4 | 8 |
---|---|---|---|---|
Contrastive | 26.4 | 37.7 | 49.8 | 62.3 |
Triplet | 36.1 | 48.6 | 59.3 | 70.0 |
LiftedStruct | 47.2 | 58.9 | 70.2 | 80.2 |
Binomial Deviance | 52.8 | 64.4 | 74.7 | 83.9 |
Histogram Loss | 50.3 | 61.9 | 72.6 | 82.3 |
HDC | 53.6 | 65.7 | 77.0 | 85.6 |
Angular Loss | 54.7 | 66.3 | 76.0 | 83.9 |
BIER | 55.3 | 67.2 | 76.9 | 85.1 |
A-BIER | 57.5 | 68.7 | 78.3 | 82.6 |
Ours CBML-const-GoogleNet | 62.8 | 73.9 | 83.2 | 89.8 |
Ours CBML-sqrt-GoogleNet | 63.1 | 74.7 | 83.1 | 89.8 |
Ours CBML-log-GoogleNet | 63.8 | 74.8 | 83.6 | 90.3 |
- BN-Inception Backbone
Recall@K | 1 | 2 | 4 | 8 |
---|---|---|---|---|
Ranked List (H) | 57.4 | 69.7 | 79.2 | 86.9 |
Ranked List (L,M,H) | 61.3 | 72.7 | 82.7 | 89.4 |
SoftTriple | 65.4 | 76.4 | 84.5 | 90.4 |
DeML | 65.4 | 75.3 | 83.7 | 89.5 |
MS | 65.7 | 77.0 | 86.3 | 91.2 |
Contrastive+HORDE | 66.8 | 77.4 | 85.1 | 91.0 |
Ours CBML-const-BN-Inception | 68.3 | 78.5 | 86.9 | 92.1 |
Ours CBML-sqrt-BN-Inception | 69.5 | 79.5 | 86.7 | 91.8 |
Ours CBML-log-BN-Inception | 69.5 | 79.4 | 87.0 | 92.4 |
- ResNet50 Backbone
Recall@K | 1 | 2 | 4 | 8 |
---|---|---|---|---|
Devide-Conquer | 65.9 | 76.6 | 84.4 | 90.6 |
MIC+Margin | 66.1 | 76.8 | 85.6 | - |
TML | 62.5 | 73.9 | 83.0 | 89.4 |
Ours CBML-const-ResNet50 | 69.2 | 79.3 | 86.3 | 91.6 |
Ours CBML-sqrt-ResNet50 | 70.0 | 79.9 | 87.0 | 92.0 |
Ours CBML-log-ResNet50 | 69.9 | 80.4 | 87.2 | 92.5 |
- ResNet50 Backbone
Recall@K | 1 | 2 | 4 | 8 |
---|---|---|---|---|
N-Pair | 53.2 | 65.3 | 76.0 | 84.8 |
ProxyNCA | 55.5 | 67.7 | 78.2 | 86.2 |
EPSHN | 57.3 | 68.9 | 79.3 | 87.2 |
MS | 57.4 | 69.8 | 80.0 | 87.8 |
Ours CBML-const-ResNet50 | 65.0 | 76.2 | 84.9 | 90.6 |
Ours CBML-sqrt-ResNet50 | 65.0 | 76.0 | 84.1 | 90.3 |
Ours CBML-log-ResNet50 | 64.3 | 75.7 | 84.1 | 90.1 |
- ResNet18 Backbone
Recall@K | 1 | 2 | 4 | 8 |
---|---|---|---|---|
N-Pair | 52.4 | 65.7 | 76.8 | 84.6 |
ProxyNCA | 51.5 | 63.8 | 74.6 | 84.0 |
EPSHN | 54.2 | 66.6 | 77.4 | 86.0 |
Ours CBML-const-ResNet18 | 58.0 | 69.6 | 80.0 | 87.5 |
Ours CBML-sqrt-ResNet18 | 59.4 | 70.5 | 80.4 | 88.0 |
Ours CBML-log-ResNet18 | 61.3 | 72.6 | 81.9 | 88.7 |
- GoogleNet Backbone
Recall@K | 1 | 2 | 4 | 8 |
---|---|---|---|---|
Triplet | 42.6 | 55.0 | 66.4 | 77.2 |
N-Pair | 45.4 | 58.4 | 69.5 | 79.5 |
ProxyNCA | 49.2 | 61.9 | 67.9 | 72.4 |
EPSHN | 51.7 | 64.1 | 75.3 | 83.9 |
Ours CBML-const-GoogleNet | 56.8 | 69.5 | 79.5 | 87.9 |
Ours CBML-sqrt-GoogleNet | 57.7 | 69.7 | 80.5 | 88.3 |
Ours CBML-log-GoogleNet | 59.3 | 70.7 | 80.6 | 88.1 |
The following script will prepare the CUB dataset for training by downloading to the ./resource/datasets/ folder; which will then build the data list (train.txt test.txt):
./scripts/prepare_cub.sh
To reproduce the results of our paper. Download the imagenet pretrained model of googlenet, bninception and resnet50, and put them in the folder: ~/.cache/torch/checkpoints/.
sudo pip3 install -r requirements.txt
sudo python3 setup.py develop build
./scripts/run_cub_bninception.sh
Trained models will be saved in the ./output-bninception-cub/ folder if using the default config.
./scripts/run_cub_resnet50.sh
Trained models will be saved in the ./output-resnet50-cub/ folder if using the default config.
./scripts/run_cub_googlenet.sh
Trained models will be saved in the ./output-googlenet-cub/ folder if using the default config.
Code will be released in other times.
If you use this method or this code in your research, please cite as:
@inproceedings{Shichao-2022,
title={Contrastive Bayesian Analysis for Deep Metric Learning},
author={Shichao Kan, Zhiquan He, Yigang Cen, Yang Li, Mladenovic Vladimir, and Zhihai He},
booktitle={IEEE Transactions on Pattern Analysis and Machine Intelligence},
pages={},
year={2022}
}
This code is written based on the framework of MS-Loss, we are really grateful to the authors of the MS paper to release their code for academic research / non-commercial use. We also thank the following helpful implementtaions on histogram, proxynca, n-pair and angular, siamese-triplet, clustering.
This code is released for academic research / non-commercial use only. If you wish to use for commercial purposes, please contact Shichao Kan by email [email protected].
- Michael Opitz, Georg Waltner, Horst Possegger, Horst Bischof: Deep Metric Learning with BIER: Boosting Independent Embeddings Robustly. IEEE Trans. Pattern Anal. Mach. Intell. 42(2): 276-290 (2020)
- Kim, S., Kim, D., Cho, M., & Kwak, S. (2022). Self-Taught Metric Learning without Labels. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 7431-7441). (2020)