Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ distance: f4d5b3 -->
> - [x] [![gen](https://img.shields.io/badge/CVPR'23-GEN-fdd7e6?style=for-the-badge)](https://openaccess.thecvf.com/content/CVPR2023/papers/Liu_GEN_Pushing_the_Limits_of_Softmax-Based_Out-of-Distribution_Detection_CVPR_2023_paper.pdf)    ![postprocess]
> - [x] [![nnguide](https://img.shields.io/badge/ICCV'23-NNGuide-fdd7e6?style=for-the-badge)](https://arxiv.org/abs/2309.14888)    ![postprocess]
> - [x] [![relation](https://img.shields.io/badge/NEURIPS'23-Relation-fdd7e6?style=for-the-badge)](https://arxiv.org/abs/2301.12321)    ![postprocess]
> - [x] [![vra](https://img.shields.io/badge/NeurIPS'23-VRA-fdd7e6?style=for-the-badge)](https://github.com/zeroQiaoba/VRA)    ![postprocess]
> - [x] [![scale](https://img.shields.io/badge/ICLR'24-Scale-fdd7e6?style=for-the-badge)](https://github.com/kai422/SCALE)    ![postprocess]
> - [x] [![fdbd](https://img.shields.io/badge/ICML'24-fDBD-f4d5b3?style=for-the-badge)](https://github.com/litianliu/fDBD-OOD)    ![postprocess]
> - [x] [![adascale-a](https://img.shields.io/badge/arXiv'25-AdaScale\_A-fdd7e6?style=for-the-badge)](https://github.com/sudarshanregmi/adascale)    ![postprocess]
Expand Down
9 changes: 9 additions & 0 deletions configs/postprocessors/vra.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
postprocessor:
name: vra
APS_mode: True
postprocessor_args:
percentile_high: 90
percentile_low: 10
postprocessor_sweep:
percentile_high_list: [85, 90, 95, 99]
percentile_low_list: [1, 5, 10, 15]
4 changes: 3 additions & 1 deletion openood/evaluation_api/postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
RMDSPostprocessor, SHEPostprocessor, CIDERPostprocessor, NPOSPostprocessor,
GENPostprocessor, NNGuidePostprocessor, RelationPostprocessor,
T2FNormPostprocessor, ReweightOODPostprocessor, fDBDPostprocessor,
AdaScalePostprocessor, IODINPostprocessor, NCIPostprocessor)
AdaScalePostprocessor, IODINPostprocessor, NCIPostprocessor,
VRAPostprocessor)
from openood.utils.config import Config, merge_configs

postprocessors = {
Expand Down Expand Up @@ -69,6 +70,7 @@
'reweightood': ReweightOODPostprocessor,
'adascale_a': AdaScalePostprocessor,
'adascale_l': AdaScalePostprocessor,
'vra': VRAPostprocessor,
}

link_prefix = 'https://raw.githubusercontent.com/Jingkang50/OpenOOD/main/configs/postprocessors/'
Expand Down
1 change: 1 addition & 0 deletions openood/postprocessors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@
from .t2fnorm_postprocessor import T2FNormPostprocessor
from .reweightood_postprocessor import ReweightOODPostprocessor
from .adascale_postprocessor import AdaScalePostprocessor
from .vra_postprocessor import VRAPostprocessor
4 changes: 3 additions & 1 deletion openood/postprocessors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@
from .rts_postprocessor import RTSPostprocessor
from .gen_postprocessor import GENPostprocessor
from .relation_postprocessor import RelationPostprocessor
from .vra_postprocessor import VRAPostprocessor


def get_postprocessor(config: Config):
postprocessors = {
'nci': NCIPostprocessor,
'nci': NCIPostprocessor,
'fdbd': fDBDPostprocessor,
'ash': ASHPostprocessor,
'cider': CIDERPostprocessor,
Expand Down Expand Up @@ -90,6 +91,7 @@ def get_postprocessor(config: Config):
'gen': GENPostprocessor,
'relation': RelationPostprocessor,
't2fnorm': T2FNormPostprocessor,
'vra': VRAPostprocessor,
}

return postprocessors[config.postprocessor.name](config)
71 changes: 71 additions & 0 deletions openood/postprocessors/vra_postprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import Any

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

from .base_postprocessor import BasePostprocessor


class VRAPostprocessor(BasePostprocessor):

def __init__(self, config):
super(VRAPostprocessor, self).__init__(config)
self.args = self.config.postprocessor.postprocessor_args
self.percentile_high = self.args.percentile_high
self.percentile_low = self.args.percentile_low
self.args_dict = self.config.postprocessor.postprocessor_sweep
self.setup_flag = False

def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict):
if not self.setup_flag:
activation_log = []
net.eval()
with torch.no_grad():
for batch in tqdm(id_loader_dict['val'],
desc='Setup: ',
position=0,
leave=True):
data = batch['data'].cuda()
data = data.float()

_, feature = net(data, return_feature=True)
activation_log.append(feature.data.cpu().numpy())

self.activation_log = np.concatenate(activation_log, axis=0)
self.setup_flag = True
else:
pass

self.threshold_high = np.percentile(self.activation_log.flatten(),
self.percentile_high)
self.threshold_low = np.percentile(self.activation_log.flatten(),
self.percentile_low)

@torch.no_grad()
def postprocess(self, net: nn.Module, data: Any):
_, feature_ood = net.forward(data, return_feature=True)
feature_ood = feature_ood.clip(min=self.threshold_low,
max=self.threshold_high)
feature_ood = feature_ood.view(feature_ood.size(0), -1)
logit_ood = net.fc(feature_ood)
score = torch.softmax(logit_ood, dim=1)
_, pred = torch.max(score, dim=1)
energyconf = torch.logsumexp(logit_ood.data.cpu(), dim=1)
return pred, energyconf

def set_hyperparam(self, hyperparam: list):
self.percentile_high = hyperparam[0]
self.percentile_low = hyperparam[1]
self.threshold_high = np.percentile(self.activation_log.flatten(),
self.percentile_high)
self.threshold_low = np.percentile(self.activation_log.flatten(),
self.percentile_low)
print('Threshold at percentile {:2d} over id data is: {}'.format(
self.percentile_high, self.threshold_high))
print('Threshold at percentile {:2d} over id data is: {}'.format(
self.percentile_low, self.threshold_low))

def get_hyperparam(self):
return [self.percentile_high, self.percentile_low]
33 changes: 33 additions & 0 deletions scripts/ood/vra/cifar100_test_ood_vra.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/bin/bash
# sh scripts/ood/ash/cifar100_test_ood_ash.sh

# GPU=1
# CPU=1
# node=73
# jobname=openood

PYTHONPATH='.':$PYTHONPATH \
# srun -p dsta --mpi=pmi2 --gres=gpu:${GPU} -n1 \
# --cpus-per-task=${CPU} --ntasks-per-node=${GPU} \
# --kill-on-bad-exit=1 --job-name=${jobname} -w SG-IDC1-10-51-2-${node} \

python main.py \
--config configs/datasets/cifar100/cifar100.yml \
configs/datasets/cifar100/cifar100_ood.yml \
configs/networks/resnet18_32x32.yml \
configs/pipelines/test/test_ood.yml \
configs/preprocessors/base_preprocessor.yml \
configs/postprocessors/vra.yml \
--network.checkpoint 'results/cifar100_resnet18_32x32_base_e100_lr0.1_default/s0/best.ckpt'

############################################
# alternatively, we recommend using the
# new unified, easy-to-use evaluator with
# the example script scripts/eval_ood.py
# especially if you want to get results from
# multiple runs
python scripts/eval_ood.py \
--id-data cifar100 \
--root ./results/cifar100_resnet18_32x32_base_e100_lr0.1_default \
--postprocessor vra \
--save-score --save-csv
35 changes: 35 additions & 0 deletions scripts/ood/vra/cifar10_test_ood_vra.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash
# sh scripts/ood/she/cifar10_test_ood_she.sh

# GPU=1
# CPU=1
# node=73
# jobname=openood

PYTHONPATH='.':$PYTHONPATH \
# srun -p dsta --mpi=pmi2 --gres=gpu:${GPU} -n1 \
# --cpus-per-task=${CPU} --ntasks-per-node=${GPU} \
# --kill-on-bad-exit=1 --job-name=${jobname} -w SG-IDC1-10-51-2-${node} \

python main.py \
--config configs/datasets/cifar10/cifar10.yml \
configs/datasets/cifar10/cifar10_ood.yml \
configs/networks/resnet18_32x32.yml \
configs/pipelines/test/test_ood.yml \
configs/preprocessors/base_preprocessor.yml \
configs/postprocessors/vra.yml \
--num_workers 8 \
--network.checkpoint 'results/cifar10_resnet18_32x32_base_e100_lr0.1_default/s0/best.ckpt' \
--mark 1

############################################
# alternatively, we recommend using the
# new unified, easy-to-use evaluator with
# the example script scripts/eval_ood.py
# especially if you want to get results from
# multiple runs
python scripts/eval_ood.py \
--id-data cifar10 \
--root ./results/cifar10_resnet18_32x32_base_e100_lr0.1_default \
--postprocessor vra \
--save-score --save-csv
23 changes: 23 additions & 0 deletions scripts/ood/vra/imagenet200_test_ood_vra.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash
# sh scripts/ood/ash/imagenet200_test_ood_ash.sh

############################################
# alternatively, we recommend using the
# new unified, easy-to-use evaluator with
# the example script scripts/eval_ood.py
# especially if you want to get results from
# multiple runs

# ood
python scripts/eval_ood.py \
--id-data imagenet200 \
--root ./results/imagenet200_resnet18_224x224_base_e90_lr0.1_default \
--postprocessor vra \
--save-score --save-csv #--fsood

# full-spectrum ood
python scripts/eval_ood.py \
--id-data imagenet200 \
--root ./results/imagenet200_resnet18_224x224_base_e90_lr0.1_default \
--postprocessor vra \
--save-score --save-csv --fsood
47 changes: 47 additions & 0 deletions scripts/ood/vra/imagenet_test_ood_vra.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash
# sh scripts/ood/ash/imagenet_test_ood_ash.sh

GPU=1
CPU=1
node=63
jobname=openood

PYTHONPATH='.':$PYTHONPATH \
# srun -p dsta --mpi=pmi2 --gres=gpu:${GPU} -n1 \
# --cpus-per-task=${CPU} --ntasks-per-node=${GPU} \
# --kill-on-bad-exit=1 --job-name=${jobname} -w SG-IDC1-10-51-2-${node} \
# python main.py \
# --config configs/datasets/imagenet/imagenet.yml \
# configs/datasets/imagenet/imagenet_ood.yml \
# configs/networks/resnet50.yml \
# configs/pipelines/test/test_ood.yml \
# configs/preprocessors/base_preprocessor.yml \
# configs/postprocessors/gen.yml \
# --num_workers 4 \
# --ood_dataset.image_size 256 \
# --dataset.test.batch_size 256 \
# --dataset.val.batch_size 256 \
# --network.pretrained True \
# --network.checkpoint 'results/pretrained_weights/resnet50_imagenet1k_v1.pth' \
# --merge_option merge

############################################
# we recommend using the
# new unified, easy-to-use evaluator with
# the example script scripts/eval_ood_imagenet.py

# available architectures:
# resnet50, swin-t, vit-b-16
# ood
python scripts/eval_ood_imagenet.py \
--tvs-pretrained \
--arch resnet50 \
--postprocessor vra \
--save-score --save-csv #--fsood

# full-spectrum ood
python scripts/eval_ood_imagenet.py \
--tvs-pretrained \
--arch resnet50 \
--postprocessor vra \
--save-score --save-csv --fsood