Skip to content

Commit 080e6b8

Browse files
author
Borong Zhang
authored
feat(crabs): refine rendering for saved crabs policies (#330)
1 parent 93d4975 commit 080e6b8

File tree

13 files changed

+220
-78
lines changed

13 files changed

+220
-78
lines changed

omnisafe/adapter/crabs_adapter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from omnisafe.adapter.offpolicy_adapter import OffPolicyAdapter
2424
from omnisafe.common.buffer import VectorOffPolicyBuffer
25+
from omnisafe.common.control_barrier_function.crabs.models import MeanPolicy
2526
from omnisafe.common.logger import Logger
2627
from omnisafe.envs.crabs_env import CRABSEnv
2728
from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic
@@ -55,14 +56,15 @@ def __init__( # pylint: disable=too-many-arguments
5556
"""Initialize a instance of :class:`CRABSAdapter`."""
5657
super().__init__(env_id, num_envs, seed, cfgs)
5758
self._env: CRABSEnv
59+
self._eval_env: CRABSEnv
5860
self.n_expl_episodes = 0
5961
self._max_ep_len = self._env.env.spec.max_episode_steps # type: ignore
6062
self.horizon = self._max_ep_len
6163

6264
def eval_policy( # pylint: disable=too-many-locals
6365
self,
6466
episode: int,
65-
agent: ConstraintActorQCritic,
67+
agent: ConstraintActorQCritic | MeanPolicy,
6668
logger: Logger,
6769
) -> None:
6870
"""Rollout the environment with deterministic agent action.
@@ -74,13 +76,13 @@ def eval_policy( # pylint: disable=too-many-locals
7476
"""
7577
for _ in range(episode):
7678
ep_ret, ep_cost, ep_len = 0.0, 0.0, 0
77-
obs, _ = self._eval_env.reset() # type: ignore
79+
obs, _ = self._eval_env.reset()
7880
obs = obs.to(self._device)
7981

8082
done = False
8183
while not done:
82-
act = agent.step(obs, deterministic=False)
83-
obs, reward, cost, terminated, truncated, info = self._eval_env.step(act) # type: ignore
84+
act = agent.step(obs, deterministic=True)
85+
obs, reward, cost, terminated, truncated, info = self._eval_env.step(act)
8486
obs, reward, cost, terminated, truncated = (
8587
torch.as_tensor(x, dtype=torch.float32, device=self._device)
8688
for x in (obs, reward, cost, terminated, truncated)

omnisafe/algorithms/off_policy/crabs.py

Lines changed: 26 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,9 @@
3131
from omnisafe.common.control_barrier_function.crabs.models import (
3232
AddGaussianNoise,
3333
CrabsCore,
34-
EnsembleModel,
3534
ExplorationPolicy,
36-
GatedTransitionModel,
3735
MeanPolicy,
3836
MultiLayerPerceptron,
39-
TransitionModel,
4037
UniformPolicy,
4138
)
4239
from omnisafe.common.control_barrier_function.crabs.optimizers import (
@@ -46,7 +43,11 @@
4643
SLangevinOptimizer,
4744
StateBox,
4845
)
49-
from omnisafe.common.control_barrier_function.crabs.utils import Normalizer, get_pretrained_model
46+
from omnisafe.common.control_barrier_function.crabs.utils import (
47+
Normalizer,
48+
create_model_and_trainer,
49+
get_pretrained_model,
50+
)
5051
from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic
5152

5253

@@ -115,48 +116,13 @@ def _init_model(self) -> None:
115116
).to(self._device)
116117
self.mean_policy = MeanPolicy(self._actor_critic.actor)
117118

118-
if self._cfgs.transition_model_cfgs.type == 'GatedTransitionModel':
119-
120-
def make_model(i):
121-
return GatedTransitionModel(
122-
self.dim_state,
123-
self.normalizer,
124-
[self.dim_state + self.dim_action, 256, 256, 256, 256, self.dim_state * 2],
125-
self._cfgs.transition_model_cfgs.train,
126-
name=f'model-{i}',
127-
)
128-
129-
self.model = EnsembleModel(
130-
[make_model(i) for i in range(self._cfgs.transition_model_cfgs.n_ensemble)],
131-
).to(self._device)
132-
self.model_trainer = pl.Trainer(
133-
max_epochs=0,
134-
accelerator='gpu',
135-
devices=[int(str(self._device)[-1])],
136-
default_root_dir=self._cfgs.logger_cfgs.log_dir,
137-
)
138-
elif self._cfgs.transition_model_cfgs.type == 'TransitionModel':
139-
140-
def make_model(i):
141-
return TransitionModel(
142-
self.dim_state,
143-
self.normalizer,
144-
[self.dim_state + self.dim_action, 256, 256, 256, 256, self.dim_state * 2],
145-
self._cfgs.transition_model_cfgs.train,
146-
name=f'model-{i}',
147-
)
148-
149-
self.model = EnsembleModel(
150-
[make_model(i) for i in range(self._cfgs.transition_model_cfgs.n_ensemble)],
151-
).to(self._device)
152-
self.model_trainer = pl.Trainer(
153-
max_epochs=0,
154-
accelerator='gpu',
155-
devices=[int(str(self._device)[-1])],
156-
default_root_dir=self._cfgs.logger_cfgs.log_dir,
157-
)
158-
else:
159-
raise AssertionError(f'unknown model type {self._cfgs.transition_model_cfgs.type}')
119+
self.model, self.model_trainer = create_model_and_trainer(
120+
self._cfgs,
121+
self.dim_state,
122+
self.dim_action,
123+
self.normalizer,
124+
self._device,
125+
)
160126

161127
def _init_log(self) -> None:
162128
super()._init_log()
@@ -167,9 +133,18 @@ def _init_log(self) -> None:
167133
what_to_save['obs_normalizer'] = self.normalizer
168134
self._logger.setup_torch_saver(what_to_save)
169135
self._logger.torch_save()
170-
self._logger.register_key('Metrics/RawPolicyEpRet', window_length=50)
171-
self._logger.register_key('Metrics/RawPolicyEpCost', window_length=50)
172-
self._logger.register_key('Metrics/RawPolicyEpLen', window_length=50)
136+
self._logger.register_key(
137+
'Metrics/RawPolicyEpRet',
138+
window_length=self._cfgs.logger_cfgs.window_lens,
139+
)
140+
self._logger.register_key(
141+
'Metrics/RawPolicyEpCost',
142+
window_length=self._cfgs.logger_cfgs.window_lens,
143+
)
144+
self._logger.register_key(
145+
'Metrics/RawPolicyEpLen',
146+
window_length=self._cfgs.logger_cfgs.window_lens,
147+
)
173148

174149
def _init(self) -> None:
175150
"""The initialization of the algorithm.
@@ -282,7 +257,7 @@ def learn(self):
282257
eval_start = time.time()
283258
self._env.eval_policy(
284259
episode=self._cfgs.train_cfgs.raw_policy_episodes,
285-
agent=self._actor_critic,
260+
agent=self.mean_policy,
286261
logger=self._logger,
287262
)
288263

@@ -330,7 +305,7 @@ def learn(self):
330305
eval_start = time.time()
331306
self._env.eval_policy(
332307
episode=self._cfgs.train_cfgs.raw_policy_episodes,
333-
agent=self.mean_policy, # type: ignore
308+
agent=self.mean_policy,
334309
logger=self._logger,
335310
)
336311
eval_time += time.time() - eval_start

omnisafe/algorithms/off_policy/ddpg.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,32 @@ def _init_log(self) -> None:
197197
self._logger.setup_torch_saver(what_to_save)
198198
self._logger.torch_save()
199199

200-
self._logger.register_key('Metrics/EpRet', window_length=50)
201-
self._logger.register_key('Metrics/EpCost', window_length=50)
202-
self._logger.register_key('Metrics/EpLen', window_length=50)
200+
self._logger.register_key(
201+
'Metrics/EpRet',
202+
window_length=self._cfgs.logger_cfgs.window_lens,
203+
)
204+
self._logger.register_key(
205+
'Metrics/EpCost',
206+
window_length=self._cfgs.logger_cfgs.window_lens,
207+
)
208+
self._logger.register_key(
209+
'Metrics/EpLen',
210+
window_length=self._cfgs.logger_cfgs.window_lens,
211+
)
203212

204213
if self._cfgs.train_cfgs.eval_episodes > 0:
205-
self._logger.register_key('Metrics/TestEpRet', window_length=50)
206-
self._logger.register_key('Metrics/TestEpCost', window_length=50)
207-
self._logger.register_key('Metrics/TestEpLen', window_length=50)
214+
self._logger.register_key(
215+
'Metrics/TestEpRet',
216+
window_length=self._cfgs.logger_cfgs.window_lens,
217+
)
218+
self._logger.register_key(
219+
'Metrics/TestEpCost',
220+
window_length=self._cfgs.logger_cfgs.window_lens,
221+
)
222+
self._logger.register_key(
223+
'Metrics/TestEpLen',
224+
window_length=self._cfgs.logger_cfgs.window_lens,
225+
)
208226

209227
self._logger.register_key('Train/Epoch')
210228
self._logger.register_key('Train/LR')

omnisafe/algorithms/on_policy/base/policy_gradient.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,18 @@ def _init_log(self) -> None:
188188
self._logger.setup_torch_saver(what_to_save)
189189
self._logger.torch_save()
190190

191-
self._logger.register_key('Metrics/EpRet', window_length=50)
192-
self._logger.register_key('Metrics/EpCost', window_length=50)
193-
self._logger.register_key('Metrics/EpLen', window_length=50)
191+
self._logger.register_key(
192+
'Metrics/EpRet',
193+
window_length=self._cfgs.logger_cfgs.window_lens,
194+
)
195+
self._logger.register_key(
196+
'Metrics/EpCost',
197+
window_length=self._cfgs.logger_cfgs.window_lens,
198+
)
199+
self._logger.register_key(
200+
'Metrics/EpLen',
201+
window_length=self._cfgs.logger_cfgs.window_lens,
202+
)
194203

195204
self._logger.register_key('Train/Epoch')
196205
self._logger.register_key('Train/Entropy')

omnisafe/common/control_barrier_function/crabs/utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,22 @@
1414
# ==============================================================================
1515
"""Utils for CRABS."""
1616
# pylint: disable=all
17+
from __future__ import annotations
18+
1719
import os
1820

21+
import pytorch_lightning as pl
1922
import requests
2023
import torch
2124
import torch.nn as nn
2225
from torch import load
2326

27+
from omnisafe.common.control_barrier_function.crabs.models import (
28+
EnsembleModel,
29+
GatedTransitionModel,
30+
TransitionModel,
31+
)
32+
2433

2534
class Normalizer(nn.Module):
2635
"""Normalizes input data to have zero mean and unit variance.
@@ -119,3 +128,59 @@ def get_pretrained_model(model_path, model_url, device):
119128
print('Model found locally.')
120129

121130
return load(model_path, map_location=device)
131+
132+
133+
def create_model_and_trainer(cfgs, dim_state, dim_action, normalizer, device):
134+
"""Create world model and trainer.
135+
136+
Args:
137+
cfgs: Configs.
138+
dim_state: Dimension of the state.
139+
dim_action: Dimension of the action.
140+
normalizer: Observation normalizer.
141+
device: Device to load the model.
142+
143+
Returns:
144+
Tuple[nn.Module, pl.Trainer]: World model and trainer.
145+
"""
146+
147+
def make_model(i, model_type) -> nn.Module:
148+
if model_type == 'GatedTransitionModel':
149+
return GatedTransitionModel(
150+
dim_state,
151+
normalizer,
152+
[dim_state + dim_action, 256, 256, 256, 256, dim_state * 2],
153+
cfgs.transition_model_cfgs.train,
154+
name=f'model-{i}',
155+
)
156+
if model_type == 'TransitionModel':
157+
return TransitionModel(
158+
dim_state,
159+
normalizer,
160+
[dim_state + dim_action, 256, 256, 256, 256, dim_state * 2],
161+
cfgs.transition_model_cfgs.train,
162+
name=f'model-{i}',
163+
)
164+
raise AssertionError(f'unknown model type {model_type}')
165+
166+
model_type = cfgs.transition_model_cfgs.type
167+
models = [make_model(i, model_type) for i in range(cfgs.transition_model_cfgs.n_ensemble)]
168+
169+
model = EnsembleModel(models).to(device)
170+
171+
devices: list[int] | int
172+
173+
if str(device).startswith('cuda'):
174+
accelerator = 'gpu'
175+
devices = [int(str(device)[-1])]
176+
else:
177+
accelerator = 'cpu'
178+
devices = torch.get_num_threads()
179+
trainer = pl.Trainer(
180+
max_epochs=0,
181+
accelerator=accelerator,
182+
devices=devices,
183+
default_root_dir=cfgs.logger_cfgs.log_dir,
184+
)
185+
186+
return model, trainer

omnisafe/common/offline/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __init__( # pylint: disable=too-many-branches
126126
# Load data from local .npz file
127127
try:
128128
data = np.load(dataset_name)
129-
except Exception as e:
129+
except (ValueError, OSError) as e:
130130
raise ValueError(f'Failed to load data from {dataset_name}') from e
131131

132132
else:
@@ -284,7 +284,7 @@ def __init__( # pylint: disable=too-many-branches, super-init-not-called
284284
# Load data from local .npz file
285285
try:
286286
data = np.load(dataset_name)
287-
except Exception as e:
287+
except (ValueError, OSError) as e:
288288
raise ValueError(f'Failed to load data from {dataset_name}') from e
289289

290290
else:

omnisafe/configs/off-policy/CRABS.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ defaults:
8484
# save logger path
8585
log_dir: "./runs"
8686
# save model path
87-
window_lens: 10
87+
window_lens: 6
8888
# model configurations
8989
model_cfgs:
9090
# weight initialization mode

omnisafe/configs/off-policy/DDPG.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ defaults:
8181
# save logger path
8282
log_dir: "./runs"
8383
# save model path
84-
window_lens: 10
84+
window_lens: 50
8585
# model configurations
8686
model_cfgs:
8787
# weight initialization mode

omnisafe/configs/on-policy/PolicyGradient.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ defaults:
8787
# save logger path
8888
log_dir: "./runs"
8989
# save model path
90-
window_lens: 100
90+
window_lens: 50
9191
# model configurations
9292
model_cfgs:
9393
# weight initialization mode

omnisafe/envs/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
# ==============================================================================
1515
"""Environment API for OmniSafe."""
1616

17-
from contextlib import suppress
18-
1917
from omnisafe.envs import classic_control
2018
from omnisafe.envs.core import CMDP, env_register, make, support_envs
2119
from omnisafe.envs.crabs_env import CRABSEnv

0 commit comments

Comments
 (0)