Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(xrk): add new env named Flozen Lake and DQN algorithm. #781

Merged
merged 12 commits into from
Mar 13, 2024
399 changes: 203 additions & 196 deletions README.md

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions ding/example/dqn_frozen_lake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from ditk import logging
from ding.model import DQN
from ding.policy import DQNPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
eps_greedy_handler, CkptSaver, nstep_reward_enhancer, final_ctx_saver
from ding.utils import set_pkg_seed
from dizoo.frozen_lake.config.frozen_lake_dqn_config import main_config, create_config
from dizoo.frozen_lake.envs import FrozenLakeEnv


def main():
logging.getLogger().setLevel(logging.INFO)
main_config.policy.nstep = 5
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
env_fn=[lambda: FrozenLakeEnv(cfg=cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: FrozenLakeEnv(cfg=cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = DQN(**cfg.policy.model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = DQNPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(eps_greedy_handler(cfg))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(nstep_reward_enhancer(cfg))
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(final_ctx_saver(cfg.exp_name))
task.run()


if __name__ == "__main__":
main()
Binary file added dizoo/frozen_lake/FrozenLake.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added dizoo/frozen_lake/__init__.py
Empty file.
1 change: 1 addition & 0 deletions dizoo/frozen_lake/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .frozen_lake_dqn_config import main_config, create_config
64 changes: 64 additions & 0 deletions dizoo/frozen_lake/config/frozen_lake_dqn_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from easydict import EasyDict

frozen_lake_dqn_config = dict(
exp_name='frozen_lake_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=10,
env_id='FrozenLake-v1',
desc=None,
map_name="4x4",
is_slippery=False,
save_replay_gif=False,
),
policy=dict(
cuda=True,
load_path='frozen_lake_seed0/ckpt/ckpt_best.pth.tar',
model=dict(
obs_shape=16,
action_shape=4,
encoder_hidden_size_list=[128, 128, 64],
dueling=True,
),
nstep=3,
discount_factor=0.97,
learn=dict(
update_per_collect=5,
batch_size=256,
learning_rate=0.001,
),
collect=dict(n_sample=10),
eval=dict(evaluator=dict(eval_freq=40, )),
other=dict(
eps=dict(
type='exp',
start=0.8,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=20000, ),
),
),
)

frozen_lake_dqn_config = EasyDict(frozen_lake_dqn_config)
main_config = frozen_lake_dqn_config

frozen_lake_dqn_create_config = dict(
env=dict(
type='frozen_lake',
import_names=['dizoo.frozen_lake.envs.frozen_lake_env'],
),
env_manager=dict(type='base'),
policy=dict(type='dqn'),
replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']),
)

frozen_lake_dqn_create_config = EasyDict(frozen_lake_dqn_create_config)
create_config = frozen_lake_dqn_create_config

if __name__ == "__main__":
# or you can enter `ding -m serial -c frozen_lake_dqn_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), max_env_step=5000, seed=0)
1 change: 1 addition & 0 deletions dizoo/frozen_lake/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .frozen_lake_env import FrozenLakeEnv
144 changes: 144 additions & 0 deletions dizoo/frozen_lake/envs/frozen_lake_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from typing import Any, Dict, List, Optional
import imageio
import os
import gymnasium as gymn
import numpy as np
from ding.envs import BaseEnv, BaseEnvTimestep
from ding.torch_utils import to_ndarray
from ding.utils import ENV_REGISTRY


@ENV_REGISTRY.register('frozen_lake')
class FrozenLakeEnv(BaseEnv):

def __init__(self, cfg) -> None:
self._cfg = cfg
assert self._cfg.env_id == "FrozenLake-v1", "yout name is not FrozernLake_v1"
self._init_flag = False
self._save_replay_bool = False
self._save_replay_count = 0
self._init_flag = False
self._frames = []
self._replay_path = False

def reset(self) -> np.ndarray:
if not self._init_flag:
if not self._cfg.desc: #specify maps non-preloaded maps
self._env = gymn.make(
self._cfg.env_id,
desc=self._cfg.desc,
map_name=self._cfg.map_name,
is_slippery=self._cfg.is_slippery,
render_mode="rgb_array"
)
self._observation_space = self._env.observation_space
self._action_space = self._env.action_space
self._reward_space = gymn.spaces.Box(
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
)
self._init_flag = True
self._eval_episode_return = 0
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._env_seed = self._seed + np_seed
elif hasattr(self, '_seed'):
self._env_seed = self._seed
if hasattr(self, '_seed'):
obs, info = self._env.reset(seed=self._env_seed)
else:
obs, info = self._env.reset()
obs = np.eye(16, dtype=np.float32)[obs - 1]
return obs

def close(self) -> None:
if self._init_flag:
self._env.close()
self._init_flag = False

def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)

def step(self, action: Dict) -> BaseEnvTimestep:
obs, rew, terminated, truncated, info = self._env.step(action[0])
self._eval_episode_return += rew
obs = np.eye(16, dtype=np.float32)[obs - 1]
rew = to_ndarray([rew])
if self._save_replay_bool:
picture = self._env.render()
self._frames.append(picture)
if terminated or truncated:
done = True
else:
done = False
if done:
info['eval_episode_return'] = self._eval_episode_return
if self._save_replay_bool:
assert self._replay_path is not None, "your should have a path"
path = os.path.join(
self._replay_path, '{}_episode_{}.gif'.format(self._cfg.env_id, self._save_replay_count)
)
self.frames_to_gif(self._frames, path)
self._frames = []
self._save_replay_count += 1
rew = rew.astype(np.float32)
return BaseEnvTimestep(obs, rew, done, info)

def random_action(self) -> Dict:
raw_action = self._env.action_space.sample()
my_type = type(self._env.action_space)
return [raw_action]

def __repr__(self) -> str:
return "DI-engine Frozen Lake Env"

@property
def observation_space(self) -> gymn.spaces.Space:
return self._observation_space

@property
def action_space(self) -> gymn.spaces.Space:
return self._action_space

@property
def reward_space(self) -> gymn.spaces.Space:
return self._reward_space

def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './video'
self._replay_path = replay_path
self._save_replay_bool = True
self._save_replay_count = 0
self._frames = []

@staticmethod
def frames_to_gif(frames: List[imageio.core.util.Array], gif_path: str, duration: float = 0.1) -> None:
"""
Convert a list of frames into a GIF.
Args:
- frames (List[imageio.core.util.Array]): A list of frames, each frame is an image.
- gif_path (str): The path to save the GIF file.
- duration (float): Duration between each frame in the GIF (seconds).

Returns:
None, the GIF file is saved directly to the specified path.
"""
# Save all frames as temporary image files
temp_image_files = []
for i, frame in enumerate(frames):
temp_image_file = f"frame_{i}.png" # Temporary file name
imageio.imwrite(temp_image_file, frame) # Save the frame as a PNG file
temp_image_files.append(temp_image_file)

# Use imageio to convert temporary image files to GIF
with imageio.get_writer(gif_path, mode='I', duration=duration) as writer:
for temp_image_file in temp_image_files:
image = imageio.imread(temp_image_file)
writer.append_data(image)

# Clean up temporary image files
for temp_image_file in temp_image_files:
os.remove(temp_image_file)
print(f"GIF saved as {gif_path}")
44 changes: 44 additions & 0 deletions dizoo/frozen_lake/envs/test_frozen_lake_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import pytest
from dizoo.frozen_lake.envs import FrozenLakeEnv
from easydict import EasyDict


@pytest.mark.envtest
class TestGymHybridEnv:

def test_my_lake(self):
env = FrozenLakeEnv(
EasyDict({
'env_id': 'FrozenLake-v1',
'desc': None,
'map_name': "4x4",
'is_slippery': False,
})
)
for _ in range(5):
env.seed(314, dynamic_seed=False)
assert env._seed == 314
obs = env.reset()
assert obs.shape == (
16,
), "Considering the one-hot encoding format, your observation should have a dimensionality of 16."
for i in range(10):
env.enable_save_replay("./video")
# Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
# can generate legal random action.
if i < 5:
random_action = np.array([env.action_space.sample()])
else:
random_action = env.random_action()
timestep = env.step(random_action)
print(timestep)
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
assert timestep.obs.shape == (16, )
assert timestep.reward.shape == (1, )
assert timestep.reward >= env.reward_space.low
assert timestep.reward <= env.reward_space.high

print(env.observation_space, env.action_space, env.reward_space)
env.close()
Loading