From a88cd1073010c3082e42c62be86d1331dd508397 Mon Sep 17 00:00:00 2001 From: Rakotoarivony Date: Fri, 5 Sep 2025 16:55:45 +0200 Subject: [PATCH] Add VRA --- README.md | 1 + configs/postprocessors/vra.yml | 9 +++ openood/evaluation_api/postprocessor.py | 4 +- openood/postprocessors/__init__.py | 1 + openood/postprocessors/utils.py | 4 +- openood/postprocessors/vra_postprocessor.py | 71 +++++++++++++++++++++ scripts/ood/vra/cifar100_test_ood_vra.sh | 33 ++++++++++ scripts/ood/vra/cifar10_test_ood_vra.sh | 35 ++++++++++ scripts/ood/vra/imagenet200_test_ood_vra.sh | 23 +++++++ scripts/ood/vra/imagenet_test_ood_vra.sh | 47 ++++++++++++++ 10 files changed, 226 insertions(+), 2 deletions(-) create mode 100644 configs/postprocessors/vra.yml create mode 100644 openood/postprocessors/vra_postprocessor.py create mode 100755 scripts/ood/vra/cifar100_test_ood_vra.sh create mode 100755 scripts/ood/vra/cifar10_test_ood_vra.sh create mode 100755 scripts/ood/vra/imagenet200_test_ood_vra.sh create mode 100755 scripts/ood/vra/imagenet_test_ood_vra.sh diff --git a/README.md b/README.md index 98b47eb7..4d7ab53c 100644 --- a/README.md +++ b/README.md @@ -247,6 +247,7 @@ distance: f4d5b3 --> > - [x] [![gen](https://img.shields.io/badge/CVPR'23-GEN-fdd7e6?style=for-the-badge)](https://openaccess.thecvf.com/content/CVPR2023/papers/Liu_GEN_Pushing_the_Limits_of_Softmax-Based_Out-of-Distribution_Detection_CVPR_2023_paper.pdf)    ![postprocess] > - [x] [![nnguide](https://img.shields.io/badge/ICCV'23-NNGuide-fdd7e6?style=for-the-badge)](https://arxiv.org/abs/2309.14888)    ![postprocess] > - [x] [![relation](https://img.shields.io/badge/NEURIPS'23-Relation-fdd7e6?style=for-the-badge)](https://arxiv.org/abs/2301.12321)    ![postprocess] +> - [x] [![vra](https://img.shields.io/badge/NeurIPS'23-VRA-fdd7e6?style=for-the-badge)](https://github.com/zeroQiaoba/VRA)    ![postprocess] > - [x] [![scale](https://img.shields.io/badge/ICLR'24-Scale-fdd7e6?style=for-the-badge)](https://github.com/kai422/SCALE)    ![postprocess] > - [x] [![fdbd](https://img.shields.io/badge/ICML'24-fDBD-f4d5b3?style=for-the-badge)](https://github.com/litianliu/fDBD-OOD)    ![postprocess] > - [x] [![adascale-a](https://img.shields.io/badge/arXiv'25-AdaScale\_A-fdd7e6?style=for-the-badge)](https://github.com/sudarshanregmi/adascale)    ![postprocess] diff --git a/configs/postprocessors/vra.yml b/configs/postprocessors/vra.yml new file mode 100644 index 00000000..16217f3a --- /dev/null +++ b/configs/postprocessors/vra.yml @@ -0,0 +1,9 @@ +postprocessor: + name: vra + APS_mode: True + postprocessor_args: + percentile_high: 90 + percentile_low: 10 + postprocessor_sweep: + percentile_high_list: [85, 90, 95, 99] + percentile_low_list: [1, 5, 10, 15] diff --git a/openood/evaluation_api/postprocessor.py b/openood/evaluation_api/postprocessor.py index 4e545e68..717edbc6 100644 --- a/openood/evaluation_api/postprocessor.py +++ b/openood/evaluation_api/postprocessor.py @@ -16,7 +16,8 @@ RMDSPostprocessor, SHEPostprocessor, CIDERPostprocessor, NPOSPostprocessor, GENPostprocessor, NNGuidePostprocessor, RelationPostprocessor, T2FNormPostprocessor, ReweightOODPostprocessor, fDBDPostprocessor, - AdaScalePostprocessor, IODINPostprocessor, NCIPostprocessor) + AdaScalePostprocessor, IODINPostprocessor, NCIPostprocessor, + VRAPostprocessor) from openood.utils.config import Config, merge_configs postprocessors = { @@ -69,6 +70,7 @@ 'reweightood': ReweightOODPostprocessor, 'adascale_a': AdaScalePostprocessor, 'adascale_l': AdaScalePostprocessor, + 'vra': VRAPostprocessor, } link_prefix = 'https://raw.githubusercontent.com/Jingkang50/OpenOOD/main/configs/postprocessors/' diff --git a/openood/postprocessors/__init__.py b/openood/postprocessors/__init__.py index 6be60017..1aad6510 100644 --- a/openood/postprocessors/__init__.py +++ b/openood/postprocessors/__init__.py @@ -47,3 +47,4 @@ from .t2fnorm_postprocessor import T2FNormPostprocessor from .reweightood_postprocessor import ReweightOODPostprocessor from .adascale_postprocessor import AdaScalePostprocessor +from .vra_postprocessor import VRAPostprocessor diff --git a/openood/postprocessors/utils.py b/openood/postprocessors/utils.py index 343beaf0..3a65085b 100644 --- a/openood/postprocessors/utils.py +++ b/openood/postprocessors/utils.py @@ -43,11 +43,12 @@ from .rts_postprocessor import RTSPostprocessor from .gen_postprocessor import GENPostprocessor from .relation_postprocessor import RelationPostprocessor +from .vra_postprocessor import VRAPostprocessor def get_postprocessor(config: Config): postprocessors = { - 'nci': NCIPostprocessor, + 'nci': NCIPostprocessor, 'fdbd': fDBDPostprocessor, 'ash': ASHPostprocessor, 'cider': CIDERPostprocessor, @@ -90,6 +91,7 @@ def get_postprocessor(config: Config): 'gen': GENPostprocessor, 'relation': RelationPostprocessor, 't2fnorm': T2FNormPostprocessor, + 'vra': VRAPostprocessor, } return postprocessors[config.postprocessor.name](config) diff --git a/openood/postprocessors/vra_postprocessor.py b/openood/postprocessors/vra_postprocessor.py new file mode 100644 index 00000000..e0d58120 --- /dev/null +++ b/openood/postprocessors/vra_postprocessor.py @@ -0,0 +1,71 @@ +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from tqdm import tqdm + +from .base_postprocessor import BasePostprocessor + + +class VRAPostprocessor(BasePostprocessor): + + def __init__(self, config): + super(VRAPostprocessor, self).__init__(config) + self.args = self.config.postprocessor.postprocessor_args + self.percentile_high = self.args.percentile_high + self.percentile_low = self.args.percentile_low + self.args_dict = self.config.postprocessor.postprocessor_sweep + self.setup_flag = False + + def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict): + if not self.setup_flag: + activation_log = [] + net.eval() + with torch.no_grad(): + for batch in tqdm(id_loader_dict['val'], + desc='Setup: ', + position=0, + leave=True): + data = batch['data'].cuda() + data = data.float() + + _, feature = net(data, return_feature=True) + activation_log.append(feature.data.cpu().numpy()) + + self.activation_log = np.concatenate(activation_log, axis=0) + self.setup_flag = True + else: + pass + + self.threshold_high = np.percentile(self.activation_log.flatten(), + self.percentile_high) + self.threshold_low = np.percentile(self.activation_log.flatten(), + self.percentile_low) + + @torch.no_grad() + def postprocess(self, net: nn.Module, data: Any): + _, feature_ood = net.forward(data, return_feature=True) + feature_ood = feature_ood.clip(min=self.threshold_low, + max=self.threshold_high) + feature_ood = feature_ood.view(feature_ood.size(0), -1) + logit_ood = net.fc(feature_ood) + score = torch.softmax(logit_ood, dim=1) + _, pred = torch.max(score, dim=1) + energyconf = torch.logsumexp(logit_ood.data.cpu(), dim=1) + return pred, energyconf + + def set_hyperparam(self, hyperparam: list): + self.percentile_high = hyperparam[0] + self.percentile_low = hyperparam[1] + self.threshold_high = np.percentile(self.activation_log.flatten(), + self.percentile_high) + self.threshold_low = np.percentile(self.activation_log.flatten(), + self.percentile_low) + print('Threshold at percentile {:2d} over id data is: {}'.format( + self.percentile_high, self.threshold_high)) + print('Threshold at percentile {:2d} over id data is: {}'.format( + self.percentile_low, self.threshold_low)) + + def get_hyperparam(self): + return [self.percentile_high, self.percentile_low] diff --git a/scripts/ood/vra/cifar100_test_ood_vra.sh b/scripts/ood/vra/cifar100_test_ood_vra.sh new file mode 100755 index 00000000..e55cc250 --- /dev/null +++ b/scripts/ood/vra/cifar100_test_ood_vra.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# sh scripts/ood/ash/cifar100_test_ood_ash.sh + +# GPU=1 +# CPU=1 +# node=73 +# jobname=openood + +PYTHONPATH='.':$PYTHONPATH \ +# srun -p dsta --mpi=pmi2 --gres=gpu:${GPU} -n1 \ +# --cpus-per-task=${CPU} --ntasks-per-node=${GPU} \ +# --kill-on-bad-exit=1 --job-name=${jobname} -w SG-IDC1-10-51-2-${node} \ + +python main.py \ + --config configs/datasets/cifar100/cifar100.yml \ + configs/datasets/cifar100/cifar100_ood.yml \ + configs/networks/resnet18_32x32.yml \ + configs/pipelines/test/test_ood.yml \ + configs/preprocessors/base_preprocessor.yml \ + configs/postprocessors/vra.yml \ + --network.checkpoint 'results/cifar100_resnet18_32x32_base_e100_lr0.1_default/s0/best.ckpt' + +############################################ +# alternatively, we recommend using the +# new unified, easy-to-use evaluator with +# the example script scripts/eval_ood.py +# especially if you want to get results from +# multiple runs +python scripts/eval_ood.py \ + --id-data cifar100 \ + --root ./results/cifar100_resnet18_32x32_base_e100_lr0.1_default \ + --postprocessor vra \ + --save-score --save-csv diff --git a/scripts/ood/vra/cifar10_test_ood_vra.sh b/scripts/ood/vra/cifar10_test_ood_vra.sh new file mode 100755 index 00000000..c73c91c8 --- /dev/null +++ b/scripts/ood/vra/cifar10_test_ood_vra.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# sh scripts/ood/she/cifar10_test_ood_she.sh + +# GPU=1 +# CPU=1 +# node=73 +# jobname=openood + +PYTHONPATH='.':$PYTHONPATH \ +# srun -p dsta --mpi=pmi2 --gres=gpu:${GPU} -n1 \ +# --cpus-per-task=${CPU} --ntasks-per-node=${GPU} \ +# --kill-on-bad-exit=1 --job-name=${jobname} -w SG-IDC1-10-51-2-${node} \ + +python main.py \ + --config configs/datasets/cifar10/cifar10.yml \ + configs/datasets/cifar10/cifar10_ood.yml \ + configs/networks/resnet18_32x32.yml \ + configs/pipelines/test/test_ood.yml \ + configs/preprocessors/base_preprocessor.yml \ + configs/postprocessors/vra.yml \ + --num_workers 8 \ + --network.checkpoint 'results/cifar10_resnet18_32x32_base_e100_lr0.1_default/s0/best.ckpt' \ + --mark 1 + +############################################ +# alternatively, we recommend using the +# new unified, easy-to-use evaluator with +# the example script scripts/eval_ood.py +# especially if you want to get results from +# multiple runs +python scripts/eval_ood.py \ + --id-data cifar10 \ + --root ./results/cifar10_resnet18_32x32_base_e100_lr0.1_default \ + --postprocessor vra \ + --save-score --save-csv diff --git a/scripts/ood/vra/imagenet200_test_ood_vra.sh b/scripts/ood/vra/imagenet200_test_ood_vra.sh new file mode 100755 index 00000000..d1c9df37 --- /dev/null +++ b/scripts/ood/vra/imagenet200_test_ood_vra.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# sh scripts/ood/ash/imagenet200_test_ood_ash.sh + +############################################ +# alternatively, we recommend using the +# new unified, easy-to-use evaluator with +# the example script scripts/eval_ood.py +# especially if you want to get results from +# multiple runs + +# ood +python scripts/eval_ood.py \ + --id-data imagenet200 \ + --root ./results/imagenet200_resnet18_224x224_base_e90_lr0.1_default \ + --postprocessor vra \ + --save-score --save-csv #--fsood + +# full-spectrum ood +python scripts/eval_ood.py \ + --id-data imagenet200 \ + --root ./results/imagenet200_resnet18_224x224_base_e90_lr0.1_default \ + --postprocessor vra \ + --save-score --save-csv --fsood diff --git a/scripts/ood/vra/imagenet_test_ood_vra.sh b/scripts/ood/vra/imagenet_test_ood_vra.sh new file mode 100755 index 00000000..609438da --- /dev/null +++ b/scripts/ood/vra/imagenet_test_ood_vra.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# sh scripts/ood/ash/imagenet_test_ood_ash.sh + +GPU=1 +CPU=1 +node=63 +jobname=openood + +PYTHONPATH='.':$PYTHONPATH \ +# srun -p dsta --mpi=pmi2 --gres=gpu:${GPU} -n1 \ +# --cpus-per-task=${CPU} --ntasks-per-node=${GPU} \ +# --kill-on-bad-exit=1 --job-name=${jobname} -w SG-IDC1-10-51-2-${node} \ +# python main.py \ +# --config configs/datasets/imagenet/imagenet.yml \ +# configs/datasets/imagenet/imagenet_ood.yml \ +# configs/networks/resnet50.yml \ +# configs/pipelines/test/test_ood.yml \ +# configs/preprocessors/base_preprocessor.yml \ +# configs/postprocessors/gen.yml \ +# --num_workers 4 \ +# --ood_dataset.image_size 256 \ +# --dataset.test.batch_size 256 \ +# --dataset.val.batch_size 256 \ +# --network.pretrained True \ +# --network.checkpoint 'results/pretrained_weights/resnet50_imagenet1k_v1.pth' \ +# --merge_option merge + +############################################ +# we recommend using the +# new unified, easy-to-use evaluator with +# the example script scripts/eval_ood_imagenet.py + +# available architectures: +# resnet50, swin-t, vit-b-16 +# ood +python scripts/eval_ood_imagenet.py \ + --tvs-pretrained \ + --arch resnet50 \ + --postprocessor vra \ + --save-score --save-csv #--fsood + +# full-spectrum ood +python scripts/eval_ood_imagenet.py \ + --tvs-pretrained \ + --arch resnet50 \ + --postprocessor vra \ + --save-score --save-csv --fsood