Skip to content

Commit ce5e50c

Browse files
committed
style(nyz): polish dreamerv3 code style and add readme link
1 parent 1074bab commit ce5e50c

File tree

10 files changed

+54
-37
lines changed

10 files changed

+54
-37
lines changed

README.md

+11-10
Original file line numberDiff line numberDiff line change
@@ -246,16 +246,17 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
246246
| 41 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [CQL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/cql.html)<br>[policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
247247
| 42 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [TD3BC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3_bc.html)<br>[policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u d4rl_td3_bc_main.py |
248248
| 43 | [Decision Transformer](https://arxiv.org/pdf/2106.01345.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/dt](https://github.com/opendilab/DI-engine/blob/main/ding/policy/decision_transformer.py) | python3 -u d4rl_dt_main.py |
249-
| 44 | MBSAC([SAC](https://arxiv.org/abs/1801.01290)+[MVE](https://arxiv.org/abs/1803.00101)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_mbsac_mbpo_config.py \ python3 -u pendulum_mbsac_ddppo_config.py |
250-
| 45 | STEVESAC([SAC](https://arxiv.org/abs/1801.01290)+[STEVE](https://arxiv.org/abs/1807.01675)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_stevesac_mbpo_config.py |
251-
| 46 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [MBPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/mbpo.html)<br>[world_model/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/mbpo.py) | python3 -u pendulum_sac_mbpo_config.py |
252-
| 47 | [DDPPO](https://openreview.net/forum?id=rzvOQrnclO0) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/ddppo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/ddppo.py) | python3 -u pendulum_mbsac_ddppo_config.py |
253-
| 48 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
254-
| 49 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
255-
| 50 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
256-
| 51 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
257-
| 52 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
258-
| 53 | [edac](https://arxiv.org/pdf/2110.01548.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [EDAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/edac.html)<br>[policy/edac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/edac.py) | python3 -u d4rl_edac_main.py |
249+
| 44 | [EDAC](https://arxiv.org/pdf/2110.01548.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [EDAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/edac.html)<br>[policy/edac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/edac.py) | python3 -u d4rl_edac_main.py |
250+
| 45 | MBSAC([SAC](https://arxiv.org/abs/1801.01290)+[MVE](https://arxiv.org/abs/1803.00101)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_mbsac_mbpo_config.py \ python3 -u pendulum_mbsac_ddppo_config.py |
251+
| 46 | STEVESAC([SAC](https://arxiv.org/abs/1801.01290)+[STEVE](https://arxiv.org/abs/1807.01675)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_stevesac_mbpo_config.py |
252+
| 47 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [MBPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/mbpo.html)<br>[world_model/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/mbpo.py) | python3 -u pendulum_sac_mbpo_config.py |
253+
| 48 | [DDPPO](https://openreview.net/forum?id=rzvOQrnclO0) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/ddppo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/ddppo.py) | python3 -u pendulum_mbsac_ddppo_config.py |
254+
| 49 | [DreamerV3](https://arxiv.org/pdf/2301.04104.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/dreamerv3](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/dreamerv3.py) | python3 -u cartpole_balance_dreamer_config.py |
255+
| 50 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
256+
| 51 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
257+
| 52 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
258+
| 53 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
259+
| 54 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
259260
</details>
260261

261262

ding/entry/serial_entry_mbrl.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -283,34 +283,43 @@ def serial_pipeline_dreamer(
283283
collect_kwargs = commander.step()
284284
# eval the policy
285285
if evaluator.should_eval(collector.envstep):
286-
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep, policy_kwargs=dict(world_model=world_model))
286+
stop, reward = evaluator.eval(
287+
learner.save_checkpoint,
288+
learner.train_iter,
289+
collector.envstep,
290+
policy_kwargs=dict(world_model=world_model)
291+
)
287292
if stop:
288293
break
289-
294+
290295
# train world model and fill imagination buffer
291296
steps = (
292297
cfg.world_model.pretrain
293-
if world_model.should_pretrain()
294-
else int(world_model.should_train(collector.envstep))
298+
if world_model.should_pretrain() else int(world_model.should_train(collector.envstep))
295299
)
296300
for _ in range(steps):
297301
batch_size = learner.policy.get_attribute('batch_size')
298302
batch_length = cfg.policy.learn.batch_length
299-
post, context = world_model.train(env_buffer, collector.envstep, learner.train_iter, batch_size, batch_length)
300-
303+
post, context = world_model.train(
304+
env_buffer, collector.envstep, learner.train_iter, batch_size, batch_length
305+
)
306+
301307
start = post
302-
308+
303309
learner.train(
304310
start, collector.envstep, policy_kwargs=dict(world_model=world_model, envstep=collector.envstep)
305311
)
306-
312+
307313
# fill environment buffer
308-
data = collector.collect(train_iter=learner.train_iter, policy_kwargs=dict(world_model=world_model, envstep=collector.envstep, **collect_kwargs))
314+
data = collector.collect(
315+
train_iter=learner.train_iter,
316+
policy_kwargs=dict(world_model=world_model, envstep=collector.envstep, **collect_kwargs)
317+
)
309318
env_buffer.push(data, cur_collector_envstep=collector.envstep)
310319

311320
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
312321
break
313322

314323
learner.call_hook('after_run')
315324

316-
return policy
325+
return policy

ding/entry/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ def random_collect(
6060
new_data = collector.collect(n_episode=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs)
6161
else:
6262
new_data = collector.collect(
63-
n_sample=policy_cfg.random_collect_size, random_collect=True,
64-
record_random_collect=False, policy_kwargs=collect_kwargs
63+
n_sample=policy_cfg.random_collect_size,
64+
random_collect=True,
65+
record_random_collect=False,
66+
policy_kwargs=collect_kwargs
6567
) # 'record_random_collect=False' means random collect without output log
6668
if postprocess_data_fn is not None:
6769
new_data = postprocess_data_fn(new_data)

ding/envs/env_wrappers/env_wrappers.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -182,16 +182,15 @@ def observation(self, frame):
182182
import sys
183183
logging.warning("Please install opencv-python first.")
184184
sys.exit(1)
185-
# to do
186-
# channel_first
185+
# deal with channel_first case
187186
if frame.shape[0] < 10:
188187
frame = frame.transpose(1, 2, 0)
189188
frame = cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA)
190189
frame = frame.transpose(2, 0, 1)
191190
else:
192191
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
193192
frame = cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA)
194-
193+
195194
return frame
196195

197196

@@ -265,6 +264,7 @@ def reward(self, reward):
265264
"""
266265
return np.sign(reward)
267266

267+
268268
@ENV_WRAPPER_REGISTRY.register('action_repeat')
269269
class ActionRepeatWrapper(gym.Wrapper):
270270
"""
@@ -275,7 +275,7 @@ class ActionRepeatWrapper(gym.Wrapper):
275275
Properties:
276276
- env (:obj:`gym.Env`): the environment to wrap.
277277
- ``action_repeat``
278-
278+
279279
"""
280280

281281
def __init__(self, env, action_repeat=1):

ding/policy/command_mode_policy_instance.py

+1
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ class MBSACCommandModePolicy(MBSACPolicy, DummyCommandModePolicy):
303303
class STEVESACCommandModePolicy(STEVESACPolicy, DummyCommandModePolicy):
304304
pass
305305

306+
306307
@POLICY_REGISTRY.register('dreamer_command')
307308
class DREAMERCommandModePolicy(DREAMERPolicy, DummyCommandModePolicy):
308309
pass

ding/worker/collector/interaction_serial_evaluator.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def eval(
190190
envstep: int = -1,
191191
n_episode: Optional[int] = None,
192192
force_render: bool = False,
193+
policy_kwargs: Optional[Dict] = {},
193194
) -> Tuple[bool, Dict[str, List]]:
194195
'''
195196
Overview:
@@ -228,7 +229,9 @@ def eval(
228229
eval_monitor.update_video(self._env.ready_imgs)
229230

230231
if self._policy_cfg.type == 'dreamer_command':
231-
policy_output = self._policy.forward(obs, **policy_kwargs, reset=self._resets, state=self._states)
232+
policy_output = self._policy.forward(
233+
obs, **policy_kwargs, reset=self._resets, state=self._states
234+
)
232235
#self._states = {env_id: output['state'] for env_id, output in policy_output.items()}
233236
self._states = [output['state'] for output in policy_output.values()]
234237
else:
@@ -317,4 +320,4 @@ def eval(
317320
stop_flag, episode_info = objects
318321

319322
episode_info = to_item(episode_info)
320-
return stop_flag, episode_info
323+
return stop_flag, episode_info

ding/worker/collector/sample_serial_collector.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def collect(
266266
self._states = [output['state'] for output in policy_output.values()]
267267
else:
268268
policy_output = self._policy.forward(obs, **policy_kwargs)
269-
self._policy_output_pool.update(policy_output)
269+
self._policy_output_pool.update(policy_output)
270270
# Interact with env.
271271
actions = {env_id: output['action'] for env_id, output in policy_output.items()}
272272
actions = to_ndarray(actions)
@@ -410,4 +410,4 @@ def _output_log(self, train_iter: int) -> None:
410410
self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
411411
if k in ['total_envstep_count']:
412412
continue
413-
self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)
413+
self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)

ding/worker/replay_buffer/naive_buffer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def _get_indices(self, size: int, sequence: int, sample_range: slice = None, rep
541541
indices.append(np.random.randint(episode * 500, episode * 500 + available + 1))
542542
batch += 1
543543
else:
544-
raise NotImplemented("sample_range is not implemented in this version")
544+
raise NotImplementedError("sample_range is not implemented in this version")
545545
return indices
546546

547547
def _sample_with_indices(self, indices: List[int], sequence: int, cur_learner_iter: int) -> list:

ding/world_model/dreamer.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def __init__(self, cfg, env, tb_logger):
143143
dist="binary",
144144
device=self._cfg.device,
145145
)
146-
146+
147147
if self._cuda:
148148
self.cuda()
149149
# to do
@@ -164,7 +164,9 @@ def should_pretrain(self):
164164

165165
def train(self, env_buffer, envstep, train_iter, batch_size, batch_length):
166166
self.last_train_step = envstep
167-
data = env_buffer.sample(batch_size, batch_length, train_iter) # [len=B, ele=[len=T, ele={dict_key: Tensor(any_dims)}]]
167+
data = env_buffer.sample(
168+
batch_size, batch_length, train_iter
169+
) # [len=B, ele=[len=T, ele={dict_key: Tensor(any_dims)}]]
168170
data = default_collate(data) # -> [len=T, ele={dict_key: Tensor(B, any_dims)}]
169171
data = lists_to_dicts(data, recursive=True) # -> {some_key: T lists}, each list is [B, some_dim]
170172
data = {k: torch.stack(data[k], dim=1) for k in data} # -> {dict_key: Tensor([B, T, any_dims])}
@@ -186,7 +188,7 @@ def train(self, env_buffer, envstep, train_iter, batch_size, batch_length):
186188
image = data['image'].reshape([-1] + list(data['image'].shape[-3:]))
187189
embed = self.encoder(image)
188190
embed = embed.reshape(list(data['image'].shape[:-3]) + [embed.shape[-1]])
189-
191+
190192
post, prior = self.dynamics.observe(embed, data["action"])
191193
kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss(
192194
post, prior, self._cfg.kl_forward, self._cfg.kl_free, self._cfg.kl_lscale, self._cfg.kl_rscale
@@ -209,7 +211,7 @@ def train(self, env_buffer, envstep, train_iter, batch_size, batch_length):
209211
self.optimizer.zero_grad()
210212
model_loss.backward()
211213
self.optimizer.step()
212-
214+
213215
self.requires_grad_(requires_grad=False)
214216
# log
215217
if self.tb_logger is not None:

ding/world_model/tests/test_dreamer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ding.utils import deep_merge_dicts
88

99
# arguments
10-
state_size = [3,64,64]
10+
state_size = [3, 64, 64]
1111
action_size = [6, 1]
1212
args = list(product(*[state_size, action_size]))
1313

@@ -30,4 +30,3 @@ def test_train(self, state_size, action_size):
3030
actions = torch.rand(1280, action_size)
3131

3232
model = self.get_world_model(state_size, action_size)
33-

0 commit comments

Comments
 (0)