Skip to content

Commit aa27f6c

Browse files
committed
🎨 Rename attribute model to core_model in all wrappers
1 parent e6ed38e commit aa27f6c

File tree

13 files changed

+76
-71
lines changed

13 files changed

+76
-71
lines changed

tests/models/wrappers/test_batch_ensemble.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def test_convert_layers(self) -> None:
4545
model = _DummyModel(in_features, out_features)
4646
wrapped_model = batch_ensemble(model, num_estimators, convert_layers=True)
4747
assert wrapped_model.num_estimators == num_estimators
48-
assert isinstance(wrapped_model.model.conv, BatchConv2d)
49-
assert isinstance(wrapped_model.model.fc, BatchLinear)
48+
assert isinstance(wrapped_model.core_model.conv, BatchConv2d)
49+
assert isinstance(wrapped_model.core_model.fc, BatchLinear)
5050

5151
def test_forward_pass(self, img_input) -> None:
5252
batch_size = img_input.size(0)

tests/models/wrappers/test_mc_dropout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def test_mc_dropout_errors(self) -> None:
135135
mc_dropout(model, num_estimators=5, task="regression")
136136

137137
with pytest.raises(ValueError, match="`num_estimators` must be strictly positive"):
138-
mc_dropout(model=model, num_estimators=-1, last_layer=True, on_batch=True)
138+
mc_dropout(core_model=model, num_estimators=-1, last_layer=True, on_batch=True)
139139

140140
dropout_model = mc_dropout(model, 5)
141141
with pytest.raises(TypeError, match="Training mode is expected to be boolean"):

torch_uncertainty/models/classification/lenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def batchensemble_lenet(
139139
dropout_rate=dropout_rate,
140140
)
141141
return BatchEnsemble(
142-
model=model,
142+
core_model=model,
143143
num_estimators=num_estimators,
144144
repeat_training_inputs=repeat_training_inputs,
145145
convert_layers=True,

torch_uncertainty/models/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _mlp(
173173
)
174174
if stochastic:
175175
return StochasticModel(
176-
model=model, num_samples=num_samples, probabilistic=dist_family is not None
176+
core_model=model, num_samples=num_samples, probabilistic=dist_family is not None
177177
)
178178
return model
179179

torch_uncertainty/models/wrappers/batch_ensemble.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
class BatchEnsemble(nn.Module):
99
def __init__(
1010
self,
11-
model: nn.Module,
11+
core_model: nn.Module,
1212
num_estimators: int,
1313
repeat_training_inputs: bool = False,
1414
convert_layers: bool = False,
@@ -23,7 +23,7 @@ def __init__(
2323
ensuring that each estimator receives the correct data format.
2424
2525
Args:
26-
model (nn.Module): The BatchEnsemble model.
26+
core_model (nn.Module): The BatchEnsemble model.
2727
num_estimators (int): Number of ensemble members.
2828
repeat_training_inputs (optional, bool): Whether to repeat the input batch during training.
2929
If ``True``, the input batch is repeated during both training and evaluation. If ``False``,
@@ -33,37 +33,37 @@ def __init__(
3333
BatchEnsemble counterparts. Default is ``False``.
3434
3535
Raises:
36-
ValueError: If neither ``BatchLinear`` nor ``BatchConv2d`` layers are found in the model at the
36+
ValueError: If neither ``BatchLinear`` nor ``BatchConv2d`` layers are found in the core_model at the
3737
end of initialization.
3838
ValueError: If ``num_estimators`` is less than or equal to ``0``.
3939
ValueError: If ``convert_layers=True`` and neither ``nn.Linear`` nor ``nn.Conv2d`` layers are
40-
found in the model.
40+
found in the core_model.
4141
4242
Warning:
4343
If ``convert_layers==True``, the wrapper will attempt to convert all ``nn.Linear`` and ``nn.Conv2d``
44-
layers in the model to their BatchEnsemble counterparts. If the model contains other types of
44+
layers in the core_model to their BatchEnsemble counterparts. If the core_model contains other types of
4545
layers, the conversion won't happen for these layers. If don't have any ``nn.Linear`` or ``nn.Conv2d``
46-
layers in the model, the wrapper will raise an error during conversion.
46+
layers in the core_model, the wrapper will raise an error during conversion.
4747
4848
Warning:
4949
If ``repeat_training_inputs==True`` and you want to use one of the ``torch_uncertainty.routines``
5050
for training, be sure to set ``format_batch_fn=RepeatTarget(num_repeats=num_estimators)`` when
5151
initializing the routine.
5252
5353
Example:
54-
>>> model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 2))
55-
>>> model = BatchEnsemble(model, num_estimators=4, convert_layers=True)
54+
>>> core_model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 2))
55+
>>> model = BatchEnsemble(core_model, num_estimators=4, convert_layers=True)
5656
>>> model
5757
BatchEnsemble(
58-
(model): Sequential(
58+
(core_model): Sequential(
5959
(0): BatchLinear(in_features=10, out_features=5, num_estimators=4)
6060
(1): ReLU()
6161
(2): BatchLinear(in_features=5, out_features=2, num_estimators=4)
6262
)
6363
)
6464
"""
6565
super().__init__()
66-
self.model = model
66+
self.core_model = core_model
6767
self.num_estimators = num_estimators
6868
self.repeat_training_inputs = repeat_training_inputs
6969

@@ -72,7 +72,7 @@ def __init__(
7272

7373
filtered_modules = [
7474
module
75-
for module in self.model.modules()
75+
for module in self.core_model.modules()
7676
if isinstance(module, BatchLinear | BatchConv2d)
7777
]
7878
_batch_ensemble_checks(filtered_modules, num_estimators)
@@ -81,22 +81,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8181
"""Repeat the input if ``self.training==False`` or ``repeat_training_inputs==True`` and pass it through the model."""
8282
if not self.training or self.repeat_training_inputs:
8383
x = repeat(x, "b ... -> (m b) ...", m=self.num_estimators)
84-
return self.model(x)
84+
return self.core_model(x)
8585

8686
def _convert_layers(self) -> None:
8787
"""Convert the model's layers to BatchEnsemble layers."""
8888
no_valid_layers = True
89-
for name, layer in self.model.named_modules():
89+
for name, layer in self.core_model.named_modules():
9090
if isinstance(layer, nn.Linear):
9191
setattr(
92-
self.model,
92+
self.core_model,
9393
name,
9494
BatchLinear.from_linear(layer, num_estimators=self.num_estimators),
9595
)
9696
no_valid_layers = False
9797
elif isinstance(layer, nn.Conv2d):
9898
setattr(
99-
self.model,
99+
self.core_model,
100100
name,
101101
BatchConv2d.from_conv2d(layer, num_estimators=self.num_estimators),
102102
)
@@ -121,15 +121,15 @@ def _batch_ensemble_checks(filtered_modules: list[nn.Module], num_estimators: in
121121

122122

123123
def batch_ensemble(
124-
model: nn.Module,
124+
core_model: nn.Module,
125125
num_estimators: int,
126126
repeat_training_inputs: bool = False,
127127
convert_layers: bool = False,
128128
) -> BatchEnsemble:
129129
"""BatchEnsemble wrapper for a model.
130130
131131
Args:
132-
model (nn.Module): model to wrap
132+
core_model (nn.Module): model to wrap
133133
num_estimators (int): number of ensemble members
134134
repeat_training_inputs (bool, optional): whether to repeat the input batch during training.
135135
If ``True``, the input batch is repeated during both training and evaluation. If ``False``,
@@ -139,10 +139,10 @@ def batch_ensemble(
139139
BatchEnsemble counterparts. Default is ``False``.
140140
141141
Returns:
142-
BatchEnsemble: BatchEnsemble wrapper for the model
142+
BatchEnsemble: BatchEnsemble wrapper for the :attr:`core_model`
143143
"""
144144
return BatchEnsemble(
145-
model=model,
145+
core_model=core_model,
146146
num_estimators=num_estimators,
147147
repeat_training_inputs=repeat_training_inputs,
148148
convert_layers=convert_layers,

torch_uncertainty/models/wrappers/checkpoint_collector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
class CheckpointCollector(nn.Module):
88
def __init__(
99
self,
10-
model: nn.Module,
10+
core_model: nn.Module,
1111
cycle_start: int | None = None,
1212
cycle_length: int | None = None,
1313
save_schedule: list[int] | None = None,
@@ -21,7 +21,7 @@ def __init__(
2121
as implemented in TorchUncertainty.
2222
2323
Args:
24-
model (nn.Module): The model to train and ensemble.
24+
core_model (nn.Module): The model to train and ensemble.
2525
cycle_start (int): Epoch to start ensembling. Defaults to ``None``.
2626
cycle_length (int): Number of epochs between model collections. Defaults to ``None``.
2727
save_schedule (list[int] | None): The epochs at which to save the model. Defaults to ``None``.
@@ -52,7 +52,7 @@ def __init__(
5252
f"The combination of arguments: cycle_start: {cycle_start}, cycle_length: {cycle_length}, save_schedule: {save_schedule} is not known."
5353
)
5454

55-
self.core_model = model
55+
self.core_model = core_model
5656
self.cycle_start = cycle_start
5757
self.cycle_length = cycle_length
5858
self.save_schedule = save_schedule

torch_uncertainty/models/wrappers/deep_ensembles.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
class _DeepEnsembles(nn.Module):
1010
def __init__(
1111
self,
12-
models: list[nn.Module],
12+
core_models: list[nn.Module],
1313
store_on_cpu: bool = False,
1414
) -> None:
1515
"""Create a classification deep ensembles from a list of models."""
1616
super().__init__()
17-
self.core_models = nn.ModuleList(models)
18-
self.num_estimators = len(models)
17+
self.core_models = nn.ModuleList(core_models)
18+
self.num_estimators = len(core_models)
1919
self.store_on_cpu = store_on_cpu
2020

2121
def forward(self, x: Tensor) -> Tensor:
@@ -52,11 +52,11 @@ class _RegDeepEnsembles(_DeepEnsembles):
5252
def __init__(
5353
self,
5454
probabilistic: bool,
55-
models: list[nn.Module],
55+
core_models: list[nn.Module],
5656
store_on_cpu: bool = False,
5757
) -> None:
5858
"""Create a regression deep ensembles from a list of models."""
59-
super().__init__(models=models, store_on_cpu=store_on_cpu)
59+
super().__init__(core_models=core_models, store_on_cpu=store_on_cpu)
6060
self.probabilistic = probabilistic
6161

6262
def forward(self, x: Tensor) -> Tensor | dict[str, Tensor]:
@@ -87,7 +87,7 @@ def forward(self, x: Tensor) -> Tensor | dict[str, Tensor]:
8787

8888

8989
def deep_ensembles(
90-
models: list[nn.Module] | nn.Module,
90+
core_models: list[nn.Module] | nn.Module,
9191
num_estimators: int | None = None,
9292
task: Literal[
9393
"classification", "regression", "segmentation", "pixel_regression"
@@ -101,12 +101,12 @@ def deep_ensembles(
101101
"""Build a Deep Ensembles out of the original models.
102102
103103
Args:
104-
models (list[nn.Module] | nn.Module): The model to be ensembled.
104+
core_models (list[nn.Module] | nn.Module): The model to be ensembled.
105105
num_estimators (int | None): The number of estimators in the ensemble.
106106
task (Literal[``"classification"``, ``"regression"``, ``"segmentation"``, ``"pixel_regression"``]): The model task. Defaults to ``"classification"``.
107107
probabilistic (bool): Whether the regression model is probabilistic.
108108
reset_model_parameters (bool): Whether to reset the model parameters
109-
when :attr:models is a module or a list of length 1. Defaults to ``True``.
109+
when :attr:core_models is a module or a list of length 1. Defaults to ``True``.
110110
store_on_cpu (bool): Whether to store the models on CPU. Defaults to ``False``.
111111
This is useful for large models that do not fit in GPU memory. Only one
112112
model will be stored on GPU at a time during forward. The rest will be stored on CPU.
@@ -140,26 +140,28 @@ def deep_ensembles(
140140
<https://arxiv.org/abs/1612.01474>`_.
141141
142142
"""
143-
if isinstance(models, list) and len(models) == 0:
143+
if isinstance(core_models, list) and len(core_models) == 0:
144144
raise ValueError("Models must not be an empty list.")
145-
if (isinstance(models, list) and len(models) == 1) or isinstance(models, nn.Module):
145+
if (isinstance(core_models, list) and len(core_models) == 1) or isinstance(
146+
core_models, nn.Module
147+
):
146148
if num_estimators is None:
147149
raise ValueError("if models is a module, num_estimators must be specified.")
148150
if num_estimators < 2:
149151
raise ValueError(f"num_estimators must be at least 2. Got {num_estimators}.")
150152

151-
if isinstance(models, list):
152-
models = models[0]
153+
if isinstance(core_models, list):
154+
core_models = core_models[0]
153155

154-
models = [copy.deepcopy(models) for _ in range(num_estimators)]
156+
core_models = [copy.deepcopy(core_models) for _ in range(num_estimators)]
155157

156158
if reset_model_parameters:
157-
for model in models:
159+
for model in core_models:
158160
for layer in model.modules():
159161
if hasattr(layer, "reset_parameters"):
160162
layer.reset_parameters()
161163

162-
elif isinstance(models, list) and len(models) > 1 and num_estimators is not None:
164+
elif isinstance(core_models, list) and len(core_models) > 1 and num_estimators is not None:
163165
raise ValueError("num_estimators must be None if you provided a non-singleton list.")
164166

165167
if ckpt_paths is not None: # coverage: ignore
@@ -175,11 +177,11 @@ def deep_ensembles(
175177
if len(ckpt_paths) == 0:
176178
raise ValueError("No checkpoint files found in the directory.")
177179

178-
if len(models) != len(ckpt_paths):
180+
if len(core_models) != len(ckpt_paths):
179181
raise ValueError(
180182
"The number of models and the number of checkpoint paths must be the same."
181183
)
182-
for model, ckpt_path in zip(models, ckpt_paths, strict=True):
184+
for model, ckpt_path in zip(core_models, ckpt_paths, strict=True):
183185
if isinstance(ckpt_path, str | Path):
184186
loaded_data = torch.load(ckpt_path, map_location="cpu")
185187
if "state_dict" in loaded_data:
@@ -198,12 +200,12 @@ def deep_ensembles(
198200

199201
match task:
200202
case "classification" | "segmentation":
201-
return _DeepEnsembles(models=models, store_on_cpu=store_on_cpu)
203+
return _DeepEnsembles(core_models=core_models, store_on_cpu=store_on_cpu)
202204
case "regression" | "pixel_regression":
203205
if probabilistic is None:
204206
raise ValueError("probabilistic must be specified for regression models.")
205207
return _RegDeepEnsembles(
206-
probabilistic=probabilistic, models=models, store_on_cpu=store_on_cpu
208+
probabilistic=probabilistic, core_models=core_models, store_on_cpu=store_on_cpu
207209
)
208210
case _:
209211
raise ValueError(f"Unknown task: {task}.")

torch_uncertainty/models/wrappers/ema.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
class EMA(nn.Module):
77
def __init__(
88
self,
9-
model: nn.Module,
9+
core_model: nn.Module,
1010
momentum: float,
1111
) -> None:
1212
"""Exponential Moving Average (EMA).
@@ -18,13 +18,13 @@ def __init__(
1818
The EMA model is regularly updated with the inner-model and used at evaluation time.
1919
2020
Args:
21-
model (nn.Module): The model to train and ensemble.
21+
core_model (nn.Module): The model to train and ensemble.
2222
momentum (float): The momentum of the moving average.
2323
"""
2424
super().__init__()
2525
_ema_checks(momentum)
26-
self.core_model = model
27-
self.ema_model = copy.deepcopy(model)
26+
self.core_model = core_model
27+
self.ema_model = copy.deepcopy(core_model)
2828
self.momentum = momentum
2929
self.remainder = 1 - momentum
3030

0 commit comments

Comments
 (0)