9
9
class _DeepEnsembles (nn .Module ):
10
10
def __init__ (
11
11
self ,
12
- models : list [nn .Module ],
12
+ core_models : list [nn .Module ],
13
13
store_on_cpu : bool = False ,
14
14
) -> None :
15
15
"""Create a classification deep ensembles from a list of models."""
16
16
super ().__init__ ()
17
- self .core_models = nn .ModuleList (models )
18
- self .num_estimators = len (models )
17
+ self .core_models = nn .ModuleList (core_models )
18
+ self .num_estimators = len (core_models )
19
19
self .store_on_cpu = store_on_cpu
20
20
21
21
def forward (self , x : Tensor ) -> Tensor :
@@ -52,11 +52,11 @@ class _RegDeepEnsembles(_DeepEnsembles):
52
52
def __init__ (
53
53
self ,
54
54
probabilistic : bool ,
55
- models : list [nn .Module ],
55
+ core_models : list [nn .Module ],
56
56
store_on_cpu : bool = False ,
57
57
) -> None :
58
58
"""Create a regression deep ensembles from a list of models."""
59
- super ().__init__ (models = models , store_on_cpu = store_on_cpu )
59
+ super ().__init__ (core_models = core_models , store_on_cpu = store_on_cpu )
60
60
self .probabilistic = probabilistic
61
61
62
62
def forward (self , x : Tensor ) -> Tensor | dict [str , Tensor ]:
@@ -87,7 +87,7 @@ def forward(self, x: Tensor) -> Tensor | dict[str, Tensor]:
87
87
88
88
89
89
def deep_ensembles (
90
- models : list [nn .Module ] | nn .Module ,
90
+ core_models : list [nn .Module ] | nn .Module ,
91
91
num_estimators : int | None = None ,
92
92
task : Literal [
93
93
"classification" , "regression" , "segmentation" , "pixel_regression"
@@ -101,12 +101,12 @@ def deep_ensembles(
101
101
"""Build a Deep Ensembles out of the original models.
102
102
103
103
Args:
104
- models (list[nn.Module] | nn.Module): The model to be ensembled.
104
+ core_models (list[nn.Module] | nn.Module): The model to be ensembled.
105
105
num_estimators (int | None): The number of estimators in the ensemble.
106
106
task (Literal[``"classification"``, ``"regression"``, ``"segmentation"``, ``"pixel_regression"``]): The model task. Defaults to ``"classification"``.
107
107
probabilistic (bool): Whether the regression model is probabilistic.
108
108
reset_model_parameters (bool): Whether to reset the model parameters
109
- when :attr:models is a module or a list of length 1. Defaults to ``True``.
109
+ when :attr:core_models is a module or a list of length 1. Defaults to ``True``.
110
110
store_on_cpu (bool): Whether to store the models on CPU. Defaults to ``False``.
111
111
This is useful for large models that do not fit in GPU memory. Only one
112
112
model will be stored on GPU at a time during forward. The rest will be stored on CPU.
@@ -140,26 +140,28 @@ def deep_ensembles(
140
140
<https://arxiv.org/abs/1612.01474>`_.
141
141
142
142
"""
143
- if isinstance (models , list ) and len (models ) == 0 :
143
+ if isinstance (core_models , list ) and len (core_models ) == 0 :
144
144
raise ValueError ("Models must not be an empty list." )
145
- if (isinstance (models , list ) and len (models ) == 1 ) or isinstance (models , nn .Module ):
145
+ if (isinstance (core_models , list ) and len (core_models ) == 1 ) or isinstance (
146
+ core_models , nn .Module
147
+ ):
146
148
if num_estimators is None :
147
149
raise ValueError ("if models is a module, num_estimators must be specified." )
148
150
if num_estimators < 2 :
149
151
raise ValueError (f"num_estimators must be at least 2. Got { num_estimators } ." )
150
152
151
- if isinstance (models , list ):
152
- models = models [0 ]
153
+ if isinstance (core_models , list ):
154
+ core_models = core_models [0 ]
153
155
154
- models = [copy .deepcopy (models ) for _ in range (num_estimators )]
156
+ core_models = [copy .deepcopy (core_models ) for _ in range (num_estimators )]
155
157
156
158
if reset_model_parameters :
157
- for model in models :
159
+ for model in core_models :
158
160
for layer in model .modules ():
159
161
if hasattr (layer , "reset_parameters" ):
160
162
layer .reset_parameters ()
161
163
162
- elif isinstance (models , list ) and len (models ) > 1 and num_estimators is not None :
164
+ elif isinstance (core_models , list ) and len (core_models ) > 1 and num_estimators is not None :
163
165
raise ValueError ("num_estimators must be None if you provided a non-singleton list." )
164
166
165
167
if ckpt_paths is not None : # coverage: ignore
@@ -175,11 +177,11 @@ def deep_ensembles(
175
177
if len (ckpt_paths ) == 0 :
176
178
raise ValueError ("No checkpoint files found in the directory." )
177
179
178
- if len (models ) != len (ckpt_paths ):
180
+ if len (core_models ) != len (ckpt_paths ):
179
181
raise ValueError (
180
182
"The number of models and the number of checkpoint paths must be the same."
181
183
)
182
- for model , ckpt_path in zip (models , ckpt_paths , strict = True ):
184
+ for model , ckpt_path in zip (core_models , ckpt_paths , strict = True ):
183
185
if isinstance (ckpt_path , str | Path ):
184
186
loaded_data = torch .load (ckpt_path , map_location = "cpu" )
185
187
if "state_dict" in loaded_data :
@@ -198,12 +200,12 @@ def deep_ensembles(
198
200
199
201
match task :
200
202
case "classification" | "segmentation" :
201
- return _DeepEnsembles (models = models , store_on_cpu = store_on_cpu )
203
+ return _DeepEnsembles (core_models = core_models , store_on_cpu = store_on_cpu )
202
204
case "regression" | "pixel_regression" :
203
205
if probabilistic is None :
204
206
raise ValueError ("probabilistic must be specified for regression models." )
205
207
return _RegDeepEnsembles (
206
- probabilistic = probabilistic , models = models , store_on_cpu = store_on_cpu
208
+ probabilistic = probabilistic , core_models = core_models , store_on_cpu = store_on_cpu
207
209
)
208
210
case _:
209
211
raise ValueError (f"Unknown task: { task } ." )
0 commit comments