31
31
from omnisafe .common .control_barrier_function .crabs .models import (
32
32
AddGaussianNoise ,
33
33
CrabsCore ,
34
- EnsembleModel ,
35
34
ExplorationPolicy ,
36
- GatedTransitionModel ,
37
35
MeanPolicy ,
38
36
MultiLayerPerceptron ,
39
- TransitionModel ,
40
37
UniformPolicy ,
41
38
)
42
39
from omnisafe .common .control_barrier_function .crabs .optimizers import (
46
43
SLangevinOptimizer ,
47
44
StateBox ,
48
45
)
49
- from omnisafe .common .control_barrier_function .crabs .utils import Normalizer , get_pretrained_model
46
+ from omnisafe .common .control_barrier_function .crabs .utils import (
47
+ Normalizer ,
48
+ create_model_and_trainer ,
49
+ get_pretrained_model ,
50
+ )
50
51
from omnisafe .models .actor_critic .constraint_actor_q_critic import ConstraintActorQCritic
51
52
52
53
@@ -115,48 +116,13 @@ def _init_model(self) -> None:
115
116
).to (self ._device )
116
117
self .mean_policy = MeanPolicy (self ._actor_critic .actor )
117
118
118
- if self ._cfgs .transition_model_cfgs .type == 'GatedTransitionModel' :
119
-
120
- def make_model (i ):
121
- return GatedTransitionModel (
122
- self .dim_state ,
123
- self .normalizer ,
124
- [self .dim_state + self .dim_action , 256 , 256 , 256 , 256 , self .dim_state * 2 ],
125
- self ._cfgs .transition_model_cfgs .train ,
126
- name = f'model-{ i } ' ,
127
- )
128
-
129
- self .model = EnsembleModel (
130
- [make_model (i ) for i in range (self ._cfgs .transition_model_cfgs .n_ensemble )],
131
- ).to (self ._device )
132
- self .model_trainer = pl .Trainer (
133
- max_epochs = 0 ,
134
- accelerator = 'gpu' ,
135
- devices = [int (str (self ._device )[- 1 ])],
136
- default_root_dir = self ._cfgs .logger_cfgs .log_dir ,
137
- )
138
- elif self ._cfgs .transition_model_cfgs .type == 'TransitionModel' :
139
-
140
- def make_model (i ):
141
- return TransitionModel (
142
- self .dim_state ,
143
- self .normalizer ,
144
- [self .dim_state + self .dim_action , 256 , 256 , 256 , 256 , self .dim_state * 2 ],
145
- self ._cfgs .transition_model_cfgs .train ,
146
- name = f'model-{ i } ' ,
147
- )
148
-
149
- self .model = EnsembleModel (
150
- [make_model (i ) for i in range (self ._cfgs .transition_model_cfgs .n_ensemble )],
151
- ).to (self ._device )
152
- self .model_trainer = pl .Trainer (
153
- max_epochs = 0 ,
154
- accelerator = 'gpu' ,
155
- devices = [int (str (self ._device )[- 1 ])],
156
- default_root_dir = self ._cfgs .logger_cfgs .log_dir ,
157
- )
158
- else :
159
- raise AssertionError (f'unknown model type { self ._cfgs .transition_model_cfgs .type } ' )
119
+ self .model , self .model_trainer = create_model_and_trainer (
120
+ self ._cfgs ,
121
+ self .dim_state ,
122
+ self .dim_action ,
123
+ self .normalizer ,
124
+ self ._device ,
125
+ )
160
126
161
127
def _init_log (self ) -> None :
162
128
super ()._init_log ()
@@ -167,9 +133,18 @@ def _init_log(self) -> None:
167
133
what_to_save ['obs_normalizer' ] = self .normalizer
168
134
self ._logger .setup_torch_saver (what_to_save )
169
135
self ._logger .torch_save ()
170
- self ._logger .register_key ('Metrics/RawPolicyEpRet' , window_length = 50 )
171
- self ._logger .register_key ('Metrics/RawPolicyEpCost' , window_length = 50 )
172
- self ._logger .register_key ('Metrics/RawPolicyEpLen' , window_length = 50 )
136
+ self ._logger .register_key (
137
+ 'Metrics/RawPolicyEpRet' ,
138
+ window_length = self ._cfgs .logger_cfgs .window_lens ,
139
+ )
140
+ self ._logger .register_key (
141
+ 'Metrics/RawPolicyEpCost' ,
142
+ window_length = self ._cfgs .logger_cfgs .window_lens ,
143
+ )
144
+ self ._logger .register_key (
145
+ 'Metrics/RawPolicyEpLen' ,
146
+ window_length = self ._cfgs .logger_cfgs .window_lens ,
147
+ )
173
148
174
149
def _init (self ) -> None :
175
150
"""The initialization of the algorithm.
@@ -282,7 +257,7 @@ def learn(self):
282
257
eval_start = time .time ()
283
258
self ._env .eval_policy (
284
259
episode = self ._cfgs .train_cfgs .raw_policy_episodes ,
285
- agent = self ._actor_critic ,
260
+ agent = self .mean_policy ,
286
261
logger = self ._logger ,
287
262
)
288
263
@@ -330,7 +305,7 @@ def learn(self):
330
305
eval_start = time .time ()
331
306
self ._env .eval_policy (
332
307
episode = self ._cfgs .train_cfgs .raw_policy_episodes ,
333
- agent = self .mean_policy , # type: ignore
308
+ agent = self .mean_policy ,
334
309
logger = self ._logger ,
335
310
)
336
311
eval_time += time .time () - eval_start
0 commit comments