Skip to content

Commit 0577694

Browse files
committed
attempt to fix coverage
1 parent db3f66b commit 0577694

File tree

3 files changed

+106
-22
lines changed

3 files changed

+106
-22
lines changed

tests/datamodules/classification/test_cifar100.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,73 @@ def _fake_get_ood(**_):
254254
assert idx["near_oods"] == [3]
255255
assert idx["far_oods"] == [4]
256256
assert idx["shift"] == []
257+
258+
def test_assigns_dataset_name_when_missing(self, monkeypatch):
259+
"""If OOD datasets lack `dataset_name`, setup() should assign class-name.lower()."""
260+
dm = CIFAR100DataModule(
261+
root="./data/",
262+
batch_size=16,
263+
train_transform=nn.Identity(),
264+
test_transform=nn.Identity(),
265+
eval_ood=True,
266+
)
267+
dm.dataset = DummyClassificationDataset
268+
dm.shift_dataset = DummyClassificationDataset
269+
270+
class _NoNameDS:
271+
def __init__(
272+
self, root="./data/", train=False, download=False, transform=None, num_images=3
273+
):
274+
self.data = list(range(num_images))
275+
self.transform = transform
276+
277+
def __len__(self):
278+
return len(self.data)
279+
280+
def __getitem__(self, i):
281+
x = self.data[i]
282+
return (x if self.transform is None else self.transform(x)), 0
283+
284+
def _mock_get_ood(**_):
285+
test_ood = _NoNameDS(num_images=3)
286+
val_ood = _NoNameDS(num_images=4)
287+
near_default = {"nearA": _NoNameDS(num_images=5)}
288+
far_default = {"farB": _NoNameDS(num_images=6)}
289+
return test_ood, val_ood, near_default, far_default
290+
291+
monkeypatch.setattr(
292+
"torch_uncertainty.datamodules.classification.cifar100.get_ood_datasets",
293+
_mock_get_ood,
294+
)
295+
296+
dm.setup("test")
297+
298+
assert hasattr(dm.val_ood, "dataset_name")
299+
assert dm.val_ood.dataset_name == "_nonameds"
300+
for ds in dm.near_oods + dm.far_oods:
301+
assert hasattr(ds, "dataset_name")
302+
assert ds.dataset_name == "_nonameds"
303+
304+
assert dm.near_ood_names == [ds.dataset_name for ds in dm.near_oods]
305+
assert dm.far_ood_names == [ds.dataset_name for ds in dm.far_oods]
306+
307+
def test_get_indices_empty_when_eval_ood_false(self):
308+
dm = CIFAR100DataModule(
309+
root="./data/",
310+
batch_size=16,
311+
train_transform=nn.Identity(),
312+
test_transform=nn.Identity(),
313+
eval_ood=False,
314+
eval_shift=False,
315+
)
316+
dm.dataset = DummyClassificationDataset
317+
dm.shift_dataset = DummyClassificationDataset
318+
dm.setup("test")
319+
320+
idx = dm.get_indices()
321+
assert idx["test"] == [0]
322+
assert idx["test_ood"] == []
323+
assert idx["val_ood"] == []
324+
assert idx["near_oods"] == []
325+
assert idx["far_oods"] == []
326+
assert idx["shift"] == []

tests/routines/test_classification.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,3 +655,31 @@ def test_shift_ens_update_path(self):
655655
)
656656
)
657657
routine.test_step((x, y), batch_idx=0, dataloader_idx=7)
658+
659+
def test_logs_when_eval_flags_mismatch_datamodule(self, caplog):
660+
model = dummy_ood_model(in_channels=3, feat_dim=64, num_classes=3)
661+
routine = ClassificationRoutine(
662+
model=model, loss=None, num_classes=3, eval_ood=False, eval_shift=False
663+
)
664+
routine.ood_criterion = MaxSoftmaxCriterion()
665+
666+
class _DM:
667+
def get_indices(self):
668+
return {"val_ood": 9, "near_oods": [2], "far_oods": [3], "shift": [4]}
669+
670+
routine._trainer = types.SimpleNamespace(barebones=True, datamodule=_DM())
671+
672+
x = torch.rand(2, 3, 8, 8)
673+
y = torch.tensor([0, 1])
674+
675+
with caplog.at_level(logging.INFO):
676+
routine.test_step((x, y), batch_idx=0, dataloader_idx=0)
677+
678+
assert any(
679+
"`eval_ood` to `True` in the datamodule and not in the routine" in r.message
680+
for r in caplog.records
681+
)
682+
assert any(
683+
"`eval_shift` to `True` in the datamodule and not in the routine" in r.message
684+
for r in caplog.records
685+
)

torch_uncertainty/routines/classification.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -440,35 +440,21 @@ def _hyperparam_search_ood(self):
440440
# ID val
441441
for x, _ in id_val:
442442
x = x.to(self.device)
443-
logits = self.model(x)
444-
445-
if crit.input_type == OODCriterionInputType.LOGIT:
446-
s = crit(logits).cpu().numpy()
447-
elif crit.input_type == OODCriterionInputType.PROB:
448-
probs = F.softmax(logits, dim=-1)
449-
s = crit(probs).cpu().numpy()
450-
else: # DATASET
451-
with torch.inference_mode(False), torch.enable_grad():
452-
x_input = x.detach().clone().requires_grad_(True)
453-
s = crit(self.model, x_input).cpu().numpy()
443+
444+
with torch.inference_mode(False), torch.enable_grad():
445+
x_input = x.detach().clone().requires_grad_(True)
446+
s = crit(self.model, x_input).cpu().numpy()
454447

455448
all_scores.append(s)
456449
all_labels.append(np.zeros_like(s))
457450

458451
# OODval splits
459452
for x, _ in ood_val:
460453
x = x.to(self.device)
461-
logits = self.model(x)
462-
463-
if crit.input_type == OODCriterionInputType.LOGIT:
464-
s = crit(logits).cpu().numpy()
465-
elif crit.input_type == OODCriterionInputType.PROB:
466-
probs = F.softmax(logits, dim=-1)
467-
s = crit(probs).cpu().numpy()
468-
else: # DATASET
469-
with torch.inference_mode(False), torch.enable_grad():
470-
x_input = x.detach().clone().requires_grad_(True)
471-
s = crit(self.model, x_input).cpu().numpy()
454+
455+
with torch.inference_mode(False), torch.enable_grad():
456+
x_input = x.detach().clone().requires_grad_(True)
457+
s = crit(self.model, x_input).cpu().numpy()
472458

473459
all_scores.append(s)
474460
all_labels.append(np.ones_like(s))

0 commit comments

Comments
 (0)