Skip to content

Commit db3f66b

Browse files
committed
attempt to fix coverage
1 parent 619b365 commit db3f66b

File tree

2 files changed

+208
-0
lines changed

2 files changed

+208
-0
lines changed

tests/datamodules/classification/test_imagenet.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,74 @@ def test_get_indices_without_ood_or_shift(self, tmp_path):
425425
assert idx["near_oods"] == []
426426
assert idx["far_oods"] == []
427427
assert idx["shift"] == []
428+
429+
def test_tta_wraps_ood_sets(self, monkeypatch, tmp_path):
430+
mod_name = ImageNetDataModule.__module__
431+
432+
monkeypatch.setattr(
433+
f"{mod_name}.download_and_extract_hf_dataset",
434+
self._fake_download_and_extract,
435+
raising=True,
436+
)
437+
monkeypatch.setattr(
438+
f"{mod_name}.download_and_extract_splits_from_hf",
439+
self._fake_download_and_extract_splits_from_hf,
440+
raising=True,
441+
)
442+
monkeypatch.setattr(
443+
f"{mod_name}.FileListDataset",
444+
self._DummyFileListDataset,
445+
raising=True,
446+
)
447+
monkeypatch.setattr(
448+
"torch_uncertainty.datamodules.classification.imagenet.get_ood_datasets",
449+
self._fake_get_ood_datasets,
450+
)
451+
452+
dm = ImageNetDataModule(
453+
root=tmp_path,
454+
batch_size=8,
455+
eval_ood=True,
456+
num_tta=2,
457+
train_transform=nn.Identity(),
458+
test_transform=nn.Identity(),
459+
num_workers=0,
460+
persistent_workers=False,
461+
pin_memory=False,
462+
)
463+
464+
dm.setup("test")
465+
466+
val_ood_wrapped = dm.get_val_ood_set()
467+
test_ood_wrapped = dm.get_test_ood_set()
468+
near_wrapped = dm.get_near_ood_set()
469+
far_wrapped = dm.get_far_ood_set()
470+
471+
assert len(val_ood_wrapped) == len(dm.val_ood) * dm.num_tta
472+
assert len(test_ood_wrapped) == len(dm.test_ood) * dm.num_tta
473+
assert all(
474+
len(w) == len(b) * dm.num_tta for w, b in zip(near_wrapped, dm.near_oods, strict=False)
475+
)
476+
assert all(
477+
len(w) == len(b) * dm.num_tta for w, b in zip(far_wrapped, dm.far_oods, strict=False)
478+
)
479+
480+
def _assert_first_block_repeat(wrapped_ds, num_tta: int):
481+
if len(wrapped_ds) == 0 or num_tta < 2:
482+
return
483+
x0, y0 = wrapped_ds[0]
484+
x1, y1 = wrapped_ds[1]
485+
assert y0 == y1
486+
if torch.is_tensor(x0) and torch.is_tensor(x1):
487+
assert x0.shape == x1.shape
488+
else:
489+
assert type(x0) is type(x1)
490+
if hasattr(x0, "size") and hasattr(x1, "size"):
491+
assert x0.size == x1.size
492+
493+
_assert_first_block_repeat(val_ood_wrapped, dm.num_tta)
494+
_assert_first_block_repeat(test_ood_wrapped, dm.num_tta)
495+
for w in near_wrapped:
496+
_assert_first_block_repeat(w, dm.num_tta)
497+
for w in far_wrapped:
498+
_assert_first_block_repeat(w, dm.num_tta)

tests/routines/test_classification.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
import types
13
from pathlib import Path
24

35
import pytest
@@ -15,6 +17,7 @@
1517
from torch_uncertainty.losses import DECLoss, ELBOLoss
1618
from torch_uncertainty.ood.ood_criteria import (
1719
EntropyCriterion,
20+
MaxSoftmaxCriterion,
1821
)
1922
from torch_uncertainty.post_processing import ConformalClsTHR
2023
from torch_uncertainty.routines import ClassificationRoutine
@@ -518,3 +521,137 @@ def patched_setup(self, stage=None):
518521
for needs_setup in {"react", "adascale_a", "vim", "knn", "nnguide"}:
519522
if crit == needs_setup:
520523
assert getattr(c, "setup_flag", False), f"Setup not executed for '{crit}'."
524+
525+
def test_setup_logs_when_no_train_loader(self, caplog, monkeypatch):
526+
dm = DummyClassificationDataModule(
527+
root=Path(),
528+
batch_size=4,
529+
num_classes=3,
530+
num_images=16,
531+
eval_ood=True,
532+
)
533+
534+
def _raise_train_loader(*_a, **_k):
535+
raise RuntimeError("no train loader")
536+
537+
monkeypatch.setattr(
538+
ClassificationRoutine, "_hyperparam_search_ood", lambda _self: None, raising=True
539+
)
540+
monkeypatch.setattr(dm, "train_dataloader", _raise_train_loader, raising=True)
541+
542+
model = dummy_ood_model(in_channels=3, feat_dim=64, num_classes=3)
543+
routine = ClassificationRoutine(
544+
model=model,
545+
loss=None,
546+
num_classes=3,
547+
eval_ood=True,
548+
)
549+
routine.ood_criterion = MaxSoftmaxCriterion() # no setup() side-effects
550+
551+
routine.trainer = types.SimpleNamespace(datamodule=dm)
552+
553+
with caplog.at_level(logging.INFO):
554+
routine.setup("test")
555+
assert any("No train loader detected" in r.message for r in caplog.records)
556+
557+
def test_create_near_far_metric_dicts_non_ensemble(self):
558+
model = dummy_ood_model(in_channels=3, feat_dim=64, num_classes=3)
559+
routine = ClassificationRoutine(
560+
model=model, loss=None, num_classes=3, eval_ood=True, is_ensemble=False
561+
)
562+
routine.ood_criterion = MaxSoftmaxCriterion()
563+
564+
x = torch.rand(4, 3, 8, 8)
565+
y = torch.tensor([0, 1, 2, 0])
566+
567+
class _DS:
568+
def __init__(self, name):
569+
self.dataset_name = name
570+
571+
routine.trainer = types.SimpleNamespace(
572+
datamodule=types.SimpleNamespace(
573+
get_indices=lambda: {"val_ood": 9, "near_oods": [2], "far_oods": [3], "shift": []},
574+
near_oods=[_DS("nearX")],
575+
far_oods=[_DS("farY")],
576+
)
577+
)
578+
579+
routine.test_step((x, y), batch_idx=0, dataloader_idx=2) # near
580+
assert "nearX" in routine.test_ood_metrics_near
581+
582+
routine.test_step((x, y), batch_idx=0, dataloader_idx=3) # far
583+
assert "farY" in routine.test_ood_metrics_far
584+
585+
def test_create_near_far_metric_dicts_ensemble_and_aggregator(self):
586+
model = dummy_ood_model(in_channels=3, feat_dim=64, num_classes=3)
587+
routine = ClassificationRoutine(
588+
model=model, loss=None, num_classes=3, eval_ood=True, is_ensemble=True
589+
)
590+
routine.ood_criterion = MaxSoftmaxCriterion()
591+
592+
x = torch.rand(4, 3, 8, 8)
593+
y = torch.tensor([0, 1, 2, 0])
594+
595+
class _DS:
596+
def __init__(self, name):
597+
self.dataset_name = name
598+
599+
routine.trainer = types.SimpleNamespace(
600+
datamodule=types.SimpleNamespace(
601+
get_indices=lambda: {
602+
"val_ood": 9,
603+
"near_oods": [5],
604+
"far_oods": [6],
605+
"shift": [7],
606+
},
607+
near_oods=[_DS("n1")],
608+
far_oods=[_DS("f1")],
609+
)
610+
)
611+
612+
routine.test_step((x, y), batch_idx=0, dataloader_idx=1) # aggregator
613+
assert "n1" in routine.test_ood_ens_metrics_near
614+
assert "f1" in routine.test_ood_ens_metrics_far
615+
616+
routine.test_step((x, y), batch_idx=0, dataloader_idx=5) # near
617+
routine.test_step((x, y), batch_idx=0, dataloader_idx=6) # far
618+
assert "n1" in routine.test_ood_ens_metrics_near
619+
assert "f1" in routine.test_ood_ens_metrics_far
620+
621+
def test_skip_when_val_ood_loader(self):
622+
model = dummy_ood_model(in_channels=3, feat_dim=64, num_classes=3)
623+
routine = ClassificationRoutine(model=model, loss=None, num_classes=3, eval_ood=True)
624+
routine.ood_criterion = MaxSoftmaxCriterion()
625+
626+
routine.trainer = types.SimpleNamespace(
627+
datamodule=types.SimpleNamespace(
628+
get_indices=lambda: {"val_ood": 4, "near_oods": [], "far_oods": [], "shift": []}
629+
)
630+
)
631+
x = torch.rand(2, 3, 8, 8)
632+
y = torch.tensor([0, 1])
633+
routine.test_step((x, y), batch_idx=0, dataloader_idx=4)
634+
635+
def test_init_metrics_creates_shift_ens_metrics_when_ensemble_and_eval_shift(self):
636+
model = dummy_ood_model(in_channels=3, feat_dim=64, num_classes=3)
637+
routine = ClassificationRoutine(
638+
model=model, loss=None, num_classes=3, eval_shift=True, is_ensemble=True
639+
)
640+
assert hasattr(routine, "test_shift_ens_metrics")
641+
642+
def test_shift_ens_update_path(self):
643+
model = dummy_ood_model(in_channels=3, feat_dim=64, num_classes=3)
644+
routine = ClassificationRoutine(
645+
model=model, loss=None, num_classes=3, eval_shift=True, is_ensemble=True
646+
)
647+
routine.ood_criterion = MaxSoftmaxCriterion()
648+
649+
x = torch.rand(4, 3, 8, 8)
650+
y = torch.tensor([0, 1, 2, 0])
651+
652+
routine.trainer = types.SimpleNamespace(
653+
datamodule=types.SimpleNamespace(
654+
get_indices=lambda: {"val_ood": 99, "near_oods": [], "far_oods": [], "shift": [7]}
655+
)
656+
)
657+
routine.test_step((x, y), batch_idx=0, dataloader_idx=7)

0 commit comments

Comments
 (0)