diff --git a/.gitignore b/.gitignore index e2a8cd7..5975e82 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ -wbia_tbd/data/ -wbia_tbd/wandb/ +wbia_miew_id/data/ +wbia_miew_id/wandb/ *.pyc -wbia_tbd/runs/ +wbia_miew_id/runs/ .env TODO.md \ No newline at end of file diff --git a/README.md b/README.md index 7f325fd..c0baf2a 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ -# WILDBOOK IA - ID Plugin +# WILDBOOK IA - MIEW-ID Plugin -A plugin for re-identificaiton of wildlife individuals using learned embeddings. +A plugin for matching and interpreting embeddings for wildlife identification. ## Setup @@ -19,7 +19,7 @@ WANDB_MODE={'online'/'offline'} You can create a new line in a code block in markdown by using two spaces at the end of the line followed by a line break. Here's an example: ``` -cd wbia_tbd +cd wbia_miew_id python train.py ``` @@ -80,4 +80,4 @@ A config file path can be set by: ## Notes -This is an initial commit which includes training, inference and WBIA integration capabilities. Release of additional features is underway. \ No newline at end of file +This is an initial commit which includes training, inference and WBIA integration capabilities. Release of additional features is underway. diff --git a/requirements.txt b/requirements.txt index 298775d..7cc71e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,6 @@ timm==0.6.12 torch==2.0.0 torchvision==0.15.1 tqdm==4.65.0 -python-dotenv=1.0.0 \ No newline at end of file +python-dotenv==1.0.0 +grad-cam==1.4.6 +optuna==3.2.0 \ No newline at end of file diff --git a/wbia_miew_id/__init__.py b/wbia_miew_id/__init__.py new file mode 100644 index 0000000..7406388 --- /dev/null +++ b/wbia_miew_id/__init__.py @@ -0,0 +1,4 @@ +from wbia_miew_id import _plugin # NOQA + + +__version__ = '0.0.0' diff --git a/wbia_tbd/_plugin.py b/wbia_miew_id/_plugin.py similarity index 63% rename from wbia_tbd/_plugin.py rename to wbia_miew_id/_plugin.py index 7ba041e..916aed3 100644 --- a/wbia_tbd/_plugin.py +++ b/wbia_miew_id/_plugin.py @@ -12,13 +12,15 @@ import torch import torchvision.transforms as transforms # noqa: E402 from scipy.spatial import distance_matrix +import pandas as pd import tqdm -from wbia_tbd.helpers import get_config, read_json -from wbia_tbd.models import get_model -from wbia_tbd.datasets import PluginDataset, get_test_transforms -from wbia_tbd.metrics import pred_light, compute_distance_matrix, eval_onevsall +from wbia_miew_id.helpers import get_config, read_json +from wbia_miew_id.models import get_model +from wbia_miew_id.datasets import PluginDataset, get_test_transforms +from wbia_miew_id.metrics import pred_light, compute_distance_matrix, eval_onevsall +from wbia_miew_id.visualization import draw_one, draw_batch (print, rrr, profile) = ut.inject2(__name__) @@ -36,32 +38,74 @@ } CONFIGS = { - 'whale_beluga': 'https://cthulhu.dyn.wildme.io/public/models/tbd.beluga.yaml', - 'delphinapterus_leucas': 'https://cthulhu.dyn.wildme.io/public/models/tbd.beluga.yaml', - 'tursiops_truncatus': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.yaml', - 'dolphin_whitesided+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.yaml', - 'white_shark+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.yaml', - 'spinner_dolphin': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.yaml', - 'stenella_longirostris': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.yaml', - 'sotalia_guianensis': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.yaml', - 'short_fin_pilot_whale+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.yaml', - 'globicephala_melas': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.yaml', - 'pilot_whale+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.yaml', - 'globicephala_macrorhynchus': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.yaml',} + 'whale_beluga': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.beluga.yaml', + 'delphinapterus_leucas': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.beluga.yaml', + 'tursiops_truncatus': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'dolphin_whitesided+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'white_shark+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'spinner_dolphin': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'stenella_longirostris': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'sotalia_guianensis': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'short_fin_pilot_whale+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'globicephala_melas': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'pilot_whale+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'globicephala_macrorhynchus': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'globicephala_melas': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'short_fin_pilot_whale+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'hyperoodon_ampullatus': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'whale_humpback+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'lagenodelphis_hosei': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'cougar+head': 'https://cthulhu.dyn.wildme.io/public/models/miewid_lion_head_v0.yaml', + 'lion+head': 'https://cthulhu.dyn.wildme.io/public/models/miewid_lion_head_v0.yaml', + 'lioness+head': 'https://cthulhu.dyn.wildme.io/public/models/miewid_lion_head_v0.yaml', + 'lion_general+head': 'https://cthulhu.dyn.wildme.io/public/models/miewid_lion_head_v0.yaml', + 'panthera_leo': 'https://cthulhu.dyn.wildme.io/public/models/miewid_lion_head_v0.yaml', + 'dolphin_spotted+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'stenella_frontalis': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'whale_falsekiller': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'pseudorca_crassidens': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'dolphin_rissos+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'grampus_griseus': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'phocoena_phocoena': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'harbour_porpoise+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'balaenoptera_acutorostrata+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml', + 'whale_minke+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.yaml' + } + MODELS = { - 'whale_beluga': 'https://cthulhu.dyn.wildme.io/public/models/tbd.beluga.bin', - 'delphinapterus_leucas': 'https://cthulhu.dyn.wildme.io/public/models/tbd.beluga.bin', - 'tursiops_truncatus': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.bin', - 'dolphin_whitesided+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.bin', - 'white_shark+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.bin', - 'spinner_dolphin': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.bin', - 'stenella_longirostris': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.bin', - 'sotalia_guianensis': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.bin', - 'short_fin_pilot_whale+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.bin', - 'globicephala_melas': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.bin', - 'pilot_whale+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.bin', - 'globicephala_macrorhynchus': 'https://cthulhu.dyn.wildme.io/public/models/tbd.bottlenose.bin', + 'whale_beluga': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.beluga.bin', + 'delphinapterus_leucas': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.beluga.bin', + 'tursiops_truncatus': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'dolphin_whitesided+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'white_shark+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'spinner_dolphin': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'stenella_longirostris': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'sotalia_guianensis': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'short_fin_pilot_whale+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'globicephala_melas': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'pilot_whale+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'globicephala_macrorhynchus': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'globicephala_melas': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'short_fin_pilot_whale+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'hyperoodon_ampullatus': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'whale_humpback+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'lagenodelphis_hosei': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'lion+head': 'https://cthulhu.dyn.wildme.io/public/models/miewid_lion_head_v0.bin', + 'lioness+head': 'https://cthulhu.dyn.wildme.io/public/models/miewid_lion_head_v0.bin', + 'lion_general+head': 'https://cthulhu.dyn.wildme.io/public/models/miewid_lion_head_v0.bin', + 'panthera_leo': 'https://cthulhu.dyn.wildme.io/public/models/miewid_lion_head_v0.bin', + 'cougar+head': 'https://cthulhu.dyn.wildme.io/public/models/miewid_lion_head_v0.bin', + 'dolphin_spotted+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'stenella_frontalis': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'whale_falsekiller': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'pseudorca_crassidens': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'dolphin_rissos+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'grampus_griseus': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'phocoena_phocoena': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'harbour_porpoise+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'balaenoptera_acutorostrata+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin', + 'whale_minke+fin_dorsal': 'https://cthulhu.dyn.wildme.io/public/models/miew_id.bottlenose.bin' } @@ -69,7 +113,7 @@ @register_ibs_method -def tbd_embedding(ibs, aid_list, config=None, use_depc=True): +def miew_id_embedding(ibs, aid_list, config=None, use_depc=True): r""" Generate embeddings using the Pose-Invariant Embedding (TBD) Args: @@ -77,13 +121,13 @@ def tbd_embedding(ibs, aid_list, config=None, use_depc=True): aid_list (int): annot ids specifying the input use_depc (bool): use dependency cache CommandLine: - python -m wbia_tbd._plugin tbd_embedding + python -m wbia_miew_id._plugin miew_id_embedding Example: >>> # ENABLE_DOCTEST - >>> import wbia_tbd - >>> from wbia_tbd._plugin import DEMOS, CONFIGS, MODELS + >>> import wbia_miew_id + >>> from wbia_miew_id._plugin import DEMOS, CONFIGS, MODELS >>> species = 'rhincodon_typus' - >>> test_ibs = wbia_tbd._plugin.wbia_tbd_test_ibs(DEMOS[species], species, 'test2021') + >>> test_ibs = wbia_miew_id._plugin.wbia_miew_id_test_ibs(DEMOS[species], species, 'test2021') >>> aid_list = test_ibs.get_valid_aids(species=species) >>> rank1 = test_ibs.evaluate_distmat(aid_list, CONFIGS[species], use_depc=False) >>> expected_rank1 = 0.81366 @@ -91,10 +135,10 @@ def tbd_embedding(ibs, aid_list, config=None, use_depc=True): Example: >>> # ENABLE_DOCTEST - >>> import wbia_tbd - >>> from wbia_tbd._plugin import DEMOS, CONFIGS, MODELS + >>> import wbia_miew_id + >>> from wbia_miew_id._plugin import DEMOS, CONFIGS, MODELS >>> species = 'whale_grey' - >>> test_ibs = wbia_tbd._plugin.wbia_tbd_test_ibs(DEMOS[species], species, 'test2021') + >>> test_ibs = wbia_miew_id._plugin.wbia_miew_id_test_ibs(DEMOS[species], species, 'test2021') >>> aid_list = test_ibs.get_valid_aids(species=species) >>> rank1 = test_ibs.evaluate_distmat(aid_list, CONFIGS[species], use_depc=False) >>> expected_rank1 = 0.69505 @@ -102,10 +146,10 @@ def tbd_embedding(ibs, aid_list, config=None, use_depc=True): Example: >>> # ENABLE_DOCTEST - >>> import wbia_tbd - >>> from wbia_tbd._plugin import DEMOS, CONFIGS, MODELS + >>> import wbia_miew_id + >>> from wbia_miew_id._plugin import DEMOS, CONFIGS, MODELS >>> species = 'horse_wild' - >>> test_ibs = wbia_tbd._plugin.wbia_tbd_test_ibs(DEMOS[species], species, 'test2021') + >>> test_ibs = wbia_miew_id._plugin.wbia_miew_id_test_ibs(DEMOS[species], species, 'test2021') >>> aid_list = test_ibs.get_valid_aids(species=species) >>> rank1 = test_ibs.evaluate_distmat(aid_list, CONFIGS[species], use_depc=False) >>> expected_rank1 = 0.32773 @@ -124,10 +168,10 @@ def tbd_embedding(ibs, aid_list, config=None, use_depc=True): if use_depc: config_map = {'config_path': config} dirty_embeddings = ibs.depc_annot.get( - 'TbdEmbedding', dirty_aids, 'embedding', config_map + 'MiewIdEmbedding', dirty_aids, 'embedding', config_map ) else: - dirty_embeddings = tbd_compute_embedding(ibs, dirty_aids, config) + dirty_embeddings = miew_id_compute_embedding(ibs, dirty_aids, config) for dirty_aid, dirty_embedding in zip(dirty_aids, dirty_embeddings): GLOBAL_EMBEDDING_CACHE[dirty_aid] = dirty_embedding @@ -137,31 +181,31 @@ def tbd_embedding(ibs, aid_list, config=None, use_depc=True): return embeddings -class TbdEmbeddingConfig(dt.Config): # NOQA +class MiewIdEmbeddingConfig(dt.Config): # NOQA _param_info_list = [ ut.ParamInfo('config_path', default=None), ] @register_preproc_annot( - tablename='TbdEmbedding', + tablename='MiewIdEmbedding', parents=[ANNOTATION_TABLE], colnames=['embedding'], coltypes=[np.ndarray], - configclass=TbdEmbeddingConfig, - fname='tbd', + configclass=MiewIdEmbeddingConfig, + fname='miew_id', chunksize=128, ) @register_ibs_method -def tbd_embedding_depc(depc, aid_list, config=None): +def miew_id_embedding_depc(depc, aid_list, config=None): ibs = depc.controller - embs = tbd_compute_embedding(ibs, aid_list, config=config['config_path']) + embs = miew_id_compute_embedding(ibs, aid_list, config=config['config_path']) for aid, emb in zip(aid_list, embs): yield (np.array(emb),) @register_ibs_method -def tbd_compute_embedding(ibs, aid_list, config=None, multithread=False): +def miew_id_compute_embedding(ibs, aid_list, config=None, multithread=False): # Get species from the first annotation species = ibs.get_annot_species_texts(aid_list[0]) @@ -180,7 +224,7 @@ def tbd_compute_embedding(ibs, aid_list, config=None, multithread=False): embeddings = [] model.eval() with torch.no_grad(): - for images, names in test_loader: + for images, names, image_paths, image_bboxes in test_loader: if config.use_gpu: images = images.cuda(non_blocking=True) @@ -191,7 +235,7 @@ def tbd_compute_embedding(ibs, aid_list, config=None, multithread=False): return embeddings -class TbdConfig(dt.Config): # NOQA +class MiewIdConfig(dt.Config): # NOQA def get_param_info_list(self): return [ ut.ParamInfo('config_path', None), @@ -243,9 +287,9 @@ def get_match_results(depc, qaid_list, daid_list, score_list, config): yield match_result -class TbdRequest(dt.base.VsOneSimilarityRequest): +class MiewIdRequest(dt.base.VsOneSimilarityRequest): _symmetric = False - _tablename = 'Tbd' + _tablename = 'MiewId' @ut.accepts_scalar_input def get_fmatch_overlayed_chip(request, aid_list, overlay=True, config=None): @@ -254,15 +298,64 @@ def get_fmatch_overlayed_chip(request, aid_list, overlay=True, config=None): chips = ibs.get_annot_chips(aid_list) return chips + # def render_single_result(request, cm, aid, **kwargs): + # # HACK FOR WEB VIEWER + # overlay = kwargs.get('draw_fmatches') + # chips = request.get_fmatch_overlayed_chip( + # [cm.qaid, aid], overlay=overlay, config=request.config + # ) + # out_image = vt.stack_image_list(chips) + + # return out_image + def render_single_result(request, cm, aid, **kwargs): - # HACK FOR WEB VIEWER - overlay = kwargs.get('draw_fmatches') - chips = request.get_fmatch_overlayed_chip( - [cm.qaid, aid], overlay=overlay, config=request.config - ) - out_image = vt.stack_image_list(chips) + + depc = request.depc + ibs = depc.controller + + # Load config + species = ibs.get_annot_species_texts(aid) + + config = None + if config is None: + config = CONFIGS[species] + config = _load_config(config) + + # Load model + model = _load_model(config, MODELS[species], use_dataparallel=False) + + # This list has to be in the format of [query_aid, db_aid] + aid_list = [cm.qaid, aid] + test_loader, test_dataset = _load_data(ibs, aid_list, config) + + out_image = draw_one(config, test_loader, model, images_dir = '', method='gradcam_plus_plus', eigen_smooth=False, show=False) + return out_image + + def render_batch_result(request, cm, aids): + + depc = request.depc + ibs = depc.controller + + # Load config + species = ibs.get_annot_species_texts(aids)[0] + + config = None + if config is None: + config = CONFIGS[species] + config = _load_config(config) + + # Load model + model = _load_model(config, MODELS[species], use_dataparallel=False) + + # This list has to be in the format of [query_aid, db_aid] + aid_list = np.concatenate(([cm.qaid], aids)) + test_loader, test_dataset = _load_data(ibs, aid_list, config) + + batch_images = draw_batch(config, test_loader, model, images_dir = '', method='gradcam_plus_plus', eigen_smooth=False, show=False) + return batch_images + def postprocess_execute(request, table, parent_rowids, rowids, result_list): qaid_list, daid_list = list(zip(*parent_rowids)) score_list = ut.take_column(result_list, 0) @@ -274,7 +367,7 @@ def postprocess_execute(request, table, parent_rowids, rowids, result_list): def execute(request, *args, **kwargs): # kwargs['use_cache'] = False - result_list = super(TbdRequest, request).execute(*args, **kwargs) + result_list = super(MiewIdRequest, request).execute(*args, **kwargs) qaids = kwargs.pop('qaids', None) if qaids is not None: result_list = [result for result in result_list if result.qaid in qaids] @@ -282,17 +375,17 @@ def execute(request, *args, **kwargs): @register_preproc_annot( - tablename='Tbd', + tablename='MiewId', parents=[ANNOTATION_TABLE, ANNOTATION_TABLE], colnames=['score'], coltypes=[float], - configclass=TbdConfig, - requestclass=TbdRequest, - fname='tbd', + configclass=MiewIdConfig, + requestclass=MiewIdRequest, + fname='miew_id', rm_extern_on_delete=True, chunksize=None, ) -def wbia_plugin_tbd(depc, qaid_list, daid_list, config): +def wbia_plugin_miew_id(depc, qaid_list, daid_list, config): ibs = depc.controller qaids = list(set(qaid_list)) @@ -303,27 +396,27 @@ def wbia_plugin_tbd(depc, qaid_list, daid_list, config): qaid_score_dict = {} for qaid in tqdm.tqdm(qaids): if use_knn: - tbd_dists = ibs.tbd_predict_light( + miew_id_dists = ibs.miew_id_predict_light( qaid, daids, config['config_path'], ) - tbd_scores = distance_dicts_to_score_dicts(tbd_dists) + miew_id_scores = distance_dicts_to_score_dicts(miew_id_dists) - # aid_score_list = aid_scores_from_name_scores(ibs, tbd_name_scores, daids) - aid_score_list = aid_scores_from_score_dict(tbd_scores, daids) + # aid_score_list = aid_scores_from_name_scores(ibs, miew_id_name_scores, daids) + aid_score_list = aid_scores_from_score_dict(miew_id_scores, daids) aid_score_dict = dict(zip(daids, aid_score_list)) qaid_score_dict[qaid] = aid_score_dict else: - tbd_annot_distances = ibs.tbd_predict_light_distance( + miew_id_annot_distances = ibs.miew_id_predict_light_distance( qaid, daids, config['config_path'], ) qaid_score_dict[qaid] = {} - for daid, tbd_annot_distance in zip(daids, tbd_annot_distances): - qaid_score_dict[qaid][daid] = distance_to_score(tbd_annot_distance) + for daid, miew_id_annot_distance in zip(daids, miew_id_annot_distances): + qaid_score_dict[qaid][daid] = distance_to_score(miew_id_annot_distance) for qaid, daid in zip(qaid_list, daid_list): if qaid == daid: @@ -339,7 +432,7 @@ def evaluate_distmat(ibs, aid_list, config, use_depc, ranks=[1, 5, 10, 20]): """Evaluate 1vsall accuracy of matching on annotations by computing distance matrix. """ - embs = np.array(tbd_embedding(ibs, aid_list, config, use_depc)) + embs = np.array(miew_id_embedding(ibs, aid_list, config, use_depc)) print('Computing distance matrix ...') distmat = compute_distance_matrix(embs, embs, metric='cosine') @@ -360,7 +453,7 @@ def _load_config(config_url): """ config_fname = config_url.split('/')[-1] config_file = ut.grab_file_url( - config_url, appname='wbia_tbd', check_hash=False, fname=config_fname + config_url, appname='wbia_miew_id', check_hash=False, fname=config_fname ) config = get_config(config_file) @@ -370,7 +463,7 @@ def _load_config(config_url): return config -def _load_model(config, model_url): +def _load_model(config, model_url, use_dataparallel=True): r""" Load a model based on config file """ @@ -386,7 +479,7 @@ def _load_model(config, model_url): # Download the model weights model_fname = model_url.split('/')[-1] model_path = ut.grab_file_url( - model_url, appname='wbia_tbd', check_hash=False, fname=model_fname + model_url, appname='wbia_miew_id', check_hash=False, fname=model_fname ) # load_pretrained_weights(model, model_path) @@ -398,7 +491,7 @@ def _load_model(config, model_url): # else: # model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # print('Loaded model from {}'.format(model_path)) - if config.use_gpu: + if config.use_gpu and use_dataparallel: model = torch.nn.DataParallel(model).cuda() return model @@ -443,7 +536,7 @@ def _load_data(ibs, aid_list, config, multithread=False): return dataloader, dataset -def wbia_tbd_test_ibs(demo_db_url, species, subset): +def wbia_miew_id_test_ibs(demo_db_url, species, subset): r""" Create a database to test orientation detection from a coco annotation file """ @@ -454,7 +547,7 @@ def wbia_tbd_test_ibs(demo_db_url, species, subset): return test_ibs else: # Download demo data archive - db_dir = ut.grab_zipped_url(demo_db_url, appname='wbia_tbd') + db_dir = ut.grab_zipped_url(demo_db_url, appname='wbia_miew_id') # Load coco annotations json_file = os.path.join( @@ -495,9 +588,9 @@ def wbia_tbd_test_ibs(demo_db_url, species, subset): @register_ibs_method -def tbd_predict_light(ibs, qaid, daid_list, config=None): - db_embs = np.array(ibs.tbd_embedding(daid_list, config)) - query_emb = np.array(ibs.tbd_embedding([qaid], config)) +def miew_id_predict_light(ibs, qaid, daid_list, config=None): + db_embs = np.array(ibs.miew_id_embedding(daid_list, config)) + query_emb = np.array(ibs.miew_id_embedding([qaid], config)) # db_labels = np.array(ibs.get_annot_name_texts(daid_list, distinguish_unknowns=True)) db_labels = np.array(daid_list) @@ -507,10 +600,10 @@ def tbd_predict_light(ibs, qaid, daid_list, config=None): @register_ibs_method -def tbd_predict_light_distance(ibs, qaid, daid_list, config=None): +def miew_id_predict_light_distance(ibs, qaid, daid_list, config=None): assert len(daid_list) == len(set(daid_list)) - db_embs = np.array(ibs.tbd_embedding(daid_list, config)) - query_emb = np.array(ibs.tbd_embedding([qaid], config)) + db_embs = np.array(ibs.miew_id_embedding(daid_list, config)) + query_emb = np.array(ibs.miew_id_embedding([qaid], config)) input1 = torch.Tensor(query_emb) input2 = torch.Tensor(db_embs) @@ -519,10 +612,10 @@ def tbd_predict_light_distance(ibs, qaid, daid_list, config=None): return distances -def _tbd_accuracy(ibs, qaid, daid_list): +def _miew_id_accuracy(ibs, qaid, daid_list): daids = daid_list.copy() daids.remove(qaid) - ans = ibs.tbd_predict_light(qaid, daids) + ans = ibs.miew_id_predict_light(qaid, daids) ans_names = [row['label'] for row in ans] ground_truth = ibs.get_annot_name_texts(qaid) try: @@ -533,10 +626,10 @@ def _tbd_accuracy(ibs, qaid, daid_list): return rank -def tbd_mass_accuracy(ibs, aid_list, daid_list=None): +def miew_id_mass_accuracy(ibs, aid_list, daid_list=None): if daid_list is None: daid_list = aid_list - ranks = [_tbd_accuracy(ibs, aid, daid_list) for aid in aid_list] + ranks = [_miew_id_accuracy(ibs, aid, daid_list) for aid in aid_list] return ranks @@ -581,9 +674,9 @@ def subset_with_resights_range(ibs, aid_list, min_sights=3, max_sights=10): @register_ibs_method -def tbd_new_accuracy(ibs, aid_list, min_sights=3, max_sights=10): +def miew_id_new_accuracy(ibs, aid_list, min_sights=3, max_sights=10): aids = subset_with_resights_range(ibs, aid_list, min_sights, max_sights) - ranks = tbd_mass_accuracy(ibs, aids) + ranks = miew_id_mass_accuracy(ibs, aids) accuracy = accuracy_at_k(ibs, ranks) print( 'Accuracy at k for annotations with %s-%s sightings:' % (min_sights, max_sights) @@ -592,10 +685,10 @@ def tbd_new_accuracy(ibs, aid_list, min_sights=3, max_sights=10): return accuracy -# The following functions are cotbdd from TBD v1 because these functions +# The following functions are comiew_idd from TBD v1 because these functions # are agnostic tot eh method of computing embeddings: -# https://github.com/WildMeOrg/wbia-plugin-tbd/wbia_tbd/_plugin.py -def _db_labels_for_tbd(ibs, daid_list): +# https://github.com/WildMeOrg/wbia-plugin-miew_id/wbia_miew_id/_plugin.py +def _db_labels_for_miew_id(ibs, daid_list): db_labels = ibs.get_annot_name_texts(daid_list, distinguish_unknowns=True) # db_auuids = ibs.get_annot_name_rowids(daid_list) # # later we must know which db_labels are for single auuids, hence prefix @@ -631,7 +724,7 @@ def aid_scores_from_score_dict(name_score_dict, daid_list): return daid_scores def aid_scores_from_name_scores(ibs, name_score_dict, daid_list): - daid_name_list = list(_db_labels_for_tbd(ibs, daid_list)) + daid_name_list = list(_db_labels_for_miew_id(ibs, daid_list)) name_count_dict = { name: daid_name_list.count(name) for name in name_score_dict.keys() @@ -654,7 +747,7 @@ def aid_scores_from_name_scores(ibs, name_score_dict, daid_list): if __name__ == '__main__': r""" CommandLine: - python -m wbia_tbd._plugin --allexamples + python -m wbia_miew_id._plugin --allexamples """ import multiprocessing diff --git a/wbia_tbd/configs/default_config.yaml b/wbia_miew_id/configs/default_config.yaml similarity index 97% rename from wbia_tbd/configs/default_config.yaml rename to wbia_miew_id/configs/default_config.yaml index 7583ff0..a350a94 100644 --- a/wbia_tbd/configs/default_config.yaml +++ b/wbia_miew_id/configs/default_config.yaml @@ -20,7 +20,7 @@ data: engine: num_workers: 0 train_batch_size: 6 - valid_batch_size: 24 + valid_batch_size: 12 epochs: 30 seed: 42 device: cuda diff --git a/wbia_tbd/datasets/__init__.py b/wbia_miew_id/datasets/__init__.py similarity index 100% rename from wbia_tbd/datasets/__init__.py rename to wbia_miew_id/datasets/__init__.py diff --git a/wbia_tbd/datasets/default_dataset.py b/wbia_miew_id/datasets/default_dataset.py similarity index 63% rename from wbia_tbd/datasets/default_dataset.py rename to wbia_miew_id/datasets/default_dataset.py index ceaaf4a..8374a8d 100644 --- a/wbia_tbd/datasets/default_dataset.py +++ b/wbia_miew_id/datasets/default_dataset.py @@ -2,20 +2,23 @@ import cv2 import torch from torch.utils.data import Dataset +import numpy as np -class TbdDataset(Dataset): - def __init__(self, csv, images_dir, transforms=None): +class MiewIdDataset(Dataset): + def __init__(self, csv, images_dir, transforms=None, fliplr=False, fliplr_view=[]): self.csv = csv#.reset_index() self.augmentations = transforms self.images_dir = images_dir + self.fliplr = fliplr + self.fliplr_view = fliplr_view def __len__(self): return self.csv.shape[0] def __getitem__(self, index): row = self.csv.iloc[index] - + image_path = os.path.join(self.images_dir, row['file_name']) image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) @@ -24,5 +27,9 @@ def __getitem__(self, index): augmented = self.augmentations(image=image) image = augmented['image'] + if self.fliplr: + if row['viewpoint'] in self.fliplr_view: + image = np.fliplr(image) + - return {"image": image, "label":torch.tensor(row['name']), "image_idx": self.csv.index[index]} \ No newline at end of file + return {"image": image, "label":torch.tensor(row['name']), "image_idx": self.csv.index[index], "file_path": image_path} \ No newline at end of file diff --git a/wbia_tbd/datasets/plugin_dataset.py b/wbia_miew_id/datasets/plugin_dataset.py similarity index 94% rename from wbia_tbd/datasets/plugin_dataset.py rename to wbia_miew_id/datasets/plugin_dataset.py index 4463c79..19ed212 100644 --- a/wbia_tbd/datasets/plugin_dataset.py +++ b/wbia_miew_id/datasets/plugin_dataset.py @@ -2,6 +2,7 @@ from torch.utils.data import Dataset import cv2 import numpy as np +import torch class PluginDataset(Dataset): @@ -67,6 +68,6 @@ def __getitem__(self, idx): image = augmented['image'] # image = self.transform(image.copy()) - return image, self.names[idx] + return image, self.names[idx], self.image_paths[idx], torch.Tensor(self.bboxes[idx]) diff --git a/wbia_tbd/datasets/transforms.py b/wbia_miew_id/datasets/transforms.py similarity index 100% rename from wbia_tbd/datasets/transforms.py rename to wbia_miew_id/datasets/transforms.py diff --git a/wbia_tbd/engine/__init__.py b/wbia_miew_id/engine/__init__.py similarity index 100% rename from wbia_tbd/engine/__init__.py rename to wbia_miew_id/engine/__init__.py diff --git a/wbia_tbd/engine/eval_fn.py b/wbia_miew_id/engine/eval_fn.py similarity index 100% rename from wbia_tbd/engine/eval_fn.py rename to wbia_miew_id/engine/eval_fn.py diff --git a/wbia_tbd/engine/run_fn.py b/wbia_miew_id/engine/run_fn.py similarity index 57% rename from wbia_tbd/engine/run_fn.py rename to wbia_miew_id/engine/run_fn.py index c947d45..9139569 100644 --- a/wbia_tbd/engine/run_fn.py +++ b/wbia_miew_id/engine/run_fn.py @@ -9,17 +9,18 @@ def run_fn(config, model, train_loader, valid_loader, criterion, optimizer, scheduler, device, checkpoint_dir, use_wandb=True): - best_loss = np.inf + best_score = 0 for epoch in range(config.engine.epochs): train_loss = train_fn(train_loader, model,criterion, optimizer, device,scheduler=scheduler,epoch=epoch, use_wandb=use_wandb) - torch.save(model.state_dict(), f'{checkpoint_dir}/model_{epoch}.bin') - valid_loss = eval_fn(valid_loader, model, device, use_wandb=use_wandb) + valid_score = eval_fn(valid_loader, model, device, use_wandb=use_wandb) - # if valid_loss.avg < best_loss: - # best_loss = valid_loss.avg - # torch.save(model.state_dict(),f'model_{config.model_name}_IMG_SIZE_{config.data.image_size[0]}_{config.engine.loss_module}.bin') - # print('best model found for epoch {}'.format(epoch)) \ No newline at end of file + if valid_score > best_score: + best_score = valid_score + torch.save(model.state_dict(), f'{checkpoint_dir}/model_best.bin') + print('best model found for epoch {}'.format(epoch)) + + return best_score \ No newline at end of file diff --git a/wbia_tbd/engine/train_fn.py b/wbia_miew_id/engine/train_fn.py similarity index 100% rename from wbia_tbd/engine/train_fn.py rename to wbia_miew_id/engine/train_fn.py diff --git a/wbia_tbd/etl/__init__.py b/wbia_miew_id/etl/__init__.py similarity index 100% rename from wbia_tbd/etl/__init__.py rename to wbia_miew_id/etl/__init__.py diff --git a/wbia_tbd/etl/coco_helpers.py b/wbia_miew_id/etl/coco_helpers.py similarity index 100% rename from wbia_tbd/etl/coco_helpers.py rename to wbia_miew_id/etl/coco_helpers.py diff --git a/wbia_tbd/etl/eda.py b/wbia_miew_id/etl/eda.py similarity index 100% rename from wbia_tbd/etl/eda.py rename to wbia_miew_id/etl/eda.py diff --git a/wbia_tbd/etl/preprocess.py b/wbia_miew_id/etl/preprocess.py similarity index 52% rename from wbia_tbd/etl/preprocess.py rename to wbia_miew_id/etl/preprocess.py index 31cc952..13f3195 100644 --- a/wbia_tbd/etl/preprocess.py +++ b/wbia_miew_id/etl/preprocess.py @@ -38,22 +38,27 @@ def convert_name_to_id(names): names_id = le.fit_transform(names) return names_id - - def preprocess_data(anno_path, name_keys=['name'], convert_names_to_ids=True, viewpoint_list=None, n_filter_min=None, n_subsample_max=None): df = load_to_df(anno_path) + print(f'** Loaded {anno_path} **') + print(' ', f'Found {len(df)} annotations') + df['name'] = df[name_keys].apply(lambda row: '_'.join(row.values.astype(str)), axis=1) + df['name_orig'] = df['name'].copy() if viewpoint_list: df = filter_viewpoint_df(df, viewpoint_list) + print(' ', len(df), 'annotations remain after filtering by viewpoint list', viewpoint_list) if n_filter_min: df = filter_min_names_df(df, n_filter_min) + print(' ', len(df), 'annotations remain after filtering by min', n_filter_min) if n_subsample_max: df = subsample_max_df(df, n_subsample_max) + print(' ', len(df), 'annotations remain after subsampling by max', n_subsample_max) if convert_names_to_ids: names = df['name'].values @@ -61,46 +66,5 @@ def preprocess_data(anno_path, name_keys=['name'], convert_names_to_ids=True, vi df['name'] = names_id return df - -# def make_dataframes(): -# DATA_DIR = "data/beluga-coco-v0-full" -# IMAGES_DIR = "data/beluga-440" - -# # anno_dir = os.path.join(DATA_DIR, "annotations") -# anno_dir = os.path.join(DATA_DIR, "") -# anno_file = lambda split: f"instances_{split}2023.json" - -# train_anno_path = os.path.join(anno_dir, anno_file("train")) -# val_anno_path = os.path.join(anno_dir, anno_file("val")) -# test_anno_path = os.path.join(anno_dir, anno_file("test")) - -# df_train = load_to_df(train_anno_path) -# df_val = load_to_df(val_anno_path) - -# df_train = df_train[df_train['viewpoint']=='up'] -# df_val = df_val[df_val['viewpoint']=='up'] - -# ## NOTE have to safely handle this case -# df_train['name'] = df_train['name'].astype(int) -# df_val['name'] = df_val['name'].astype(int) - -# df_train = df_train.groupby('name').filter(lambda g: len(g)>=4) -# df_val = df_val.groupby('name').filter(lambda g: len(g)>=2) - -# # df_train.groupby('name')['name'].count().hist() -# # df_val.groupby('name')['name'].count().hist() - -# le = LabelEncoder() -# df_train['name'] = le.fit_transform(df_train['name']) -# print('generated {n_train_classes} labels for the training set'.format(n_train_classes=df_train['name'].nunique())) -# # print(df_train['name'].max(), df_train['name'].nunique()) - -# ## NOTE column filtering can be done earlier to save memory for merge -# # df_train = df_train['name', 'file_name', 'viewpoint'] -# # df_val = df_val['name', 'file_name', 'viewpoint'] - - -# return df_train, df_val - if __name__ == "__main__": pass \ No newline at end of file diff --git a/wbia_tbd/helpers/__init__.py b/wbia_miew_id/helpers/__init__.py similarity index 100% rename from wbia_tbd/helpers/__init__.py rename to wbia_miew_id/helpers/__init__.py diff --git a/wbia_tbd/helpers/config.py b/wbia_miew_id/helpers/config.py similarity index 97% rename from wbia_tbd/helpers/config.py rename to wbia_miew_id/helpers/config.py index 8c56dbc..e2f9a21 100644 --- a/wbia_tbd/helpers/config.py +++ b/wbia_miew_id/helpers/config.py @@ -81,6 +81,7 @@ class Config(DictableClass): def get_config(file_path: str) -> Config: + print(f"Loading config from path: {file_path}") with open(file_path, 'r') as file: config_dict = yaml.safe_load(file) diff --git a/wbia_tbd/helpers/getters.py b/wbia_miew_id/helpers/getters.py similarity index 91% rename from wbia_tbd/helpers/getters.py rename to wbia_miew_id/helpers/getters.py index 9ff88a5..8347034 100644 --- a/wbia_tbd/helpers/getters.py +++ b/wbia_miew_id/helpers/getters.py @@ -1,10 +1,10 @@ # import torch -# from models import TbdNet -# from datasets import TbdDataset +# from models import MiewIdNet +# from datasets import MiewIdDataset # def get_model(cfg, checkpoint_path=None, use_gpu=True): -# model = TbdNet(**dict(cfg.model_params)) +# model = MiewIdNet(**dict(cfg.model_params)) # if use_gpu: @@ -18,7 +18,7 @@ # return model # def get_dataloader(df_data, images_dir, cfg, transforms, shuffle=True): -# dataset = TbdDataset( +# dataset = MiewIdDataset( # csv=df_data, # images_dir = images_dir, # transforms=transforms, diff --git a/wbia_tbd/helpers/tools.py b/wbia_miew_id/helpers/tools.py similarity index 100% rename from wbia_tbd/helpers/tools.py rename to wbia_miew_id/helpers/tools.py diff --git a/wbia_tbd/logging_utils/__init__.py b/wbia_miew_id/logging_utils/__init__.py similarity index 100% rename from wbia_tbd/logging_utils/__init__.py rename to wbia_miew_id/logging_utils/__init__.py diff --git a/wbia_tbd/logging_utils/wandb_utils.py b/wbia_miew_id/logging_utils/wandb_utils.py similarity index 100% rename from wbia_tbd/logging_utils/wandb_utils.py rename to wbia_miew_id/logging_utils/wandb_utils.py diff --git a/wbia_tbd/losses/__init__.py b/wbia_miew_id/losses/__init__.py similarity index 100% rename from wbia_tbd/losses/__init__.py rename to wbia_miew_id/losses/__init__.py diff --git a/wbia_tbd/losses/cross b/wbia_miew_id/losses/cross similarity index 100% rename from wbia_tbd/losses/cross rename to wbia_miew_id/losses/cross diff --git a/wbia_tbd/losses/focal_loss.py b/wbia_miew_id/losses/focal_loss.py similarity index 100% rename from wbia_tbd/losses/focal_loss.py rename to wbia_miew_id/losses/focal_loss.py diff --git a/wbia_tbd/losses/loss_utils.py b/wbia_miew_id/losses/loss_utils.py similarity index 100% rename from wbia_tbd/losses/loss_utils.py rename to wbia_miew_id/losses/loss_utils.py diff --git a/wbia_tbd/metrics/__init__.py b/wbia_miew_id/metrics/__init__.py similarity index 100% rename from wbia_tbd/metrics/__init__.py rename to wbia_miew_id/metrics/__init__.py diff --git a/wbia_tbd/metrics/average_meter.py b/wbia_miew_id/metrics/average_meter.py similarity index 100% rename from wbia_tbd/metrics/average_meter.py rename to wbia_miew_id/metrics/average_meter.py diff --git a/wbia_tbd/metrics/distance.py b/wbia_miew_id/metrics/distance.py similarity index 100% rename from wbia_tbd/metrics/distance.py rename to wbia_miew_id/metrics/distance.py diff --git a/wbia_tbd/metrics/eval_onevsall.py b/wbia_miew_id/metrics/eval_onevsall.py similarity index 100% rename from wbia_tbd/metrics/eval_onevsall.py rename to wbia_miew_id/metrics/eval_onevsall.py diff --git a/wbia_tbd/metrics/knn.py b/wbia_miew_id/metrics/knn.py similarity index 100% rename from wbia_tbd/metrics/knn.py rename to wbia_miew_id/metrics/knn.py diff --git a/wbia_tbd/models/__init__.py b/wbia_miew_id/models/__init__.py similarity index 100% rename from wbia_tbd/models/__init__.py rename to wbia_miew_id/models/__init__.py diff --git a/wbia_tbd/models/heads.py b/wbia_miew_id/models/heads.py similarity index 100% rename from wbia_tbd/models/heads.py rename to wbia_miew_id/models/heads.py diff --git a/wbia_tbd/models/model.py b/wbia_miew_id/models/model.py similarity index 98% rename from wbia_tbd/models/model.py rename to wbia_miew_id/models/model.py index f0d4a29..6629b96 100644 --- a/wbia_tbd/models/model.py +++ b/wbia_miew_id/models/model.py @@ -46,7 +46,7 @@ def __repr__(self): -class TbdNet(nn.Module): +class MiewIdNet(nn.Module): def __init__(self, n_classes, @@ -67,7 +67,7 @@ def __init__(self, :param pooling: One of ('SPoC', 'MAC', 'RMAC', 'GeM', 'Rpool', 'Flatten', 'CompactBilinearPooling') :param loss_module: One of ('arcface', 'cosface', 'softmax') """ - super(TbdNet, self).__init__() + super(MiewIdNet, self).__init__() print('Building Model Backbone for {} model'.format(model_name)) self.backbone = timm.create_model(model_name, pretrained=pretrained) diff --git a/wbia_tbd/models/model_helpers.py b/wbia_miew_id/models/model_helpers.py similarity index 64% rename from wbia_tbd/models/model_helpers.py rename to wbia_miew_id/models/model_helpers.py index f2b847e..62e33f9 100644 --- a/wbia_tbd/models/model_helpers.py +++ b/wbia_miew_id/models/model_helpers.py @@ -1,12 +1,12 @@ import torch import sys -# sys.path.append('..'); from wbia_pie_v2.models import TbdNet -# from datasets import TbdDataset -from .model import TbdNet +# sys.path.append('..'); from wbia_pie_v2.models import MiewIdNet +# from datasets import MiewIdDataset +from .model import MiewIdNet def get_model(cfg, checkpoint_path=None, use_gpu=True): - model = TbdNet(**dict(cfg.model_params)) + model = MiewIdNet(**dict(cfg.model_params)) if use_gpu: diff --git a/wbia_tbd/schedulers/__init__.py b/wbia_miew_id/schedulers/__init__.py similarity index 100% rename from wbia_tbd/schedulers/__init__.py rename to wbia_miew_id/schedulers/__init__.py diff --git a/wbia_tbd/schedulers/default_scheduler.py b/wbia_miew_id/schedulers/default_scheduler.py similarity index 93% rename from wbia_tbd/schedulers/default_scheduler.py rename to wbia_miew_id/schedulers/default_scheduler.py index 7d33cc2..3e3a45f 100644 --- a/wbia_tbd/schedulers/default_scheduler.py +++ b/wbia_miew_id/schedulers/default_scheduler.py @@ -1,7 +1,7 @@ import warnings from torch.optim.lr_scheduler import _LRScheduler -class TbdScheduler(_LRScheduler): +class MiewIdScheduler(_LRScheduler): def __init__(self, optimizer, lr_start=5e-6, lr_max=1e-5, lr_min=1e-6, lr_ramp_ep=5, lr_sus_ep=0, lr_decay=0.8, last_epoch=-1): @@ -11,7 +11,7 @@ def __init__(self, optimizer, lr_start=5e-6, lr_max=1e-5, self.lr_ramp_ep = lr_ramp_ep self.lr_sus_ep = lr_sus_ep self.lr_decay = lr_decay - super(TbdScheduler, self).__init__(optimizer, last_epoch) + super(MiewIdScheduler, self).__init__(optimizer, last_epoch) def get_lr(self): if not self._get_lr_called_within_step: diff --git a/wbia_tbd/schedulers/fetch_schedulers.py b/wbia_miew_id/schedulers/fetch_schedulers.py similarity index 100% rename from wbia_tbd/schedulers/fetch_schedulers.py rename to wbia_miew_id/schedulers/fetch_schedulers.py diff --git a/wbia_miew_id/sweep.py b/wbia_miew_id/sweep.py new file mode 100644 index 0000000..9c8e88a --- /dev/null +++ b/wbia_miew_id/sweep.py @@ -0,0 +1,75 @@ +import optuna +import yaml +from train import run +from helpers import get_config +from optuna.pruners import MedianPruner +from optuna.samplers import TPESampler +import pickle + + +import argparse + + +def parse_args(): + parser = argparse.ArgumentParser(description="Load configuration file.") + parser.add_argument( + "--config", + type=str, + default="configs/default_config.yaml", + help="Path to the YAML configuration file. Default: configs/default_config.yaml", + ) + return parser.parse_args() + + +def objective(trial, config): + + # Specify the parameters you want to optimize + config.data.train_n_filter_min = trial.suggest_int("train_n_filter_min", 2, 5) + image_size = trial.suggest_categorical("image_size", [192, 256, 384, 440, 512]) + config.data.image_size = [image_size, image_size] + n_epochs = trial.suggest_int("epochs", 20, 40) + config.engine.epochs = n_epochs + config.model_params.margin = trial.suggest_uniform("margin", 0.1, 0.7) + config.model_params.s = trial.suggest_uniform("s", 20, 64) + + # The scheduler params are derived from one base paremeter to minimize the number of parameters to optimzie + lr_base = trial.suggest_loguniform("lr_base", 1e-6, 1e-2) + config.scheduler_params.lr_start = lr_base + config.scheduler_params.lr_max = lr_base * 10 + config.scheduler_params.lr_min = lr_base / 2 + result = run(config) + + print("cfg", config.engine) + + return result + + +if __name__ == "__main__": + # args = parse_args() + config_path = "configs/default_config.yaml" # args.config + + config = get_config(config_path) + + study = optuna.create_study( + sampler=TPESampler(), pruner=MedianPruner(), direction="maximize" + ) + + comb_objective = lambda trial: objective(trial, config) + + study.optimize(comb_objective, n_trials=100) + + print("Best trial:") + trial_ = study.best_trial + + print(f"Value: {trial_.value}") + + print("Best parameters:") + for key, value in trial_.params.items(): + print(f" {key}: {value}") + + # saves best parameters + save_dict = trial_.params + save_dict['best_score'] = trial_.value + + with open('sweep.pkl', 'wb') as f: + pickle.dump(save_dict, f, pickle.HIGHEST_PROTOCOL) diff --git a/wbia_tbd/train.py b/wbia_miew_id/train.py similarity index 78% rename from wbia_tbd/train.py rename to wbia_miew_id/train.py index c160fef..d3f3948 100644 --- a/wbia_tbd/train.py +++ b/wbia_miew_id/train.py @@ -1,9 +1,9 @@ -from datasets import TbdDataset, get_train_transforms, get_valid_transforms +from datasets import MiewIdDataset, get_train_transforms, get_valid_transforms from logging_utils import init_wandb -from models import TbdNet -from etl import preprocess_data, print_intersect_stats +from models import MiewIdNet +from etl import preprocess_data, print_intersect_stats, convert_name_to_id from losses import fetch_loss -from schedulers import TbdScheduler +from schedulers import MiewIdScheduler from engine import run_fn from helpers import get_config @@ -29,10 +29,8 @@ def parse_args(): ) return parser.parse_args() -def run(config_path): +def run(config): - config = get_config(config_path) - checkpoint_dir = f"{config.checkpoint_dir}/{config.project_name}/{config.exp_name}/{config.model_params.model_name}-{config.data.image_size[0]}-{config.engine.loss_module}" os.makedirs(checkpoint_dir, exist_ok=True) print('Checkpoints will be saved at: ', checkpoint_dir) @@ -61,22 +59,30 @@ def set_seed_torch(seed): viewpoint_list=config.data.viewpoint_list, n_filter_min=config.data.val_n_filter_min, n_subsample_max=config.data.val_n_subsample_max) - - print_intersect_stats(df_train, df_val) + + print_intersect_stats(df_train, df_val, individual_key='name_orig') + + # df_train['name'] = convert_name_to_id(df_train['name'].values) + # df_val['name'] = convert_name_to_id(df_val['name'].values) + n_train_classes = df_train['name'].nunique() - train_dataset = TbdDataset( + train_dataset = MiewIdDataset( csv=df_train, images_dir = config.data.images_dir, transforms=get_train_transforms(config), + fliplr=config.test.fliplr, + fliplr_view=config.test.fliplr_view ) - valid_dataset = TbdDataset( + valid_dataset = MiewIdDataset( csv=df_val, images_dir=config.data.images_dir, transforms=get_valid_transforms(config), + fliplr=config.test.fliplr, + fliplr_view=config.test.fliplr_view ) train_loader = torch.utils.data.DataLoader( @@ -101,7 +107,7 @@ def set_seed_torch(seed): if config.model_params.n_classes != n_train_classes: print(f"WARNING: Overriding n_classes in config ({config.model_params.n_classes}) which is different from actual n_train_classes ({n_train_classes}). This parameters has to be readjusted in config for proper checkpoint loading after training.") config.model_params.n_classes = n_train_classes - model = TbdNet(**dict(config.model_params)) + model = MiewIdNet(**dict(config.model_params)) model.to(device) criterion = fetch_loss() @@ -110,18 +116,21 @@ def set_seed_torch(seed): optimizer = torch.optim.Adam(model.parameters(), lr=config.scheduler_params.lr_start) - scheduler = TbdScheduler(optimizer,**dict(config.scheduler_params)) + scheduler = MiewIdScheduler(optimizer,**dict(config.scheduler_params)) if config.engine.use_wandb: load_dotenv() init_wandb(config.exp_name, config.project_name, config=None) - run_fn(config, model, train_loader, valid_loader, criterion, optimizer, scheduler, device, checkpoint_dir, use_wandb=config.engine.use_wandb) + best_score = run_fn(config, model, train_loader, valid_loader, criterion, optimizer, scheduler, device, checkpoint_dir, use_wandb=config.engine.use_wandb) + + return best_score if __name__ == '__main__': args = parse_args() config_path = args.config - print(f"Loading config from path: {config_path}") + + config = get_config(config_path) - run(config_path) \ No newline at end of file + run(config) \ No newline at end of file diff --git a/wbia_miew_id/visualization/__init__.py b/wbia_miew_id/visualization/__init__.py new file mode 100644 index 0000000..3d52129 --- /dev/null +++ b/wbia_miew_id/visualization/__init__.py @@ -0,0 +1 @@ +from .gradcam import * \ No newline at end of file diff --git a/wbia_miew_id/visualization/gradcam.py b/wbia_miew_id/visualization/gradcam.py new file mode 100644 index 0000000..471ba21 --- /dev/null +++ b/wbia_miew_id/visualization/gradcam.py @@ -0,0 +1,333 @@ +import os +import time +import pandas as pd +import numpy as np +import cv2 +import torch +from tqdm.auto import tqdm +import matplotlib.pyplot as plt + +from pytorch_grad_cam import GradCAMPlusPlus, EigenCAM +from pytorch_grad_cam.utils.image import show_cam_on_image + + +from wbia_miew_id.datasets import MiewIdDataset, get_valid_transforms +from wbia_miew_id.models import MiewIdNet + +def resize_image(image, new_height): + aspect_ratio = image.shape[1] / image.shape[0] + new_width = int(new_height * aspect_ratio) + resized_image = cv2.resize(image, (new_width, new_height)) + return resized_image + + +class SimilarityToConceptTarget: + def __init__(self, features): + self.features = features + + def __call__(self, model_output): + cos = torch.nn.CosineSimilarity(dim=0) + return cos(model_output, self.features) + +def batch_iter(iterable, n=1): + l = len(iterable) + for ndx in range(0, l, n): + yield iterable[ndx:min(ndx + n, l)] + +def load_image(image_path): + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image + +def show_cam_on_image(img: np.ndarray, + mask: np.ndarray, + use_rgb: bool = False, + colormap: int = cv2.COLORMAP_JET, + image_weight: float = 0.6) -> np.ndarray: + heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) + + # Keep heatmap areas lower than threshold transparent + heatmap[mask <= 0.05] = 0 + if use_rgb: + heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) + heatmap = np.float32(heatmap) / 255 + + if np.max(img) > 1: + raise Exception( + "The input image should np.float32 in the range [0, 1]") + + if image_weight < 0 or image_weight > 1: + raise Exception( + f"image_weight should be in the range [0, 1].\ + Got: {image_weight}") + + cam = (1 - image_weight) * heatmap + image_weight * img + + cam = cam / np.max(cam) + return np.uint8(255 * cam) + +def draw_one(config, test_loader, model, images_dir = '', method='gradcam_plus_plus', eigen_smooth=False, show=False): + + # Generate embeddings for query and db + model.eval() + tk0 = tqdm(test_loader, total=len(test_loader)) + embeddings = [] + labels = [] + images = [] + paths = [] + bboxes = [] + with torch.no_grad(): + for batch in tk0: + batch_image = batch[0] + batch_name = batch[1] + batch_path = batch[2] + batch_bbox = batch[3] + + + images.extend(batch_image) + batch_embeddings = model(batch_image.to(config.engine.device)) + + batch_embeddings = batch_embeddings.detach().cpu().numpy() + + batch_embeddings_df = pd.DataFrame(batch_embeddings) + embeddings.append(batch_embeddings_df) + + batch_labels = batch_name.tolist() + labels.extend(batch_labels) + + paths.extend(batch_path) + + bboxes.extend(batch_bbox) + + bboxes = [t.int().tolist() for t in bboxes] + + embeddings = pd.concat(embeddings) + + target_layers = model.backbone.conv_head + + if method=='gradcam_plus_plus': + generate_cam = GradCAMPlusPlus(model=model,target_layers=[target_layers],use_cuda=True) + elif method=='eigencam': + generate_cam = EigenCAM(model=model,target_layers=[target_layers],use_cuda=True) + + qry_idx = 0 + db_idx = 1 + + qry_features = embeddings.iloc[qry_idx].values + qry_features = torch.Tensor(qry_features).to(config.engine.device) + + db_features = embeddings.iloc[db_idx].values + db_features = torch.Tensor(db_features).to(config.engine.device) + + similarity_to_qry = SimilarityToConceptTarget(qry_features) + similarity_to_db = SimilarityToConceptTarget(db_features) + + qry_tensor = images[qry_idx] + db_tensor = images[db_idx] + + db_tensor = db_tensor.unsqueeze(0) + qry_tensor = qry_tensor.unsqueeze(0) + + + # generate results + stack_tensor = torch.concatenate([db_tensor, qry_tensor]) + stack_target = [similarity_to_qry, similarity_to_db] + results_cam = generate_cam(input_tensor=stack_tensor,targets=stack_target,aug_smooth=False,eigen_smooth=eigen_smooth) + qry_grayscale_cam = results_cam[0, :] + db_grayscale_cam = results_cam[1, :] + + + # query image results + qry_image_path = paths[qry_idx] + qry_float = load_image(qry_image_path) + qry_bbox = bboxes[qry_idx] + + x1, y1, w, h = qry_bbox + + + qry_float = qry_float[y1 : y1 + h, x1 : x1 + w] + if min(qry_float.shape) < 1: + # Use original image + qry_float = qry_float = load_image(qry_image_path) + + qry_float_norm = (qry_float - qry_float.min()) / (qry_float.max() - qry_float.min()) + db_grayscale_cam_res = cv2.resize(db_grayscale_cam, (qry_float_norm.shape[1], qry_float_norm.shape[0])) + cam_image_qry = show_cam_on_image(qry_float_norm, db_grayscale_cam_res, use_rgb=True) + + ai0 = cam_image_qry + ai1 = qry_float + + # db image results + db_image_path = paths[db_idx] + db_float = load_image(db_image_path) + db_bbox = bboxes[db_idx] + x1, y1, w, h = db_bbox + db_float = db_float[y1 : y1 + h, x1 : x1 + w] + if min(db_float.shape) < 1: + # Use original image + db_float = db_float = load_image(db_image_path) + + db_float_norm = (db_float - db_float.min()) / (db_float.max() - db_float.min()) + qry_grayscale_cam_res = cv2.resize(qry_grayscale_cam, (db_float_norm.shape[1], db_float_norm.shape[0])) + cam_image_db = show_cam_on_image(db_float_norm, qry_grayscale_cam_res, use_rgb=True) + + ai2 = cam_image_db + ai3 = db_float + + image_list = [ai0, ai1, ai2, ai3] + resize_height = 440 + resized_image_list = [resize_image(img, resize_height) for img in image_list] + comb_image = np.hstack(resized_image_list) + if show: + plt.imshow(comb_image) + + comb_image = cv2.cvtColor(comb_image, cv2.COLOR_BGR2RGB) + return comb_image + +def generate_embeddings(config, model, test_loader): + print('generating embeddings') + tk0 = tqdm(test_loader, total=len(test_loader)) + embeddings = [] + labels = [] + images = [] + paths = [] + bboxes = [] + with torch.no_grad(): + for batch in tk0: + batch_image = batch[0] + batch_name = batch[1] + batch_path = batch[2] + batch_bbox = batch[3] + + + images.extend(batch_image) + batch_embeddings = model(batch_image.to(config.engine.device)) + + batch_embeddings = batch_embeddings.detach().cpu().numpy() + + batch_embeddings_df = pd.DataFrame(batch_embeddings) + embeddings.append(batch_embeddings_df) + + batch_labels = batch_name.tolist() + labels.extend(batch_labels) + + paths.extend(batch_path) + + bboxes.extend(batch_bbox) + + bboxes = [t.int().tolist() for t in bboxes] + + embeddings = pd.concat(embeddings) + return embeddings, labels, images, paths, bboxes + +def draw_batch(config, test_loader, model, images_dir = '', method='gradcam_plus_plus', eigen_smooth=False, show=False): + + print('** draw_batch started') + + # Generate embeddings for query and db + model.eval() + + embeddings, labels, images, paths, bboxes = generate_embeddings(config, model, test_loader) + + target_layers = model.backbone.conv_head + + if method=='gradcam_plus_plus': + generate_cam = GradCAMPlusPlus(model=model,target_layers=[target_layers],use_cuda=True) + elif method=='eigencam': + generate_cam = EigenCAM(model=model,target_layers=[target_layers],use_cuda=True) + + qry_idx = 0 + db_idx = 1 + + qry_features = embeddings.iloc[qry_idx].values + qry_features = torch.Tensor(qry_features).to(config.engine.device) + + db_features_batch = embeddings.iloc[db_idx:].values + db_features_batch = torch.Tensor(db_features_batch).to(config.engine.device) + + tensors = [] + stack_target = [] + print('generating similarity targets') + for i, db_features in tqdm(enumerate(db_features_batch)): + + similarity_to_qry = SimilarityToConceptTarget(qry_features) + similarity_to_db = SimilarityToConceptTarget(db_features) + + qry_tensor = images[qry_idx] + db_tensor = images[db_idx + i] + + db_tensor = db_tensor.unsqueeze(0) + qry_tensor = qry_tensor.unsqueeze(0) + tensors.extend([db_tensor, qry_tensor]) + stack_target.extend([similarity_to_qry, similarity_to_db]) + + stack_tensor = torch.concatenate(tensors) + + batch_images = [] + results_cam = [] + + batch_size = test_loader.batch_size + + batch_step = max(batch_size//2, 2) + print('generating cams') + + for i in tqdm(range(0, len(stack_target), batch_step)): + stack_tensor_batch = stack_tensor[i:i+batch_step] + stack_target_batch = stack_target[i:i+batch_step] + results_cam_batch = generate_cam(input_tensor=stack_tensor_batch,targets=stack_target_batch,aug_smooth=False,eigen_smooth=eigen_smooth) + results_cam.extend(results_cam_batch) + + results_cam = np.array(results_cam) + + print('generating image tile') + for i in tqdm(range(0, results_cam.shape[0], 2)): + qry_grayscale_cam = results_cam[i, :] + db_grayscale_cam = results_cam[i+1, :] + + # query image results + qry_image_path = paths[qry_idx] + qry_float = load_image(qry_image_path) + qry_bbox = bboxes[qry_idx] + + x1, y1, w, h = qry_bbox + + qry_float = qry_float[y1 : y1 + h, x1 : x1 + w] + if min(qry_float.shape) < 1: + # Use original image + qry_float = qry_float = load_image(qry_image_path) + + qry_float_norm = (qry_float - qry_float.min()) / (qry_float.max() - qry_float.min()) + db_grayscale_cam_res = cv2.resize(db_grayscale_cam, (qry_float_norm.shape[1], qry_float_norm.shape[0])) + cam_image_qry = show_cam_on_image(qry_float_norm, db_grayscale_cam_res, use_rgb=True) + + ai0 = cam_image_qry + ai1 = qry_float + + # db image results + db_image_path = paths[db_idx + i//2] + db_float = load_image(db_image_path) + db_bbox = bboxes[db_idx + i//2] + x1, y1, w, h = db_bbox + db_float = db_float[y1 : y1 + h, x1 : x1 + w] + if min(db_float.shape) < 1: + # Use original image + db_float = db_float = load_image(db_image_path) + + db_float_norm = (db_float - db_float.min()) / (db_float.max() - db_float.min()) + qry_grayscale_cam_res = cv2.resize(qry_grayscale_cam, (db_float_norm.shape[1], db_float_norm.shape[0])) + cam_image_db = show_cam_on_image(db_float_norm, qry_grayscale_cam_res, use_rgb=True) + + ai2 = cam_image_db + ai3 = db_float + + image_list = [ai0, ai1, ai2, ai3] + resize_height = 440 + resized_image_list = [resize_image(img, resize_height) for img in image_list] + comb_image = np.hstack(resized_image_list) + if show: + plt.imshow(comb_image) + + comb_image = cv2.cvtColor(comb_image, cv2.COLOR_BGR2RGB) + + batch_images.append(comb_image) + return batch_images \ No newline at end of file diff --git a/wbia_tbd/__init__.py b/wbia_tbd/__init__.py deleted file mode 100644 index 49af4ee..0000000 --- a/wbia_tbd/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from wbia_tbd import _plugin # NOQA - - -__version__ = '0.0.0'