From d333aaf190aad7aa685c496db5d54424d485988b Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Sat, 14 Oct 2023 19:47:54 +0800 Subject: [PATCH 01/12] feature(yzj): add ptz with lightzero main --- lzero/mcts/buffer/game_buffer_muzero.py | 16 +- lzero/mcts/utils.py | 5 +- lzero/model/muzero_model_mlp.py | 67 +-- lzero/policy/muzero.py | 19 +- lzero/policy/utils.py | 14 + lzero/worker/muzero_collector.py | 9 +- lzero/worker/muzero_evaluator.py | 5 +- zoo/petting_zoo/__init__.py | 0 zoo/petting_zoo/config/__init__.py | 0 .../config/ptz_simple_spread_mz_config.py | 117 ++++++ zoo/petting_zoo/entry/__init__.py | 2 + zoo/petting_zoo/entry/eval_muzero.py | 81 ++++ zoo/petting_zoo/entry/train_muzero.py | 198 +++++++++ zoo/petting_zoo/envs/__init__.py | 0 .../envs/petting_zoo_simple_spread_env.py | 390 ++++++++++++++++++ .../test_petting_zoo_simple_spread_env.py | 133 ++++++ zoo/petting_zoo/model/__init__.py | 1 + zoo/petting_zoo/model/model.py | 291 +++++++++++++ 18 files changed, 1309 insertions(+), 39 deletions(-) create mode 100644 zoo/petting_zoo/__init__.py create mode 100644 zoo/petting_zoo/config/__init__.py create mode 100644 zoo/petting_zoo/config/ptz_simple_spread_mz_config.py create mode 100644 zoo/petting_zoo/entry/__init__.py create mode 100644 zoo/petting_zoo/entry/eval_muzero.py create mode 100644 zoo/petting_zoo/entry/train_muzero.py create mode 100644 zoo/petting_zoo/envs/__init__.py create mode 100644 zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py create mode 100644 zoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py create mode 100644 zoo/petting_zoo/model/__init__.py create mode 100644 zoo/petting_zoo/model/model.py diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index daddf6f9f..9fe01a8c2 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -9,6 +9,8 @@ from lzero.mcts.utils import prepare_observation from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer import GameBuffer +from ding.torch_utils import to_device, to_tensor +from ding.utils.data import default_collate if TYPE_CHECKING: from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy @@ -199,6 +201,10 @@ def _prepare_reward_value_context( td_steps_list, action_mask_segment, to_play_segment """ zero_obs = game_segment_list[0].zero_obs() + zero_obs = np.array([{'agent_state': np.zeros((3, 18), dtype=np.float32), + 'global_state': np.zeros((30,), dtype=np.float32), + 'agent_alone_state': np.zeros((3, 14), dtype=np.float32), + 'agent_alone_padding_state': np.zeros((3, 18), dtype=np.float32),}]) value_obs_list = [] # the value is valid or not (out of game_segment) value_mask = [] @@ -242,7 +248,7 @@ def _prepare_reward_value_context( value_mask.append(0) obs = zero_obs - value_obs_list.append(obs) + value_obs_list.append(obs.tolist()) reward_value_context = [ value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, @@ -377,7 +383,13 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A beg_index = self._cfg.mini_infer_size * i end_index = self._cfg.mini_infer_size * (i + 1) - m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + if self._cfg.model.model_type and self._cfg.model.model_type in ['conv', 'mlp']: + m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + elif self._cfg.model.model_type and self._cfg.model.model_type == 'structure': + m_obs = value_obs_list[beg_index:end_index] + m_obs = sum(m_obs, []) + m_obs = default_collate(m_obs) + m_obs = to_device(m_obs, self._cfg.device) # calculate the target value m_output = model.initial_inference(m_obs) diff --git a/lzero/mcts/utils.py b/lzero/mcts/utils.py index 6f5737acb..747a0336a 100644 --- a/lzero/mcts/utils.py +++ b/lzero/mcts/utils.py @@ -63,7 +63,7 @@ def prepare_observation(observation_list, model_type='conv'): - observation_list (:obj:`List`): list of observations. - model_type (:obj:`str`): type of the model. (default is 'conv') """ - assert model_type in ['conv', 'mlp'] + assert model_type in ['conv', 'mlp', 'structure'] observation_array = np.array(observation_list) if model_type == 'conv': @@ -97,6 +97,9 @@ def prepare_observation(observation_list, model_type='conv'): # print(observation_array.shape) observation_array = observation_array.reshape(observation_array.shape[0], -1) # print(observation_array.shape) + + elif model_type == 'structure': + return observation_list return observation_array diff --git a/lzero/model/muzero_model_mlp.py b/lzero/model/muzero_model_mlp.py index caf1df15d..fad6ef59c 100644 --- a/lzero/model/muzero_model_mlp.py +++ b/lzero/model/muzero_model_mlp.py @@ -7,6 +7,7 @@ from .common import MZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean +from ding.utils.default_helper import get_shape0 @MODEL_REGISTRY.register('MuZeroModelMLP') @@ -34,6 +35,9 @@ def __init__( discrete_action_encoding_type: str = 'one_hot', norm_type: Optional[str] = 'BN', res_connection_in_dynamics: bool = False, + state_encoder=None, + state_prediction=None, + state_dynamics=None, *args, **kwargs ): @@ -101,30 +105,39 @@ def __init__( self.state_norm = state_norm self.res_connection_in_dynamics = res_connection_in_dynamics - self.representation_network = RepresentationNetworkMLP( - observation_shape=observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type - ) - - self.dynamics_network = DynamicsNetwork( - action_encoding_dim=self.action_encoding_dim, - num_channels=self.latent_state_dim + self.action_encoding_dim, - common_layer_num=2, - fc_reward_layers=fc_reward_layers, - output_support_size=self.reward_support_size, - last_linear_layer_init_zero=self.last_linear_layer_init_zero, - norm_type=norm_type, - res_connection_in_dynamics=self.res_connection_in_dynamics, - ) - - self.prediction_network = PredictionNetworkMLP( - action_space_size=action_space_size, - num_channels=latent_state_dim, - fc_value_layers=fc_value_layers, - fc_policy_layers=fc_policy_layers, - output_support_size=self.value_support_size, - last_linear_layer_init_zero=self.last_linear_layer_init_zero, - norm_type=norm_type - ) + if state_encoder == None: + self.representation_network = RepresentationNetworkMLP( + observation_shape=observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type + ) + else: + self.representation_network = state_encoder + + if state_dynamics == None: + self.dynamics_network = DynamicsNetwork( + action_encoding_dim=self.action_encoding_dim, + num_channels=self.latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + fc_reward_layers=fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) + else: + self.dynamics_network = state_dynamics + + if state_prediction == None: + self.prediction_network = PredictionNetworkMLP( + action_space_size=action_space_size, + num_channels=latent_state_dim, + fc_value_layers=fc_value_layers, + fc_policy_layers=fc_policy_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type + ) + else: + self.prediction_network = state_prediction if self.self_supervised_learning_loss: # self_supervised_learning_loss related network proposed in EfficientZero @@ -166,14 +179,14 @@ def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. """ - batch_size = obs.size(0) + batch_size = get_shape0(obs) latent_state = self._representation(obs) policy_logits, value = self._prediction(latent_state) return MZNetworkOutput( value, [0. for _ in range(batch_size)], policy_logits, - latent_state, + latent_state[1], ) def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor) -> MZNetworkOutput: @@ -201,7 +214,7 @@ def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor) """ next_latent_state, reward = self._dynamics(latent_state, action) policy_logits, value = self._prediction(next_latent_state) - return MZNetworkOutput(value, reward, policy_logits, next_latent_state) + return MZNetworkOutput(value, reward, policy_logits, next_latent_state[1]) def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: """ diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index a72f6e748..ef372ff22 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -16,7 +16,7 @@ from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ prepare_obs - +from ding.utils.data import default_collate @POLICY_REGISTRY.register('muzero') class MuZeroPolicy(Policy): @@ -298,7 +298,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in target_reward = target_reward.view(self._cfg.batch_size, -1) target_value = target_value.view(self._cfg.batch_size, -1) - assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) + # assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) # ``scalar_transform`` to transform the original value to the scaled value, # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. @@ -505,7 +505,13 @@ def _forward_collect( self._collect_model.eval() self._collect_mcts_temperature = temperature self.collect_epsilon = epsilon - active_collect_env_num = data.shape[0] + active_collect_env_num = len(data) + # + data = sum(data, []) + data = default_collate(data) + # data = to_device(data, self._device) + to_play = np.array(to_play).reshape(-1).tolist() + with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._collect_model.initial_inference(data) @@ -629,7 +635,12 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. """ self._eval_model.eval() - active_eval_env_num = data.shape[0] + active_eval_env_num = len(data) + # + data = sum(data, []) + data = default_collate(data) + # data = to_device(data, self._device) + to_play = np.array(to_play).reshape(-1).tolist() with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._collect_model.initial_inference(data) diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index 61ceb455e..c347bcc9c 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -9,6 +9,8 @@ from easydict import EasyDict from scipy.stats import entropy from torch.nn import functional as F +from ding.torch_utils import to_device, to_tensor +from ding.utils.data import default_collate def visualize_avg_softmax(logits): @@ -316,6 +318,18 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, # ``obs_target_batch`` is only used for calculate consistency loss, which take the all obs other than # timestep t1, and is only performed in the last 8 timesteps in the second dim in ``obs_batch_ori``. obs_target_batch = obs_batch_ori[:, cfg.model.observation_shape:] + + elif cfg.model.model_type == 'structure': + obs_batch = obs_batch_ori[:, 0:cfg.model.frame_stack_num] + if cfg.model.self_supervised_learning_loss: + obs_target_batch = obs_batch_ori[:, cfg.model.frame_stack_num:] + else: + obs_target_batch = None + # obs_batch + obs_batch = obs_batch.tolist() + obs_batch = sum(obs_batch, []) + obs_batch = default_collate(obs_batch) + obs_batch = to_device(obs_batch, device=cfg.device) return obs_batch, obs_target_batch diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 331c72d17..58ad5a2cf 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -399,11 +399,14 @@ def collect(self, chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} chance = [chance_dict[env_id] for env_id in ready_env_id] - stack_obs = to_ndarray(stack_obs) + if self.policy_config.model.model_type in ['conv', 'mlp']: + stack_obs = to_ndarray(stack_obs) - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + elif self.policy_config.model.model_type == 'structure': + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) # ============================================================== # policy forward diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index 9e8d360e3..4ed3b8959 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -280,8 +280,9 @@ def eval( to_play = [to_play_dict[env_id] for env_id in ready_env_id] stack_obs = to_ndarray(stack_obs) - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() + if self.policy_config.model.model_type and self.policy_config.model.model_type in ['conv', 'mlp']: + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() # ============================================================== # policy forward diff --git a/zoo/petting_zoo/__init__.py b/zoo/petting_zoo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/petting_zoo/config/__init__.py b/zoo/petting_zoo/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py new file mode 100644 index 000000000..515242ee4 --- /dev/null +++ b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py @@ -0,0 +1,117 @@ +from easydict import EasyDict + +env_name = 'ptz_simple_spread' +multi_agent = True + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +seed = 0 +n_agent = 3 +n_landmark = n_agent +collector_env_num = 8 +evaluator_env_num = 8 +n_episode = 8 +batch_size = 256 +num_simulations = 50 +update_per_collect = 1000 +reanalyze_ratio = 0. +action_space_size = 5*5*5 +eps_greedy_exploration_in_collect = True +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +main_config = dict( + exp_name= + f'data_mz_ctree/{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', + env=dict( + env_family='mpe', + env_id='simple_spread_v2', + n_agent=n_agent, + n_landmark=n_landmark, + max_cycles=25, + agent_obs_only=False, + agent_specific_global_state=False, + continuous_actions=False, + stop_value=0, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + multi_agent=multi_agent, + ignore_done=False, + model=dict( + model_type='structure', + latent_state_dim=256, + frame_stack_num=1, + action_space='discrete', + action_space_size=action_space_size, + agent_num=n_agent, + self_supervised_learning_loss=False, # default is False + agent_obs_shape=(2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2)*3, + global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + + n_landmark * 2 + n_agent * (n_agent - 1) * 2, + discrete_action_encoding_type='one_hot', + global_cooperation=True, # TODO: doesn't work now + hidden_size_list=[256, 256], + norm_type='BN', + ), + cuda=True, + mcts_ctree=True, + gumbel_algo=False, + env_type='not_board_games', + game_segment_length=30, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + use_augmentation=False, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + ssl_loss_weight=0, # default is 0 + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + learn=dict(learner=dict( + log_policy=True, + hook=dict(log_show_after_iter=10, ), + ), ), +) +main_config = EasyDict(main_config) +create_config = dict( + env=dict( + import_names=['zoo.petting_zoo.envs.petting_zoo_simple_spread_env'], + type='petting_zoo', + ), + env_manager=dict(type='base'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +create_config = EasyDict(create_config) +ptz_simple_spread_muzero_config = main_config +ptz_simple_spread_muzero_create_config = create_config + +if __name__ == '__main__': + from zoo.petting_zoo.entry import train_muzero + train_muzero([main_config, create_config], seed=seed) diff --git a/zoo/petting_zoo/entry/__init__.py b/zoo/petting_zoo/entry/__init__.py new file mode 100644 index 000000000..5e8144157 --- /dev/null +++ b/zoo/petting_zoo/entry/__init__.py @@ -0,0 +1,2 @@ +from .train_muzero import train_muzero +from .eval_muzero import eval_muzero \ No newline at end of file diff --git a/zoo/petting_zoo/entry/eval_muzero.py b/zoo/petting_zoo/entry/eval_muzero.py new file mode 100644 index 000000000..7eb3e4d17 --- /dev/null +++ b/zoo/petting_zoo/entry/eval_muzero.py @@ -0,0 +1,81 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_rank +from ding.rl_utils import get_epsilon_greedy_fn +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.entry.utils import random_collect +from zoo.petting_zoo.model import PettingZooEncoder + +def eval_muzero(main_cfg, create_cfg, seed=0): + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'], \ + "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'" + + if create_cfg.policy.type == 'muzero' or create_cfg.policy.type == 'multi_agent_muzero': + from lzero.mcts import MuZeroGameBuffer as GameBuffer + from lzero.model.muzero_model_mlp import MuZeroModelMLP as Encoder + elif create_cfg.policy.type == 'efficientzero' or create_cfg.policy.type == 'multi_agent_efficientzero': + from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP as Encoder + elif create_cfg.policy.type == 'sampled_efficientzero': + from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gumbel_muzero': + from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer + + main_cfg.policy.device = 'cpu' + main_cfg.policy.load_path = 'exp_name/ckpt/ckpt_best.pth.tar' + main_cfg.env.replay_path = './' # when visualize must set as base + create_cfg.env_manager.type = 'base' # when visualize must set as base + main_cfg.env.evaluator_env_num = 1 # only 1 env for save replay + main_cfg.env.n_evaluator_episode = 1 + + cfg = compile_config(main_cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder()) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location=cfg.policy.device)) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = cfg.policy + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) + return stop, reward + +if __name__ == '__main__': + from zoo.petting_zoo.config.ptz_simple_spread_ez_config import main_config, create_config + eval_muzero(main_config, create_config, seed=0) \ No newline at end of file diff --git a/zoo/petting_zoo/entry/train_muzero.py b/zoo/petting_zoo/entry/train_muzero.py new file mode 100644 index 000000000..c31f36892 --- /dev/null +++ b/zoo/petting_zoo/entry/train_muzero.py @@ -0,0 +1,198 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_rank +from ding.rl_utils import get_epsilon_greedy_fn +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator +from zoo.petting_zoo.model import PettingZooEncoder, PettingZooPrediction, PettingZooDynamics +from lzero.entry.utils import random_collect + + +def train_muzero( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + The train entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, Gumbel Muzero. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + + cfg, create_cfg = input_cfg + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \ + "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'" + + if create_cfg.policy.type == 'muzero': + from lzero.mcts import MuZeroGameBuffer as GameBuffer + from lzero.model.muzero_model_mlp import MuZeroModelMLP as Encoder + elif create_cfg.policy.type == 'efficientzero': + from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'sampled_efficientzero': + from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gumbel_muzero': + from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'stochastic_muzero': + from lzero.mcts import StochasticMuZeroGameBuffer as GameBuffer + + if cfg.policy.cuda and torch.cuda.is_available(): + cfg.policy.device = 'cuda' + else: + cfg.policy.device = 'cpu' + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(), state_prediction=PettingZooPrediction(), state_dynamics=PettingZooDynamics()) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # load pretrained model + if model_path is not None: + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = cfg.policy + batch_size = policy_config.batch_size + # specific game buffer for MCTS+RL algorithms + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + + if cfg.policy.update_per_collect is not None: + update_per_collect = cfg.policy.update_per_collect + + # The purpose of collecting random data before training: + # Exploration: The collection of random data aids the agent in exploring the environment and prevents premature convergence to a suboptimal policy. + # Comparation: The agent's performance during random action-taking can be used as a reference point to evaluate the efficacy of reinforcement learning algorithms. + if cfg.policy.random_collect_episode_num > 0: + random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + + while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + collect_kwargs = {} + # set temperature for visit count distributions according to the train_iter, + # please refer to Appendix D in MuZero paper for details. + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + else: + collect_kwargs['epsilon'] = 0.0 + + # Evaluate policy performance. + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect data by default config n_sample/n_episode. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + if cfg.policy.update_per_collect is None: + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) + update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio) + # save returned new_data collected by the collector + replay_buffer.push_game_segments(new_data) + # remove the oldest data if the replay buffer is full. + replay_buffer.remove_oldest_data_to_fit() + + # Learn policy from collected data. + for i in range(update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, policy) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + + # The core train steps for MCTS+RL algorithms. + log_vars = learner.train(train_data, collector.envstep) + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + break + + # Learner's after_run hook. + learner.call_hook('after_run') + return policy diff --git a/zoo/petting_zoo/envs/__init__.py b/zoo/petting_zoo/envs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py new file mode 100644 index 000000000..83f7eafaa --- /dev/null +++ b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -0,0 +1,390 @@ +from typing import Any, List, Union, Optional, Dict +import gymnasium as gym +import numpy as np +import pettingzoo +from functools import reduce + +from ding.envs import BaseEnv, BaseEnvTimestep, FrameStackWrapper +from ding.torch_utils import to_ndarray, to_list +from ding.envs.common.common_function import affine_transform +from ding.utils import ENV_REGISTRY, import_module +from pettingzoo.utils.conversions import parallel_wrapper_fn +from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env +from pettingzoo.mpe.simple_spread.simple_spread import Scenario +from PIL import Image +import pygame + + +@ENV_REGISTRY.register('petting_zoo') +class PettingZooEnv(BaseEnv): + # Now only supports simple_spread_v2. + # All agents' observations should have the same shape. + + def __init__(self, cfg: dict) -> None: + self._cfg = cfg + self._init_flag = False + self._replay_path = None + # self._replay_path = self._cfg.get('replay_path', None) + # self.frame_list = [] + self._env_family = self._cfg.env_family + self._env_id = self._cfg.env_id + self._num_agents = self._cfg.n_agent + self._num_landmarks = self._cfg.n_landmark + self._continuous_actions = self._cfg.get('continuous_actions', False) + self._max_cycles = self._cfg.get('max_cycles', 25) + self._act_scale = self._cfg.get('act_scale', False) + self._agent_specific_global_state = self._cfg.get('agent_specific_global_state', False) + if self._act_scale: + assert self._continuous_actions, 'Only continuous action space env needs act_scale' + + # joint action + import itertools + action_space = [0, 1, 2, 3, 4] + self.combinations = list(itertools.product(action_space, repeat=3)) + + def reset(self) -> np.ndarray: + if not self._init_flag: + # In order to align with the simple spread in Multiagent Particle Env (MPE), + # instead of adopting the pettingzoo interface directly, + # we have redefined the way rewards are calculated + + # import_module(['pettingzoo.{}.{}'.format(self._env_family, self._env_id)]) + # self._env = pettingzoo.__dict__[self._env_family].__dict__[self._env_id].parallel_env( + # N=self._cfg.n_agent, continuous_actions=self._continuous_actions, max_cycles=self._max_cycles + # ) + + # init parallel_env wrapper + _env = make_env(simple_spread_raw_env) + parallel_env = parallel_wrapper_fn(_env) + # init env + self._env = parallel_env( + N=self._cfg.n_agent, continuous_actions=self._continuous_actions, max_cycles=self._max_cycles + ) + # dynamic seed reduces training speed greatly + # if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: + # np_seed = 100 * np.random.randint(1, 1000) + # self._env.seed(self._seed + np_seed) + if self._replay_path is not None: + self._env = gym.wrappers.Monitor( + self._env, self._replay_path, video_callable=lambda episode_id: True, force=True + ) + if hasattr(self, '_seed'): + obs = self._env.reset(seed=self._seed) + else: + obs = self._env.reset() + if not self._init_flag: + self._agents = self._env.agents + + self._action_space = gym.spaces.Dict({agent: self._env.action_space(agent) for agent in self._agents}) + single_agent_obs_space = self._env.action_space(self._agents[0]) + if isinstance(single_agent_obs_space, gym.spaces.Box): + self._action_dim = single_agent_obs_space.shape + elif isinstance(single_agent_obs_space, gym.spaces.Discrete): + self._action_dim = (single_agent_obs_space.n, ) + else: + raise Exception('Only support `Box` or `Discrete` obs space for single agent.') + + # only for env 'simple_spread_v2', n_agent = 5 + # now only for the case that each agent in the team have the same obs structure and corresponding shape. + if not self._cfg.agent_obs_only: + self._observation_space = gym.spaces.Dict( + { + 'agent_state': gym.spaces.Box( + low=float("-inf"), + high=float("inf"), + shape=(self._num_agents, + self._env.observation_space('agent_0').shape[0]), # (self._num_agents, 30) + dtype=np.float32 + ), + 'global_state': gym.spaces.Box( + low=float("-inf"), + high=float("inf"), + shape=( + 4 * self._num_agents + 2 * self._num_landmarks + 2 * self._num_agents * + (self._num_agents - 1), + ), + dtype=np.float32 + ), + 'agent_alone_state': gym.spaces.Box( + low=float("-inf"), + high=float("inf"), + shape=(self._num_agents, 4 + 2 * self._num_landmarks + 2 * (self._num_agents - 1)), + dtype=np.float32 + ), + 'agent_alone_padding_state': gym.spaces.Box( + low=float("-inf"), + high=float("inf"), + shape=(self._num_agents, + self._env.observation_space('agent_0').shape[0]), # (self._num_agents, 30) + dtype=np.float32 + ), + 'action_mask': gym.spaces.Box( + low=float("-inf"), + high=float("inf"), + shape=(self._num_agents, self._action_dim[0]), # (self._num_agents, 5) + dtype=np.float32 + ) + } + ) + # whether use agent_specific_global_state. It is usually used in AC multiagent algos, e.g., mappo, masac, etc. + if self._agent_specific_global_state: + agent_specifig_global_state = gym.spaces.Box( + low=float("-inf"), + high=float("inf"), + shape=( + self._num_agents, self._env.observation_space('agent_0').shape[0] + 4 * self._num_agents + + 2 * self._num_landmarks + 2 * self._num_agents * (self._num_agents - 1) + ), + dtype=np.float32 + ) + self._observation_space['global_state'] = agent_specifig_global_state + else: + # for case when env.agent_obs_only=True + self._observation_space = gym.spaces.Box( + low=float("-inf"), + high=float("inf"), + shape=(self._num_agents, self._env.observation_space('agent_0').shape[0]), + dtype=np.float32 + ) + + self._reward_space = gym.spaces.Dict( + { + agent: gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32) + for agent in self._agents + } + ) + self._init_flag = True + # self._eval_episode_return = {agent: 0. for agent in self._agents} + self._eval_episode_return = 0. + self._step_count = 0 + obs_n = self._process_obs(obs) + return obs_n + + def close(self) -> None: + if self._init_flag: + self._env.close() + self._init_flag = False + + def render(self) -> None: + self._env.render() + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def step(self, action: int) -> BaseEnvTimestep: + action = np.array(self.combinations[action]) + self._step_count += 1 + action = self._process_action(action) + if self._act_scale: + for agent in self._agents: + # print(action[agent]) + # print(self.action_space[agent]) + # print(self.action_space[agent].low, self.action_space[agent].high) + action[agent] = affine_transform( + action[agent], min_val=self.action_space[agent].low, max_val=self.action_space[agent].high + ) + + obs, rew, done, trunc, info = self._env.step(action) + obs_n = self._process_obs(obs) + rew_n = np.array([sum([rew[agent] for agent in self._agents])]) + rew_n = rew_n.astype(np.float32) + # collide_sum = 0 + # for i in range(self._num_agents): + # collide_sum += info['n'][i][1] + # collide_penalty = self._cfg.get('collide_penal', self._num_agent) + # rew_n += collide_sum * (1.0 - collide_penalty) + # rew_n = rew_n / (self._cfg.get('max_cycles', 25) * self._num_agent) + self._eval_episode_return += rew_n.item() + + # occupied_landmarks = info['n'][0][3] + # if self._step_count >= self._max_step or occupied_landmarks >= self._n_agent \ + # or occupied_landmarks >= self._num_landmarks: + # done_n = True + # else: + # done_n = False + done_n = reduce(lambda x, y: x and y, done.values()) or self._step_count >= self._max_cycles + + # for agent in self._agents: + # self._eval_episode_return[agent] += rew[agent] + # if self._replay_path is not None: + # self.frame_list.append(Image.fromarray(self._env.render())) + if done_n: # or reduce(lambda x, y: x and y, done.values()) + info['eval_episode_return'] = self._eval_episode_return + # if self._replay_path is not None: + # self.frame_list[0].save('out.gif', save_all=True, append_images=self.frame_list[1:], duration=3, loop=0) + # for agent in rew: + # rew[agent] = to_ndarray([rew[agent]]) + return BaseEnvTimestep(obs_n, rew_n, done_n, info) + + def enable_save_replay(self, replay_path: Optional[str] = None) -> None: + if replay_path is None: + replay_path = './video' + self._replay_path = replay_path + + def _process_obs(self, obs: 'torch.Tensor') -> np.ndarray: # noqa + obs = np.array([obs[agent] for agent in self._agents]).astype(np.float32) + if self._cfg.get('agent_obs_only', False): + return obs + ret = {} + # Raw agent observation structure is -- + # [self_vel, self_pos, landmark_rel_positions, other_agent_rel_positions, communication] + # where `communication` are signals from other agents (two for each agent in `simple_spread_v2`` env) + + # agent_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2). + # Stacked observation. Contains + # - agent itself's state(velocity + position) + # - position of items that the agent can observe(e.g. other agents, landmarks) + # - communication + ret['agent_state'] = obs + # global_state: Shape (n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, ). + # 1-dim vector. Contains + # - all agents' state(velocity + position) + + # - all landmarks' position + + # - all agents' communication + ret['global_state'] = np.concatenate( + [ + obs[0, 2:-(self._num_agents - 1) * 2], # all agents' position + all landmarks' position + obs[:, 0:2].flatten(), # all agents' velocity + obs[:, -(self._num_agents - 1) * 2:].flatten() # all agents' communication + ] + ) + # agent_specific_global_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2). + # 2-dim vector. contains + # - agent_state info + # - global_state info + if self._agent_specific_global_state: + ret['global_state'] = np.concatenate( + [ret['agent_state'], + np.expand_dims(ret['global_state'], axis=0).repeat(self._num_agents, axis=0)], + axis=1 + ) + # agent_alone_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2). + # Stacked observation. Exclude other agents' positions from agent_state. Contains + # - agent itself's state(velocity + position) + + # - landmarks' positions (do not include other agents' positions) + # - communication + ret['agent_alone_state'] = np.concatenate( + [ + obs[:, 0:(4 + self._num_agents * 2)], # agent itself's state + landmarks' position + obs[:, -(self._num_agents - 1) * 2:], # communication + ], + 1 + ) + # agent_alone_padding_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2). + # Contains the same information as agent_alone_state; + # But 0-padding other agents' positions. + ret['agent_alone_padding_state'] = np.concatenate( + [ + obs[:, 0:(4 + self._num_agents * 2)], # agent itself's state + landmarks' position + np.zeros((self._num_agents, + (self._num_agents - 1) * 2), np.float32), # Other agents' position(0-padding) + obs[:, -(self._num_agents - 1) * 2:] # communication + ], + 1 + ) + # action_mask: All actions are of use(either 1 for discrete or 5 for continuous). Thus all 1. + joint_action_mask = [1 for _ in range(5*5*5)] + return {'observation': ret, 'action_mask': joint_action_mask, 'to_play': [-1]} + + def _process_action(self, action: 'torch.Tensor') -> Dict[str, np.ndarray]: # noqa + dict_action = {} + for i, agent in enumerate(self._agents): + agent_action = action[i] + if agent_action.shape == (1, ): + agent_action = agent_action.squeeze() # 0-dim array + dict_action[agent] = agent_action + return dict_action + + def random_action(self) -> np.ndarray: + random_action = self.action_space.sample() + for k in random_action: + if isinstance(random_action[k], np.ndarray): + pass + elif isinstance(random_action[k], int): + random_action[k] = to_ndarray([random_action[k]], dtype=np.int64) + return random_action + + def __repr__(self) -> str: + return "DI-engine PettingZoo Env" + + @property + def agents(self) -> List[str]: + return self._agents + + @property + def observation_space(self) -> gym.spaces.Space: + return self._observation_space + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space + + @property + def reward_space(self) -> gym.spaces.Space: + return self._reward_space + + +class simple_spread_raw_env(SimpleEnv): + + def __init__(self, N=3, local_ratio=0.5, max_cycles=25, continuous_actions=False): + assert 0. <= local_ratio <= 1., "local_ratio is a proportion. Must be between 0 and 1." + scenario = Scenario() + world = scenario.make_world(N) + super().__init__(scenario, world, max_cycles, continuous_actions=continuous_actions, local_ratio=local_ratio) + # self.render_mode = 'rgb_array' + self.metadata['name'] = "simple_spread_v2" + + def _execute_world_step(self): + # set action for each agent + for i, agent in enumerate(self.world.agents): + action = self.current_actions[i] + scenario_action = [] + if agent.movable: + mdim = self.world.dim_p * 2 + 1 + if self.continuous_actions: + scenario_action.append(action[0:mdim]) + action = action[mdim:] + else: + scenario_action.append(action % mdim) + action //= mdim + if not agent.silent: + scenario_action.append(action) + self._set_action(scenario_action, agent, self.action_spaces[agent.name]) + + self.world.step() + + global_reward = 0. + if self.local_ratio is not None: + global_reward = float(self.scenario.global_reward(self.world)) + + for agent in self.world.agents: + agent_reward = float(self.scenario.reward(agent, self.world)) + if self.local_ratio is not None: + # we changed reward calc way to keep same with mpe + # reward = global_reward * (1 - self.local_ratio) + agent_reward * self.local_ratio + reward = global_reward + agent_reward + else: + reward = agent_reward + + self.rewards[agent.name] = reward + + # def render(self): + # if self.render_mode is None: + # gym.logger.warn( + # "You are calling render method without specifying any render mode." + # ) + # return + + # self.enable_render(self.render_mode) + + # self.draw() + # observation = np.array(pygame.surfarray.pixels3d(self.screen)) + # if self.render_mode == "human": + # pygame.display.flip() + # return ( + # np.transpose(observation, axes=(1, 0, 2)) + # if self.render_mode == "rgb_array" + # else None + # ) diff --git a/zoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py new file mode 100644 index 000000000..22117cf85 --- /dev/null +++ b/zoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py @@ -0,0 +1,133 @@ +from easydict import EasyDict +import pytest +import numpy as np +import pettingzoo +from ding.utils import import_module + +from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PettingZooEnv + + +@pytest.mark.envtest +class TestPettingZooEnv: + + def test_agent_obs_only(self): + n_agent = 5 + n_landmark = n_agent + env = PettingZooEnv( + EasyDict( + dict( + env_family='mpe', + env_id='simple_spread_v2', + n_agent=n_agent, + n_landmark=n_landmark, + max_step=100, + agent_obs_only=True, + continuous_actions=True, + ) + ) + ) + env.seed(123) + assert env._seed == 123 + obs = env.reset() + assert obs.shape == (n_agent, 2 + 2 + (n_agent - 1) * 2 + n_agent * 2 + (n_agent - 1) * 2) + for i in range(10): + random_action = env.random_action() + random_action = np.array([random_action[agent] for agent in random_action]) + timestep = env.step(random_action) + print(timestep) + assert isinstance(timestep.obs, np.ndarray), timestep.obs + assert timestep.obs.shape == (n_agent, 2 + 2 + (n_agent - 1) * 2 + n_agent * 2 + (n_agent - 1) * 2) + assert isinstance(timestep.done, bool), timestep.done + assert isinstance(timestep.reward, np.ndarray), timestep.reward + assert timestep.reward.dtype == np.float32 + print(env.observation_space, env.action_space, env.reward_space) + env.close() + + def test_dict_obs(self): + n_agent = 5 + n_landmark = n_agent + env = PettingZooEnv( + EasyDict( + dict( + env_family='mpe', + env_id='simple_spread_v2', + n_agent=n_agent, + n_landmark=n_landmark, + max_step=100, + agent_obs_only=False, + continuous_actions=True, + ) + ) + ) + env.seed(123) + assert env._seed == 123 + obs = env.reset() + for k, v in obs.items(): + print(k, v.shape) + for i in range(10): + random_action = env.random_action() + random_action = np.array([random_action[agent] for agent in random_action]) + timestep = env.step(random_action) + print(timestep) + assert isinstance(timestep.obs, dict), timestep.obs + assert isinstance(timestep.obs['agent_state'], np.ndarray), timestep.obs + assert timestep.obs['agent_state'].shape == ( + n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + ) + assert timestep.obs['global_state'].shape == ( + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, + ) + assert timestep.obs['agent_alone_state'].shape == (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2) + assert timestep.obs['agent_alone_padding_state'].shape == ( + n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + ) + assert timestep.obs['action_mask'].dtype == np.float32 + assert isinstance(timestep.done, bool), timestep.done + assert isinstance(timestep.reward, np.ndarray), timestep.reward + print(env.observation_space, env.action_space, env.reward_space) + env.close() + + def test_agent_specific_global_state(self): + n_agent = 5 + n_landmark = n_agent + env = PettingZooEnv( + EasyDict( + dict( + env_family='mpe', + env_id='simple_spread_v2', + n_agent=n_agent, + n_landmark=n_landmark, + max_step=100, + agent_obs_only=False, + agent_specific_global_state=True, + continuous_actions=True, + ) + ) + ) + env.seed(123) + assert env._seed == 123 + obs = env.reset() + for k, v in obs.items(): + print(k, v.shape) + for i in range(10): + random_action = env.random_action() + random_action = np.array([random_action[agent] for agent in random_action]) + timestep = env.step(random_action) + print(timestep) + assert isinstance(timestep.obs, dict), timestep.obs + assert isinstance(timestep.obs['agent_state'], np.ndarray), timestep.obs + assert timestep.obs['agent_state'].shape == ( + n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + ) + assert timestep.obs['global_state'].shape == ( + n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + + n_landmark * 2 + n_agent * (n_agent - 1) * 2 + ) + assert timestep.obs['agent_alone_state'].shape == (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2) + assert timestep.obs['agent_alone_padding_state'].shape == ( + n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + ) + assert isinstance(timestep.done, bool), timestep.done + assert isinstance(timestep.reward, np.ndarray), timestep.reward + print(env.observation_space, env.action_space, env.reward_space) + env.close() diff --git a/zoo/petting_zoo/model/__init__.py b/zoo/petting_zoo/model/__init__.py new file mode 100644 index 000000000..cc2eebcdf --- /dev/null +++ b/zoo/petting_zoo/model/__init__.py @@ -0,0 +1 @@ +from .model import PettingZooEncoder, PettingZooPrediction, PettingZooDynamics \ No newline at end of file diff --git a/zoo/petting_zoo/model/model.py b/zoo/petting_zoo/model/model.py new file mode 100644 index 000000000..dee3ee836 --- /dev/null +++ b/zoo/petting_zoo/model/model.py @@ -0,0 +1,291 @@ +import torch +import torch.nn as nn +from ding.model.common import FCEncoder +from ding.torch_utils import MLP, ResBlock + +from lzero.model.common import RepresentationNetworkMLP +from typing import Optional, Tuple +import torch +import torch.nn as nn +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType +from lzero.model.utils import get_dynamic_mean, get_reward_mean + +class PettingZooEncoder(nn.Module): + + def __init__(self): + super().__init__() + self.agent_encoder = RepresentationNetworkMLP(observation_shape=18, hidden_channels=128, norm_type='BN') + self.global_encoder = RepresentationNetworkMLP(observation_shape=30, hidden_channels=128, norm_type='BN') + self.encoder = RepresentationNetworkMLP(observation_shape=512, hidden_channels=128, norm_type='BN') + + def forward(self, x): + # agent + batch_size, agent_num = x['agent_state'].shape[0], x['agent_state'].shape[1] + agent_state = x['agent_state'].reshape(batch_size*agent_num, -1) + agent_state = self.agent_encoder(agent_state) + agent_state_B = agent_state.reshape(batch_size, -1) # [8, 768] + agent_state_B_A = agent_state.reshape(batch_size, agent_num, -1) + # global + global_state = self.global_encoder(x['global_state']) + global_state = self.encoder(torch.cat((agent_state_B, global_state),dim=1)) + return (agent_state_B, global_state) + + +class PettingZooPrediction(nn.Module): + + def __init__( + self, + action_space_size: int=125, + num_channels: int=128, + common_layer_num: int = 2, + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + ): + """ + Overview: + The definition of policy and value prediction network with Multi-Layer Perceptron (MLP), + which is used to predict value and policy by the given latent state. + Arguments: + - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \ + space, it is the number of discrete actions. + - num_channels (:obj:`int`): The channels of latent states. + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - output_support_size (:obj:`int`): The size of categorical value output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ + dynamics/prediction mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super().__init__() + self.num_channels = num_channels + + # ******* common backbone ****** + self.fc_prediction_common = MLP( + in_channels=self.num_channels, + hidden_channels=self.num_channels, + out_channels=self.num_channels, + layer_num=common_layer_num, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + # ******* value and policy head ****** + self.fc_value_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=len(fc_value_layers) + 1, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.fc_policy_head = MLP( + in_channels=self.num_channels*3, + hidden_channels=fc_policy_layers[0], + out_channels=action_space_size, + layer_num=len(fc_policy_layers) + 1, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, latent_state: torch.Tensor): + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + agent_state, global_state = latent_state + global_state = self.fc_prediction_common(global_state) + + value = self.fc_value_head(global_state) + policy = self.fc_policy_head(agent_state) + return policy, value + + +class PettingZooDynamics(nn.Module): + + def __init__( + self, + action_encoding_dim: int = 125, + num_channels: int = 253, + common_layer_num: int = 2, + fc_reward_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, + ): + """ + Overview: + The definition of dynamics network in MuZero algorithm, which is used to predict next latent state + reward by the given current latent state and action. + The networks are mainly built on fully connected layers. + Arguments: + - action_encoding_dim (:obj:`int`): The dimension of action encoding. + - num_channels (:obj:`int`): The num of channels in latent states. + - common_layer_num (:obj:`int`): The number of common layers in dynamics network. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - output_support_size (:obj:`int`): The size of categorical reward output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. + """ + super().__init__() + self.num_channels = num_channels + self.action_encoding_dim = action_encoding_dim + self.latent_state_dim = self.num_channels - self.action_encoding_dim + + self.res_connection_in_dynamics = res_connection_in_dynamics + if self.res_connection_in_dynamics: + self.fc_dynamics_1 = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + self.fc_dynamics_2 = MLP( + in_channels=self.latent_state_dim, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + else: + self.fc_dynamics_1 = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + self.fc_dynamics_2 = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + self.fc_dynamics_3 = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + self.fc_dynamics_4 = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + self.fc_reward_head = MLP( + in_channels=self.latent_state_dim, + hidden_channels=fc_reward_layers[0], + layer_num=2, + out_channels=output_support_size, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the dynamics network. Predict the next latent state given current latent state and action. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ + latent state and action encoding, with shape (batch_size, num_channels, height, width). + Returns: + - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, latent_state_dim). + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + """ + if self.res_connection_in_dynamics: + # take the state encoding (e.g. latent_state), + # state_action_encoding[:, -self.action_encoding_dim:] is action encoding + latent_state = state_action_encoding[:, :-self.action_encoding_dim] + x = self.fc_dynamics_1(state_action_encoding) + # the residual link: add the latent_state to the state_action encoding + next_latent_state = x + latent_state + next_latent_state_encoding = self.fc_dynamics_2(next_latent_state) + else: + batch_size = state_action_encoding.shape[0] + next_agent_latent_state_1 = self.fc_dynamics_1(state_action_encoding) + next_agent_latent_state_2 = self.fc_dynamics_2(state_action_encoding) + next_agent_latent_state_3 = self.fc_dynamics_3(state_action_encoding) + next_agent_latent_state = torch.stack((next_agent_latent_state_1, next_agent_latent_state_2, next_agent_latent_state_3), dim=1) + next_agent_latent_state = next_agent_latent_state.reshape(batch_size, -1) + next_global_latent_state = self.fc_dynamics_4(state_action_encoding) + next_latent_state_encoding = next_global_latent_state + + reward = self.fc_reward_head(next_latent_state_encoding) + + return (next_agent_latent_state, next_latent_state_encoding), reward + + def get_dynamic_mean(self) -> float: + return get_dynamic_mean(self) + + def get_reward_mean(self) -> float: + return get_reward_mean(self) \ No newline at end of file From 1bd05f0a019c4f017a11a0d28fcaf4fbbc887222 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Sat, 14 Oct 2023 19:54:38 +0800 Subject: [PATCH 02/12] fix(yzj): fix data device on mz policy --- lzero/policy/muzero.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index ef372ff22..411675ae4 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -17,6 +17,7 @@ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ prepare_obs from ding.utils.data import default_collate +from ding.torch_utils import to_device, to_tensor @POLICY_REGISTRY.register('muzero') class MuZeroPolicy(Policy): @@ -509,7 +510,7 @@ def _forward_collect( # data = sum(data, []) data = default_collate(data) - # data = to_device(data, self._device) + data = to_device(data, self._device) to_play = np.array(to_play).reshape(-1).tolist() with torch.no_grad(): @@ -639,7 +640,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 # data = sum(data, []) data = default_collate(data) - # data = to_device(data, self._device) + data = to_device(data, self._device) to_play = np.array(to_play).reshape(-1).tolist() with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} From 1b4ff2be808bf3ec868fb4a5c18f6514a423ed24 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Mon, 16 Oct 2023 12:22:19 +0800 Subject: [PATCH 03/12] feature(yzj): adapt ptz simple env --- lzero/mcts/buffer/game_buffer_muzero.py | 4 + .../config/ptz_simple_mz_config.py | 116 ++++++++++++++++++ .../config/ptz_simple_spread_mz_config.py | 15 ++- zoo/petting_zoo/entry/train_muzero.py | 2 +- .../envs/petting_zoo_simple_spread_env.py | 2 +- zoo/petting_zoo/model/model.py | 98 +++++++-------- 6 files changed, 172 insertions(+), 65 deletions(-) create mode 100644 zoo/petting_zoo/config/ptz_simple_mz_config.py diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 9fe01a8c2..e0759ecff 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -205,6 +205,10 @@ def _prepare_reward_value_context( 'global_state': np.zeros((30,), dtype=np.float32), 'agent_alone_state': np.zeros((3, 14), dtype=np.float32), 'agent_alone_padding_state': np.zeros((3, 18), dtype=np.float32),}]) + zero_obs = np.array([{'agent_state': np.zeros((1, 6), dtype=np.float32), + 'global_state': np.zeros((8,), dtype=np.float32), + 'agent_alone_state': np.zeros((1, 12), dtype=np.float32), + 'agent_alone_padding_state': np.zeros((1, 12), dtype=np.float32),}]) value_obs_list = [] # the value is valid or not (out of game_segment) value_mask = [] diff --git a/zoo/petting_zoo/config/ptz_simple_mz_config.py b/zoo/petting_zoo/config/ptz_simple_mz_config.py new file mode 100644 index 000000000..0c5338485 --- /dev/null +++ b/zoo/petting_zoo/config/ptz_simple_mz_config.py @@ -0,0 +1,116 @@ +from easydict import EasyDict + +env_name = 'ptz_simple' +multi_agent = True + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +seed = 0 +n_agent = 1 +n_landmark = n_agent +collector_env_num = 8 +evaluator_env_num = 8 +n_episode = 8 +batch_size = 256 +num_simulations = 50 +update_per_collect = 50 +reanalyze_ratio = 0. +action_space_size = 5 +eps_greedy_exploration_in_collect = True +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +main_config = dict( + exp_name= + f'data_mz_ctree/{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', + env=dict( + env_family='mpe', + env_id='simple_v2', + n_agent=n_agent, + n_landmark=n_landmark, + max_cycles=25, + agent_obs_only=False, + agent_specific_global_state=False, + continuous_actions=False, + stop_value=0, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + multi_agent=multi_agent, + ignore_done=False, + model=dict( + model_type='structure', + latent_state_dim=256, + frame_stack_num=1, + action_space='discrete', + action_space_size=action_space_size, + agent_num=n_agent, + self_supervised_learning_loss=False, # default is False + agent_obs_shape=6, + global_obs_shape=8, + discrete_action_encoding_type='one_hot', + global_cooperation=True, # TODO: doesn't work now + hidden_size_list=[256, 256], + norm_type='BN', + ), + cuda=True, + mcts_ctree=True, + gumbel_algo=False, + env_type='not_board_games', + game_segment_length=30, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + use_augmentation=False, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=0, # NOTE: default is 0. + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + learn=dict(learner=dict( + log_policy=True, + hook=dict(log_show_after_iter=10, ), + ), ), +) +main_config = EasyDict(main_config) +create_config = dict( + env=dict( + import_names=['zoo.petting_zoo.envs.petting_zoo_simple_spread_env'], + type='petting_zoo', + ), + env_manager=dict(type='base'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +create_config = EasyDict(create_config) +ptz_simple_spread_muzero_config = main_config +ptz_simple_spread_muzero_create_config = create_config + +if __name__ == '__main__': + from zoo.petting_zoo.entry import train_muzero + train_muzero([main_config, create_config], seed=seed) diff --git a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py index 515242ee4..4655a85b6 100644 --- a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py +++ b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py @@ -14,7 +14,7 @@ n_episode = 8 batch_size = 256 num_simulations = 50 -update_per_collect = 1000 +update_per_collect = 50 reanalyze_ratio = 0. action_space_size = 5*5*5 eps_greedy_exploration_in_collect = True @@ -51,9 +51,8 @@ action_space_size=action_space_size, agent_num=n_agent, self_supervised_learning_loss=False, # default is False - agent_obs_shape=(2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2)*3, - global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + - n_landmark * 2 + n_agent * (n_agent - 1) * 2, + agent_obs_shape=18, + global_obs_shape=30, discrete_action_encoding_type='one_hot', global_cooperation=True, # TODO: doesn't work now hidden_size_list=[256, 256], @@ -75,10 +74,10 @@ use_augmentation=False, update_per_collect=update_per_collect, batch_size=batch_size, - optim_type='SGD', - lr_piecewise_constant_decay=True, - learning_rate=0.2, - ssl_loss_weight=0, # default is 0 + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=0, # NOTE: default is 0. num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, diff --git a/zoo/petting_zoo/entry/train_muzero.py b/zoo/petting_zoo/entry/train_muzero.py index c31f36892..3789bce38 100644 --- a/zoo/petting_zoo/entry/train_muzero.py +++ b/zoo/petting_zoo/entry/train_muzero.py @@ -79,7 +79,7 @@ def train_muzero( evaluator_env.seed(cfg.seed, dynamic_seed=False) set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(), state_prediction=PettingZooPrediction(), state_dynamics=PettingZooDynamics()) + model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg), state_prediction=PettingZooPrediction(cfg), state_dynamics=PettingZooDynamics(cfg)) policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) # load pretrained model diff --git a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index 83f7eafaa..556c78907 100644 --- a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -285,7 +285,7 @@ def _process_obs(self, obs: 'torch.Tensor') -> np.ndarray: # noqa 1 ) # action_mask: All actions are of use(either 1 for discrete or 5 for continuous). Thus all 1. - joint_action_mask = [1 for _ in range(5*5*5)] + joint_action_mask = [1 for _ in range(np.power(5, self._num_agents))] return {'observation': ret, 'action_mask': joint_action_mask, 'to_play': [-1]} def _process_action(self, action: 'torch.Tensor') -> Dict[str, np.ndarray]: # noqa diff --git a/zoo/petting_zoo/model/model.py b/zoo/petting_zoo/model/model.py index dee3ee836..0ff96a06a 100644 --- a/zoo/petting_zoo/model/model.py +++ b/zoo/petting_zoo/model/model.py @@ -13,18 +13,29 @@ class PettingZooEncoder(nn.Module): - def __init__(self): + def __init__(self, cfg): super().__init__() - self.agent_encoder = RepresentationNetworkMLP(observation_shape=18, hidden_channels=128, norm_type='BN') - self.global_encoder = RepresentationNetworkMLP(observation_shape=30, hidden_channels=128, norm_type='BN') - self.encoder = RepresentationNetworkMLP(observation_shape=512, hidden_channels=128, norm_type='BN') + self.agent_num = cfg.policy.model.agent_num + agent_obs_shape = cfg.policy.model.agent_obs_shape + global_obs_shape = cfg.policy.model.global_obs_shape + self.agent_encoder = RepresentationNetworkMLP(observation_shape=agent_obs_shape, + hidden_channels=128, + norm_type='BN') + + self.global_encoder = RepresentationNetworkMLP(observation_shape=global_obs_shape, + hidden_channels=128, + norm_type='BN') + + self.encoder = RepresentationNetworkMLP(observation_shape=128+128*self.agent_num, + hidden_channels=128, + norm_type='BN') def forward(self, x): # agent batch_size, agent_num = x['agent_state'].shape[0], x['agent_state'].shape[1] agent_state = x['agent_state'].reshape(batch_size*agent_num, -1) agent_state = self.agent_encoder(agent_state) - agent_state_B = agent_state.reshape(batch_size, -1) # [8, 768] + agent_state_B = agent_state.reshape(batch_size, -1) agent_state_B_A = agent_state.reshape(batch_size, agent_num, -1) # global global_state = self.global_encoder(x['global_state']) @@ -36,7 +47,8 @@ class PettingZooPrediction(nn.Module): def __init__( self, - action_space_size: int=125, + cfg, + action_space_size: int=5, num_channels: int=128, common_layer_num: int = 2, fc_value_layers: SequenceType = [32], @@ -65,6 +77,8 @@ def __init__( """ super().__init__() self.num_channels = num_channels + self.agent_num = cfg.policy.model.agent_num + self.action_space_size = pow(action_space_size, self.agent_num) # ******* common backbone ****** self.fc_prediction_common = MLP( @@ -94,9 +108,9 @@ def __init__( last_linear_layer_init_zero=last_linear_layer_init_zero ) self.fc_policy_head = MLP( - in_channels=self.num_channels*3, + in_channels=self.num_channels*self.agent_num, hidden_channels=fc_policy_layers[0], - out_channels=action_space_size, + out_channels=self.action_space_size, layer_num=len(fc_policy_layers) + 1, activation=activation, norm_type=norm_type, @@ -128,7 +142,8 @@ class PettingZooDynamics(nn.Module): def __init__( self, - action_encoding_dim: int = 125, + cfg, + action_encoding_dim: int = 5, num_channels: int = 253, common_layer_num: int = 2, fc_reward_layers: SequenceType = [32], @@ -156,8 +171,9 @@ def __init__( - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. """ super().__init__() - self.num_channels = num_channels - self.action_encoding_dim = action_encoding_dim + self.agent_num = cfg.policy.model.agent_num + self.action_encoding_dim = pow(action_encoding_dim, self.agent_num) + self.num_channels = 128 + self.action_encoding_dim self.latent_state_dim = self.num_channels - self.action_encoding_dim self.res_connection_in_dynamics = res_connection_in_dynamics @@ -187,46 +203,20 @@ def __init__( last_linear_layer_init_zero=False, ) else: - self.fc_dynamics_1 = MLP( - in_channels=self.num_channels, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - - self.fc_dynamics_2 = MLP( - in_channels=self.num_channels, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - - self.fc_dynamics_3 = MLP( - in_channels=self.num_channels, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) + self.fc_dynamics_list = nn.ModuleList( + MLP(in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) for _ in range(self.agent_num)) - self.fc_dynamics_4 = MLP( + self.fc_dynamics_global = MLP( in_channels=self.num_channels, hidden_channels=self.latent_state_dim, layer_num=common_layer_num, @@ -272,12 +262,10 @@ def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, to next_latent_state_encoding = self.fc_dynamics_2(next_latent_state) else: batch_size = state_action_encoding.shape[0] - next_agent_latent_state_1 = self.fc_dynamics_1(state_action_encoding) - next_agent_latent_state_2 = self.fc_dynamics_2(state_action_encoding) - next_agent_latent_state_3 = self.fc_dynamics_3(state_action_encoding) - next_agent_latent_state = torch.stack((next_agent_latent_state_1, next_agent_latent_state_2, next_agent_latent_state_3), dim=1) + next_agent_latent_list = [self.fc_dynamics_list[i](state_action_encoding) for i in range(self.agent_num)] + next_agent_latent_state = torch.stack(next_agent_latent_list, dim=1) next_agent_latent_state = next_agent_latent_state.reshape(batch_size, -1) - next_global_latent_state = self.fc_dynamics_4(state_action_encoding) + next_global_latent_state = self.fc_dynamics_global(state_action_encoding) next_latent_state_encoding = next_global_latent_state reward = self.fc_reward_head(next_latent_state_encoding) From 60f0832b7b291b4c81271ad5ae5662da52f5bf02 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Fri, 20 Oct 2023 16:08:42 +0800 Subject: [PATCH 04/12] fix(yzj): fix visualization --- zoo/petting_zoo/entry/eval_muzero.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/zoo/petting_zoo/entry/eval_muzero.py b/zoo/petting_zoo/entry/eval_muzero.py index 7eb3e4d17..d535c0f9d 100644 --- a/zoo/petting_zoo/entry/eval_muzero.py +++ b/zoo/petting_zoo/entry/eval_muzero.py @@ -13,30 +13,31 @@ from ding.worker import BaseLearner from tensorboardX import SummaryWriter -from lzero.worker import MuZeroCollector as Collector -from lzero.worker import MuZeroEvaluator as Evaluator from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature -from lzero.entry.utils import random_collect -from zoo.petting_zoo.model import PettingZooEncoder +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator +from zoo.petting_zoo.model import PettingZooEncoder, PettingZooPrediction, PettingZooDynamics def eval_muzero(main_cfg, create_cfg, seed=0): - assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'], \ - "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'" + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \ + "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'" - if create_cfg.policy.type == 'muzero' or create_cfg.policy.type == 'multi_agent_muzero': + if create_cfg.policy.type == 'muzero': from lzero.mcts import MuZeroGameBuffer as GameBuffer from lzero.model.muzero_model_mlp import MuZeroModelMLP as Encoder - elif create_cfg.policy.type == 'efficientzero' or create_cfg.policy.type == 'multi_agent_efficientzero': + elif create_cfg.policy.type == 'efficientzero': from lzero.mcts import EfficientZeroGameBuffer as GameBuffer - from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP as Encoder elif create_cfg.policy.type == 'sampled_efficientzero': from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'gumbel_muzero': from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'stochastic_muzero': + from lzero.mcts import StochasticMuZeroGameBuffer as GameBuffer main_cfg.policy.device = 'cpu' - main_cfg.policy.load_path = 'exp_name/ckpt/ckpt_best.pth.tar' + main_cfg.policy.load_path = '/Users/yangzhenjie/code/LightZero/zoo/petting_zoo/entry/ckpt_best.pth.tar' main_cfg.env.replay_path = './' # when visualize must set as base create_cfg.env_manager.type = 'base' # when visualize must set as base main_cfg.env.evaluator_env_num = 1 # only 1 env for save replay @@ -51,7 +52,7 @@ def eval_muzero(main_cfg, create_cfg, seed=0): evaluator_env.seed(cfg.seed, dynamic_seed=False) set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder()) + model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg), state_prediction=PettingZooPrediction(cfg), state_dynamics=PettingZooDynamics(cfg)) policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location=cfg.policy.device)) @@ -77,5 +78,5 @@ def eval_muzero(main_cfg, create_cfg, seed=0): return stop, reward if __name__ == '__main__': - from zoo.petting_zoo.config.ptz_simple_spread_ez_config import main_config, create_config + from zoo.petting_zoo.config.ptz_simple_mz_config import main_config, create_config eval_muzero(main_config, create_config, seed=0) \ No newline at end of file From 6f801735f1ef62131706d464266510e2290edc76 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Fri, 20 Oct 2023 21:17:31 +0800 Subject: [PATCH 05/12] fix(yzj): fix combinations --- .../envs/petting_zoo_simple_spread_env.py | 61 +++++++++---------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index 556c78907..a76e1b651 100644 --- a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -23,9 +23,8 @@ class PettingZooEnv(BaseEnv): def __init__(self, cfg: dict) -> None: self._cfg = cfg self._init_flag = False - self._replay_path = None - # self._replay_path = self._cfg.get('replay_path', None) - # self.frame_list = [] + self._replay_path = self._cfg.get('replay_path', None) + self.frame_list = [] self._env_family = self._cfg.env_family self._env_id = self._cfg.env_id self._num_agents = self._cfg.n_agent @@ -40,7 +39,7 @@ def __init__(self, cfg: dict) -> None: # joint action import itertools action_space = [0, 1, 2, 3, 4] - self.combinations = list(itertools.product(action_space, repeat=3)) + self.combinations = list(itertools.product(action_space, repeat=self._num_agents)) def reset(self) -> np.ndarray: if not self._init_flag: @@ -64,10 +63,10 @@ def reset(self) -> np.ndarray: # if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: # np_seed = 100 * np.random.randint(1, 1000) # self._env.seed(self._seed + np_seed) - if self._replay_path is not None: - self._env = gym.wrappers.Monitor( - self._env, self._replay_path, video_callable=lambda episode_id: True, force=True - ) + # if self._replay_path is not None: + # self._env = gym.wrappers.Monitor( + # self._env, self._replay_path, video_callable=lambda episode_id: True, force=True + # ) if hasattr(self, '_seed'): obs = self._env.reset(seed=self._seed) else: @@ -208,12 +207,12 @@ def step(self, action: int) -> BaseEnvTimestep: # for agent in self._agents: # self._eval_episode_return[agent] += rew[agent] - # if self._replay_path is not None: - # self.frame_list.append(Image.fromarray(self._env.render())) + if self._replay_path is not None: + self.frame_list.append(Image.fromarray(self._env.render())) if done_n: # or reduce(lambda x, y: x and y, done.values()) info['eval_episode_return'] = self._eval_episode_return - # if self._replay_path is not None: - # self.frame_list[0].save('out.gif', save_all=True, append_images=self.frame_list[1:], duration=3, loop=0) + if self._replay_path is not None: + self.frame_list[0].save('out.gif', save_all=True, append_images=self.frame_list[1:], duration=3, loop=0) # for agent in rew: # rew[agent] = to_ndarray([rew[agent]]) return BaseEnvTimestep(obs_n, rew_n, done_n, info) @@ -333,7 +332,7 @@ def __init__(self, N=3, local_ratio=0.5, max_cycles=25, continuous_actions=False scenario = Scenario() world = scenario.make_world(N) super().__init__(scenario, world, max_cycles, continuous_actions=continuous_actions, local_ratio=local_ratio) - # self.render_mode = 'rgb_array' + self.render_mode = 'rgb_array' self.metadata['name'] = "simple_spread_v2" def _execute_world_step(self): @@ -370,21 +369,21 @@ def _execute_world_step(self): self.rewards[agent.name] = reward - # def render(self): - # if self.render_mode is None: - # gym.logger.warn( - # "You are calling render method without specifying any render mode." - # ) - # return - - # self.enable_render(self.render_mode) - - # self.draw() - # observation = np.array(pygame.surfarray.pixels3d(self.screen)) - # if self.render_mode == "human": - # pygame.display.flip() - # return ( - # np.transpose(observation, axes=(1, 0, 2)) - # if self.render_mode == "rgb_array" - # else None - # ) + def render(self): + if self.render_mode is None: + gym.logger.warn( + "You are calling render method without specifying any render mode." + ) + return + + self.enable_render(self.render_mode) + + self.draw() + observation = np.array(pygame.surfarray.pixels3d(self.screen)) + if self.render_mode == "human": + pygame.display.flip() + return ( + np.transpose(observation, axes=(1, 0, 2)) + if self.render_mode == "rgb_array" + else None + ) \ No newline at end of file From 08ef2bad78b9a6dab7c64da922b32cfe12e98574 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Sat, 21 Oct 2023 23:57:48 +0800 Subject: [PATCH 06/12] feature(yzj): add ptz simple env --- lzero/mcts/buffer/game_buffer_muzero.py | 10 +++++----- lzero/model/muzero_model_mlp.py | 4 ++-- .../config/ptz_simple_mz_config.py | 6 +++--- zoo/petting_zoo/entry/train_muzero.py | 3 ++- zoo/petting_zoo/model/model.py | 19 ++++++++++--------- 5 files changed, 22 insertions(+), 20 deletions(-) diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index e0759ecff..aa92814fd 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -201,12 +201,12 @@ def _prepare_reward_value_context( td_steps_list, action_mask_segment, to_play_segment """ zero_obs = game_segment_list[0].zero_obs() - zero_obs = np.array([{'agent_state': np.zeros((3, 18), dtype=np.float32), - 'global_state': np.zeros((30,), dtype=np.float32), - 'agent_alone_state': np.zeros((3, 14), dtype=np.float32), - 'agent_alone_padding_state': np.zeros((3, 18), dtype=np.float32),}]) + # zero_obs = np.array([{'agent_state': np.zeros((3, 18), dtype=np.float32), + # 'global_state': np.zeros((30,), dtype=np.float32), + # 'agent_alone_state': np.zeros((3, 14), dtype=np.float32), + # 'agent_alone_padding_state': np.zeros((3, 18), dtype=np.float32),}]) zero_obs = np.array([{'agent_state': np.zeros((1, 6), dtype=np.float32), - 'global_state': np.zeros((8,), dtype=np.float32), + 'global_state': np.zeros((1, 14), dtype=np.float32), 'agent_alone_state': np.zeros((1, 12), dtype=np.float32), 'agent_alone_padding_state': np.zeros((1, 12), dtype=np.float32),}]) value_obs_list = [] diff --git a/lzero/model/muzero_model_mlp.py b/lzero/model/muzero_model_mlp.py index fad6ef59c..ecde46a32 100644 --- a/lzero/model/muzero_model_mlp.py +++ b/lzero/model/muzero_model_mlp.py @@ -186,7 +186,7 @@ def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: value, [0. for _ in range(batch_size)], policy_logits, - latent_state[1], + latent_state, ) def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor) -> MZNetworkOutput: @@ -214,7 +214,7 @@ def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor) """ next_latent_state, reward = self._dynamics(latent_state, action) policy_logits, value = self._prediction(next_latent_state) - return MZNetworkOutput(value, reward, policy_logits, next_latent_state[1]) + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: """ diff --git a/zoo/petting_zoo/config/ptz_simple_mz_config.py b/zoo/petting_zoo/config/ptz_simple_mz_config.py index 0c5338485..c4773c9fd 100644 --- a/zoo/petting_zoo/config/ptz_simple_mz_config.py +++ b/zoo/petting_zoo/config/ptz_simple_mz_config.py @@ -32,7 +32,7 @@ n_landmark=n_landmark, max_cycles=25, agent_obs_only=False, - agent_specific_global_state=False, + agent_specific_global_state=True, continuous_actions=False, stop_value=0, collector_env_num=collector_env_num, @@ -52,7 +52,7 @@ agent_num=n_agent, self_supervised_learning_loss=False, # default is False agent_obs_shape=6, - global_obs_shape=8, + global_obs_shape=14, discrete_action_encoding_type='one_hot', global_cooperation=True, # TODO: doesn't work now hidden_size_list=[256, 256], @@ -97,7 +97,7 @@ import_names=['zoo.petting_zoo.envs.petting_zoo_simple_spread_env'], type='petting_zoo', ), - env_manager=dict(type='base'), + env_manager=dict(type='subprocess'), policy=dict( type='muzero', import_names=['lzero.policy.muzero'], diff --git a/zoo/petting_zoo/entry/train_muzero.py b/zoo/petting_zoo/entry/train_muzero.py index 3789bce38..905f1469e 100644 --- a/zoo/petting_zoo/entry/train_muzero.py +++ b/zoo/petting_zoo/entry/train_muzero.py @@ -79,7 +79,8 @@ def train_muzero( evaluator_env.seed(cfg.seed, dynamic_seed=False) set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg), state_prediction=PettingZooPrediction(cfg), state_dynamics=PettingZooDynamics(cfg)) + # model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg), state_prediction=PettingZooPrediction(cfg), state_dynamics=PettingZooDynamics(cfg)) + model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg)) policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) # load pretrained model diff --git a/zoo/petting_zoo/model/model.py b/zoo/petting_zoo/model/model.py index 0ff96a06a..a108b4307 100644 --- a/zoo/petting_zoo/model/model.py +++ b/zoo/petting_zoo/model/model.py @@ -23,7 +23,7 @@ def __init__(self, cfg): norm_type='BN') self.global_encoder = RepresentationNetworkMLP(observation_shape=global_obs_shape, - hidden_channels=128, + hidden_channels=256, norm_type='BN') self.encoder = RepresentationNetworkMLP(observation_shape=128+128*self.agent_num, @@ -32,15 +32,16 @@ def __init__(self, cfg): def forward(self, x): # agent - batch_size, agent_num = x['agent_state'].shape[0], x['agent_state'].shape[1] - agent_state = x['agent_state'].reshape(batch_size*agent_num, -1) - agent_state = self.agent_encoder(agent_state) - agent_state_B = agent_state.reshape(batch_size, -1) - agent_state_B_A = agent_state.reshape(batch_size, agent_num, -1) + batch_size, agent_num = x['global_state'].shape[0], x['global_state'].shape[1] + latent_state = x['global_state'].reshape(batch_size*agent_num, -1) + latent_state = self.global_encoder(latent_state) + return latent_state + # agent_state_B = agent_state.reshape(batch_size, -1) + # agent_state_B_A = agent_state.reshape(batch_size, agent_num, -1) # global - global_state = self.global_encoder(x['global_state']) - global_state = self.encoder(torch.cat((agent_state_B, global_state),dim=1)) - return (agent_state_B, global_state) + # global_state = self.global_encoder(x['global_state']) + # global_state = self.encoder(torch.cat((agent_state_B, global_state),dim=1)) + # return (agent_state_B, global_state) class PettingZooPrediction(nn.Module): From 0e6dfd32a5181bdbd8b869e86df6e6128c790a47 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Mon, 23 Oct 2023 15:12:25 +0800 Subject: [PATCH 07/12] feature(yzj): ptz simple mz cfg is ready and add ptz simple ez cfg --- .../mcts/buffer/game_buffer_efficientzero.py | 22 +++- lzero/mcts/buffer/game_buffer_muzero.py | 4 +- lzero/model/efficientzero_model_mlp.py | 70 ++++++----- lzero/policy/efficientzero.py | 19 ++- .../config/ptz_simple_ez_config.py | 115 ++++++++++++++++++ .../config/ptz_simple_mz_config.py | 5 +- .../config/ptz_simple_spread_mz_config.py | 8 +- zoo/petting_zoo/entry/train_muzero.py | 1 + .../envs/petting_zoo_simple_spread_env.py | 11 +- zoo/petting_zoo/model/model.py | 4 +- 10 files changed, 209 insertions(+), 50 deletions(-) create mode 100644 zoo/petting_zoo/config/ptz_simple_ez_config.py diff --git a/lzero/mcts/buffer/game_buffer_efficientzero.py b/lzero/mcts/buffer/game_buffer_efficientzero.py index 4ab12259e..ae6b43786 100644 --- a/lzero/mcts/buffer/game_buffer_efficientzero.py +++ b/lzero/mcts/buffer/game_buffer_efficientzero.py @@ -9,6 +9,8 @@ from lzero.mcts.utils import prepare_observation from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer_muzero import MuZeroGameBuffer +from ding.torch_utils import to_device, to_tensor +from ding.utils.data import default_collate @BUFFER_REGISTRY.register('game_buffer_efficientzero') @@ -100,7 +102,15 @@ def _prepare_reward_value_context( - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, to_play_segment """ - zero_obs = game_segment_list[0].zero_obs() + # zero_obs = game_segment_list[0].zero_obs() + # zero_obs = np.array([{'agent_state': np.zeros((3, 18), dtype=np.float32), + # 'global_state': np.zeros((84,), dtype=np.float32), + # 'agent_alone_state': np.zeros((3, 14), dtype=np.float32), + # 'agent_alone_padding_state': np.zeros((3, 18), dtype=np.float32),}]) + zero_obs = np.array([{'agent_state': np.zeros((1, 6), dtype=np.float32), + 'global_state': np.zeros((14, ), dtype=np.float32), + 'agent_alone_state': np.zeros((1, 12), dtype=np.float32), + 'agent_alone_padding_state': np.zeros((1, 12), dtype=np.float32),}]) value_obs_list = [] # the value is valid or not (out of trajectory) value_mask = [] @@ -152,7 +162,7 @@ def _prepare_reward_value_context( value_mask.append(0) obs = zero_obs - value_obs_list.append(obs) + value_obs_list.append(obs.tolist()) reward_value_context = [ value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, @@ -196,7 +206,13 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A beg_index = self._cfg.mini_infer_size * i end_index = self._cfg.mini_infer_size * (i + 1) - m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + if self._cfg.model.model_type and self._cfg.model.model_type in ['conv', 'mlp']: + m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + elif self._cfg.model.model_type and self._cfg.model.model_type == 'structure': + m_obs = value_obs_list[beg_index:end_index] + m_obs = sum(m_obs, []) + m_obs = default_collate(m_obs) + m_obs = to_device(m_obs, self._cfg.device) # calculate the target value m_output = model.initial_inference(m_obs) diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index aa92814fd..15177a4e2 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -202,11 +202,11 @@ def _prepare_reward_value_context( """ zero_obs = game_segment_list[0].zero_obs() # zero_obs = np.array([{'agent_state': np.zeros((3, 18), dtype=np.float32), - # 'global_state': np.zeros((30,), dtype=np.float32), + # 'global_state': np.zeros((84,), dtype=np.float32), # 'agent_alone_state': np.zeros((3, 14), dtype=np.float32), # 'agent_alone_padding_state': np.zeros((3, 18), dtype=np.float32),}]) zero_obs = np.array([{'agent_state': np.zeros((1, 6), dtype=np.float32), - 'global_state': np.zeros((1, 14), dtype=np.float32), + 'global_state': np.zeros((14, ), dtype=np.float32), 'agent_alone_state': np.zeros((1, 12), dtype=np.float32), 'agent_alone_padding_state': np.zeros((1, 12), dtype=np.float32),}]) value_obs_list = [] diff --git a/lzero/model/efficientzero_model_mlp.py b/lzero/model/efficientzero_model_mlp.py index a491cdb75..beed7830a 100644 --- a/lzero/model/efficientzero_model_mlp.py +++ b/lzero/model/efficientzero_model_mlp.py @@ -8,6 +8,7 @@ from .common import EZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean +from ding.utils.default_helper import get_shape0 @MODEL_REGISTRY.register('EfficientZeroModelMLP') @@ -36,6 +37,9 @@ def __init__( norm_type: Optional[str] = 'BN', discrete_action_encoding_type: str = 'one_hot', res_connection_in_dynamics: bool = False, + state_encoder=None, + state_prediction=None, + state_dynamics=None, *args, **kwargs, ): @@ -104,31 +108,40 @@ def __init__( self.state_norm = state_norm self.res_connection_in_dynamics = res_connection_in_dynamics - self.representation_network = RepresentationNetworkMLP( - observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type - ) - - self.dynamics_network = DynamicsNetworkMLP( - action_encoding_dim=self.action_encoding_dim, - num_channels=latent_state_dim + self.action_encoding_dim, - common_layer_num=2, - lstm_hidden_size=lstm_hidden_size, - fc_reward_layers=fc_reward_layers, - output_support_size=self.reward_support_size, - last_linear_layer_init_zero=self.last_linear_layer_init_zero, - norm_type=norm_type, - res_connection_in_dynamics=self.res_connection_in_dynamics, - ) - - self.prediction_network = PredictionNetworkMLP( - action_space_size=action_space_size, - num_channels=latent_state_dim, - fc_value_layers=fc_value_layers, - fc_policy_layers=fc_policy_layers, - output_support_size=self.value_support_size, - last_linear_layer_init_zero=self.last_linear_layer_init_zero, - norm_type=norm_type - ) + if state_encoder == None: + self.representation_network = RepresentationNetworkMLP( + observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type + ) + else: + self.representation_network = state_encoder + + if state_dynamics == None: + self.dynamics_network = DynamicsNetworkMLP( + action_encoding_dim=self.action_encoding_dim, + num_channels=latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + lstm_hidden_size=lstm_hidden_size, + fc_reward_layers=fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) + else: + self.dynamics_network = state_dynamics + + if state_prediction == None: + self.prediction_network = PredictionNetworkMLP( + action_space_size=action_space_size, + num_channels=latent_state_dim, + fc_value_layers=fc_value_layers, + fc_policy_layers=fc_policy_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type + ) + else: + self.prediction_network = state_prediction if self.self_supervised_learning_loss: # self_supervised_learning_loss related network proposed in EfficientZero @@ -171,15 +184,16 @@ def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. """ - batch_size = obs.size(0) + batch_size = get_shape0(obs) latent_state = self._representation(obs) + device = latent_state.device policy_logits, value = self._prediction(latent_state) # zero initialization for reward hidden states # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) reward_hidden_state = ( torch.zeros(1, batch_size, - self.lstm_hidden_size).to(obs.device), torch.zeros(1, batch_size, - self.lstm_hidden_size).to(obs.device) + self.lstm_hidden_size).to(device), torch.zeros(1, batch_size, + self.lstm_hidden_size).to(device) ) return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state) diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index b18a9e297..b160f4dd6 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -18,6 +18,8 @@ prepare_obs, \ configure_optimizers from lzero.policy.muzero import MuZeroPolicy +from ding.utils.data import default_collate +from ding.torch_utils import to_device, to_tensor @POLICY_REGISTRY.register('efficientzero') @@ -307,7 +309,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1) target_value = target_value.view(self._cfg.batch_size, -1) - assert obs_batch.size(0) == self._cfg.batch_size == target_value_prefix.size(0) + # assert obs_batch.size(0) == self._cfg.batch_size == target_value_prefix.size(0) # ``scalar_transform`` to transform the original value to the scaled value, # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. @@ -562,7 +564,13 @@ def _forward_collect( self._collect_model.eval() self._collect_mcts_temperature = temperature self.collect_epsilon = epsilon - active_collect_env_num = data.shape[0] + active_collect_env_num = len(data) + # + data = sum(data, []) + data = default_collate(data) + data = to_device(data, self._device) + to_play = np.array(to_play).reshape(-1).tolist() + with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._collect_model.initial_inference(data) @@ -667,7 +675,12 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. """ self._eval_model.eval() - active_eval_env_num = data.shape[0] + active_eval_env_num = len(data) + # + data = sum(data, []) + data = default_collate(data) + data = to_device(data, self._device) + to_play = np.array(to_play).reshape(-1).tolist() with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._eval_model.initial_inference(data) diff --git a/zoo/petting_zoo/config/ptz_simple_ez_config.py b/zoo/petting_zoo/config/ptz_simple_ez_config.py new file mode 100644 index 000000000..d691f68cf --- /dev/null +++ b/zoo/petting_zoo/config/ptz_simple_ez_config.py @@ -0,0 +1,115 @@ +from easydict import EasyDict + +env_name = 'ptz_simple' +multi_agent = True + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +seed = 0 +n_agent = 1 +n_landmark = n_agent +collector_env_num = 8 +evaluator_env_num = 8 +n_episode = 8 +batch_size = 256 +num_simulations = 50 +update_per_collect = 50 +reanalyze_ratio = 0. +action_space_size = 5 +eps_greedy_exploration_in_collect = True +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +main_config = dict( + exp_name= + f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', + env=dict( + env_family='mpe', + env_id='simple_v2', + n_agent=n_agent, + n_landmark=n_landmark, + max_cycles=25, + agent_obs_only=False, + agent_specific_global_state=True, + continuous_actions=False, + stop_value=0, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + multi_agent=multi_agent, + ignore_done=False, + model=dict( + model_type='structure', + latent_state_dim=256, + action_space='discrete', + action_space_size=action_space_size, + agent_num=n_agent, + self_supervised_learning_loss=False, # default is False + agent_obs_shape=6, + global_obs_shape=14, + discrete_action_encoding_type='one_hot', + global_cooperation=True, # TODO: doesn't work now + hidden_size_list=[256, 256], + norm_type='BN', + ), + cuda=True, + mcts_ctree=True, + gumbel_algo=False, + env_type='not_board_games', + game_segment_length=30, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + type='linear', + start=1., + end=0.05, + decay=int(2e5), + ), + use_augmentation=False, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + ssl_loss_weight=0, # NOTE: default is 0. + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + learn=dict(learner=dict( + log_policy=True, + hook=dict(log_show_after_iter=10, ), + ), ), +) +main_config = EasyDict(main_config) +create_config = dict( + env=dict( + import_names=['zoo.petting_zoo.envs.petting_zoo_simple_spread_env'], + type='petting_zoo', + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero', + import_names=['lzero.policy.efficientzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +create_config = EasyDict(create_config) +ptz_simple_spread_efficientzero_config = main_config +ptz_simple_spread_efficientzero_create_config = create_config + +if __name__ == "__main__": + from zoo.petting_zoo.entry import train_muzero + train_muzero([main_config, create_config], seed=seed) diff --git a/zoo/petting_zoo/config/ptz_simple_mz_config.py b/zoo/petting_zoo/config/ptz_simple_mz_config.py index c4773c9fd..45343f978 100644 --- a/zoo/petting_zoo/config/ptz_simple_mz_config.py +++ b/zoo/petting_zoo/config/ptz_simple_mz_config.py @@ -46,7 +46,6 @@ model=dict( model_type='structure', latent_state_dim=256, - frame_stack_num=1, action_space='discrete', action_space_size=action_space_size, agent_num=n_agent, @@ -69,7 +68,7 @@ type='linear', start=1., end=0.05, - decay=int(1e5), + decay=int(2e5), ), use_augmentation=False, update_per_collect=update_per_collect, @@ -111,6 +110,6 @@ ptz_simple_spread_muzero_config = main_config ptz_simple_spread_muzero_create_config = create_config -if __name__ == '__main__': +if __name__ == "__main__": from zoo.petting_zoo.entry import train_muzero train_muzero([main_config, create_config], seed=seed) diff --git a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py index 4655a85b6..c93cb0ada 100644 --- a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py +++ b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py @@ -13,7 +13,7 @@ evaluator_env_num = 8 n_episode = 8 batch_size = 256 -num_simulations = 50 +num_simulations = 200 update_per_collect = 50 reanalyze_ratio = 0. action_space_size = 5*5*5 @@ -32,7 +32,7 @@ n_landmark=n_landmark, max_cycles=25, agent_obs_only=False, - agent_specific_global_state=False, + agent_specific_global_state=True, continuous_actions=False, stop_value=0, collector_env_num=collector_env_num, @@ -52,7 +52,7 @@ agent_num=n_agent, self_supervised_learning_loss=False, # default is False agent_obs_shape=18, - global_obs_shape=30, + global_obs_shape=18*n_agent+30, # 84 discrete_action_encoding_type='one_hot', global_cooperation=True, # TODO: doesn't work now hidden_size_list=[256, 256], @@ -69,7 +69,7 @@ type='linear', start=1., end=0.05, - decay=int(1e5), + decay=int(2e5), ), use_augmentation=False, update_per_collect=update_per_collect, diff --git a/zoo/petting_zoo/entry/train_muzero.py b/zoo/petting_zoo/entry/train_muzero.py index 905f1469e..1c32d89fe 100644 --- a/zoo/petting_zoo/entry/train_muzero.py +++ b/zoo/petting_zoo/entry/train_muzero.py @@ -56,6 +56,7 @@ def train_muzero( from lzero.model.muzero_model_mlp import MuZeroModelMLP as Encoder elif create_cfg.policy.type == 'efficientzero': from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP as Encoder elif create_cfg.policy.type == 'sampled_efficientzero': from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'gumbel_muzero': diff --git a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index a76e1b651..69b366b55 100644 --- a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -254,11 +254,12 @@ def _process_obs(self, obs: 'torch.Tensor') -> np.ndarray: # noqa # - agent_state info # - global_state info if self._agent_specific_global_state: - ret['global_state'] = np.concatenate( - [ret['agent_state'], - np.expand_dims(ret['global_state'], axis=0).repeat(self._num_agents, axis=0)], - axis=1 - ) + ret['global_state'] = np.concatenate((np.concatenate(ret['agent_state']), ret['global_state'])) + # ret['global_state'] = np.concatenate( + # [ret['agent_state'], + # np.expand_dims(ret['global_state'], axis=0).repeat(self._num_agents, axis=0)], + # axis=1 + # ) # agent_alone_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2). # Stacked observation. Exclude other agents' positions from agent_state. Contains # - agent itself's state(velocity + position) + diff --git a/zoo/petting_zoo/model/model.py b/zoo/petting_zoo/model/model.py index a108b4307..2bf02a64c 100644 --- a/zoo/petting_zoo/model/model.py +++ b/zoo/petting_zoo/model/model.py @@ -32,8 +32,8 @@ def __init__(self, cfg): def forward(self, x): # agent - batch_size, agent_num = x['global_state'].shape[0], x['global_state'].shape[1] - latent_state = x['global_state'].reshape(batch_size*agent_num, -1) + batch_size = x['global_state'].shape[0] + latent_state = x['global_state'].reshape(batch_size, -1) latent_state = self.global_encoder(latent_state) return latent_state # agent_state_B = agent_state.reshape(batch_size, -1) From c323a445b61acaf71eba369d1868cc573d5f8d6b Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Sun, 29 Oct 2023 18:16:00 +0800 Subject: [PATCH 08/12] fix(yzj): fix ptz simple ez eval muzero --- zoo/petting_zoo/entry/eval_muzero.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/zoo/petting_zoo/entry/eval_muzero.py b/zoo/petting_zoo/entry/eval_muzero.py index d535c0f9d..a6947e9c4 100644 --- a/zoo/petting_zoo/entry/eval_muzero.py +++ b/zoo/petting_zoo/entry/eval_muzero.py @@ -29,6 +29,7 @@ def eval_muzero(main_cfg, create_cfg, seed=0): from lzero.model.muzero_model_mlp import MuZeroModelMLP as Encoder elif create_cfg.policy.type == 'efficientzero': from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP as Encoder elif create_cfg.policy.type == 'sampled_efficientzero': from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'gumbel_muzero': @@ -52,7 +53,8 @@ def eval_muzero(main_cfg, create_cfg, seed=0): evaluator_env.seed(cfg.seed, dynamic_seed=False) set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg), state_prediction=PettingZooPrediction(cfg), state_dynamics=PettingZooDynamics(cfg)) + # model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg), state_prediction=PettingZooPrediction(cfg), state_dynamics=PettingZooDynamics(cfg)) + model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg)) policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location=cfg.policy.device)) @@ -78,5 +80,6 @@ def eval_muzero(main_cfg, create_cfg, seed=0): return stop, reward if __name__ == '__main__': - from zoo.petting_zoo.config.ptz_simple_mz_config import main_config, create_config + # from zoo.petting_zoo.config.ptz_simple_mz_config import main_config, create_config + from zoo.petting_zoo.config.ptz_simple_ez_config import main_config, create_config eval_muzero(main_config, create_config, seed=0) \ No newline at end of file From 6ea3f9bb3069c61e6c97cf05a7ffe4a0b200dfa8 Mon Sep 17 00:00:00 2001 From: chosenone Date: Tue, 21 Nov 2023 23:15:49 +0800 Subject: [PATCH 09/12] feature(yzj): polish ctde2-(8,3,5) --- .../mcts/buffer/game_buffer_efficientzero.py | 33 +- lzero/mcts/buffer/game_buffer_muzero.py | 32 +- lzero/mcts/tree_search/mcts_ctree.py | 46 ++- lzero/model/efficientzero_model_mlp.py | 40 +- lzero/model/muzero_model_mlp.py | 37 +- lzero/policy/efficientzero.py | 98 +++-- lzero/policy/muzero.py | 148 ++++--- lzero/worker/muzero_collector.py | 280 +++++++++---- .../config/ptz_simple_ez_config.py | 11 +- .../config/ptz_simple_mz_config.py | 5 +- .../config/ptz_simple_spread_ez_config.py | 116 ++++++ .../config/ptz_simple_spread_mz_config.py | 14 +- zoo/petting_zoo/entry/eval_muzero.py | 2 +- zoo/petting_zoo/entry/train_muzero.py | 10 +- .../envs/petting_zoo_simple_spread_env.py | 29 +- zoo/petting_zoo/model/__init__.py | 2 +- zoo/petting_zoo/model/model.py | 370 ++++++++++++++---- 17 files changed, 941 insertions(+), 332 deletions(-) create mode 100644 zoo/petting_zoo/config/ptz_simple_spread_ez_config.py diff --git a/lzero/mcts/buffer/game_buffer_efficientzero.py b/lzero/mcts/buffer/game_buffer_efficientzero.py index ae6b43786..5f0273712 100644 --- a/lzero/mcts/buffer/game_buffer_efficientzero.py +++ b/lzero/mcts/buffer/game_buffer_efficientzero.py @@ -102,15 +102,15 @@ def _prepare_reward_value_context( - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, to_play_segment """ - # zero_obs = game_segment_list[0].zero_obs() - # zero_obs = np.array([{'agent_state': np.zeros((3, 18), dtype=np.float32), - # 'global_state': np.zeros((84,), dtype=np.float32), - # 'agent_alone_state': np.zeros((3, 14), dtype=np.float32), - # 'agent_alone_padding_state': np.zeros((3, 18), dtype=np.float32),}]) - zero_obs = np.array([{'agent_state': np.zeros((1, 6), dtype=np.float32), + zero_obs = game_segment_list[0].zero_obs() + # zero_obs = np.array([{'agent_state': np.zeros((18,), dtype=np.float32), + # 'global_state': np.zeros((48,), dtype=np.float32), + # 'agent_alone_state': np.zeros((14,), dtype=np.float32), + # 'agent_alone_padding_state': np.zeros((18,), dtype=np.float32),}]) + zero_obs = np.array([{'agent_state': np.zeros((6,), dtype=np.float32), 'global_state': np.zeros((14, ), dtype=np.float32), - 'agent_alone_state': np.zeros((1, 12), dtype=np.float32), - 'agent_alone_padding_state': np.zeros((1, 12), dtype=np.float32),}]) + 'agent_alone_state': np.zeros((12,), dtype=np.float32), + 'agent_alone_padding_state': np.zeros((12,), dtype=np.float32),}]) value_obs_list = [] # the value is valid or not (out of trajectory) value_mask = [] @@ -221,13 +221,16 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # EfficientZero related core code # ============================================================== # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) + # [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + # [ + # m_output.latent_state, + # inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + # m_output.policy_logits + # ] + # ) + m_output.latent_state = (to_detach_cpu_numpy(m_output.latent_state[0]), to_detach_cpu_numpy(m_output.latent_state[1])) + m_output.value = to_detach_cpu_numpy(inverse_scalar_transform(m_output.value, self._cfg.model.support_scale)) + m_output.policy_logits = to_detach_cpu_numpy(m_output.policy_logits) m_output.reward_hidden_state = ( m_output.reward_hidden_state[0].detach().cpu().numpy(), m_output.reward_hidden_state[1].detach().cpu().numpy() diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 15177a4e2..6feaf65d6 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -201,14 +201,14 @@ def _prepare_reward_value_context( td_steps_list, action_mask_segment, to_play_segment """ zero_obs = game_segment_list[0].zero_obs() - # zero_obs = np.array([{'agent_state': np.zeros((3, 18), dtype=np.float32), - # 'global_state': np.zeros((84,), dtype=np.float32), - # 'agent_alone_state': np.zeros((3, 14), dtype=np.float32), - # 'agent_alone_padding_state': np.zeros((3, 18), dtype=np.float32),}]) - zero_obs = np.array([{'agent_state': np.zeros((1, 6), dtype=np.float32), + zero_obs = np.array([{'agent_state': np.zeros((18,), dtype=np.float32), + 'global_state': np.zeros((48,), dtype=np.float32), + 'agent_alone_state': np.zeros((14,), dtype=np.float32), + 'agent_alone_padding_state': np.zeros((18,), dtype=np.float32),}]) + zero_obs = np.array([{'agent_state': np.zeros((6,), dtype=np.float32), 'global_state': np.zeros((14, ), dtype=np.float32), - 'agent_alone_state': np.zeros((1, 12), dtype=np.float32), - 'agent_alone_padding_state': np.zeros((1, 12), dtype=np.float32),}]) + 'agent_alone_state': np.zeros((12,), dtype=np.float32), + 'agent_alone_padding_state': np.zeros((12,), dtype=np.float32),}]) value_obs_list = [] # the value is valid or not (out of game_segment) value_mask = [] @@ -400,14 +400,16 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A if not model.training: # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) - + # [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + # [ + # m_output.latent_state, + # inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + # m_output.policy_logits + # ] + # ) + m_output.latent_state = (to_detach_cpu_numpy(m_output.latent_state[0]), to_detach_cpu_numpy(m_output.latent_state[1])) + m_output.value = to_detach_cpu_numpy(inverse_scalar_transform(m_output.value, self._cfg.model.support_scale)) + m_output.policy_logits = to_detach_cpu_numpy(m_output.policy_logits) network_output.append(m_output) # concat the output slices after model inference diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index 20336277c..d69781750 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -96,7 +96,9 @@ def search( pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor # the data storage of latent states: storing the latent state of all the nodes in one search. - latent_state_batch_in_search_path = [latent_state_roots] + agent_latent_state_roots, global_latent_state_roots = latent_state_roots + agent_latent_state_batch_in_search_path = [agent_latent_state_roots] + global_latent_state_batch_in_search_path = [global_latent_state_roots] # the data storage of value prefix hidden states in LSTM reward_hidden_state_c_batch = [reward_hidden_state_roots[0]] reward_hidden_state_h_batch = [reward_hidden_state_roots[1]] @@ -108,7 +110,8 @@ def search( for simulation_index in range(self._cfg.num_simulations): # In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most. - latent_states = [] + agent_latent_states = [] + global_latent_states = [] hidden_states_c_reward = [] hidden_states_h_reward = [] @@ -132,11 +135,13 @@ def search( # obtain the latent state for leaf node for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): - latent_states.append(latent_state_batch_in_search_path[ix][iy]) + agent_latent_states.append(agent_latent_state_batch_in_search_path[ix][iy]) + global_latent_states.append(global_latent_state_batch_in_search_path[ix][iy]) hidden_states_c_reward.append(reward_hidden_state_c_batch[ix][0][iy]) hidden_states_h_reward.append(reward_hidden_state_h_batch[ix][0][iy]) - latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float() + agent_latent_states = torch.from_numpy(np.asarray(agent_latent_states)).to(self._cfg.device).float() + global_latent_states = torch.from_numpy(np.asarray(global_latent_states)).to(self._cfg.device).float() hidden_states_c_reward = torch.from_numpy(np.asarray(hidden_states_c_reward)).to(self._cfg.device ).unsqueeze(0) hidden_states_h_reward = torch.from_numpy(np.asarray(hidden_states_h_reward)).to(self._cfg.device @@ -151,10 +156,12 @@ def search( At the end of the simulation, the statistics along the trajectory are updated. """ network_output = model.recurrent_inference( - latent_states, (hidden_states_c_reward, hidden_states_h_reward), last_actions + (agent_latent_states, global_latent_states), (hidden_states_c_reward, hidden_states_h_reward), last_actions ) + network_output_agent_latent_state, network_output_global_latent_state = network_output.latent_state - network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) + network_output_agent_latent_state = to_detach_cpu_numpy(network_output_agent_latent_state) + network_output_global_latent_state = to_detach_cpu_numpy(network_output_global_latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value)) network_output.value_prefix = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value_prefix)) @@ -164,7 +171,8 @@ def search( network_output.reward_hidden_state[1].detach().cpu().numpy() ) - latent_state_batch_in_search_path.append(network_output.latent_state) + agent_latent_state_batch_in_search_path.append(network_output_agent_latent_state) + global_latent_state_batch_in_search_path.append(network_output_global_latent_state) # tolist() is to be compatible with cpp datatype. value_prefix_batch = network_output.value_prefix.reshape(-1).tolist() value_batch = network_output.value.reshape(-1).tolist() @@ -273,7 +281,9 @@ def search( batch_size = roots.num pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor # the data storage of latent states: storing the latent state of all the nodes in the search. - latent_state_batch_in_search_path = [latent_state_roots] + agent_latent_state_roots, global_latent_state_roots = latent_state_roots + agent_latent_state_batch_in_search_path = [agent_latent_state_roots] + global_latent_state_batch_in_search_path = [global_latent_state_roots] # minimax value storage min_max_stats_lst = tree_muzero.MinMaxStatsList(batch_size) @@ -282,7 +292,8 @@ def search( for simulation_index in range(self._cfg.num_simulations): # In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most. - latent_states = [] + agent_latent_states = [] + global_latent_states = [] # prepare a result wrapper to transport results between python and c++ parts results = tree_muzero.ResultsWrapper(num=batch_size) @@ -302,9 +313,11 @@ def search( # obtain the latent state for leaf node for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): - latent_states.append(latent_state_batch_in_search_path[ix][iy]) + agent_latent_states.append(agent_latent_state_batch_in_search_path[ix][iy]) + global_latent_states.append(global_latent_state_batch_in_search_path[ix][iy]) - latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float() + agent_latent_states = torch.from_numpy(np.asarray(agent_latent_states)).to(self._cfg.device).float() + global_latent_states = torch.from_numpy(np.asarray(global_latent_states)).to(self._cfg.device).float() # .long() is only for discrete action last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long() """ @@ -314,14 +327,19 @@ def search( MCTS stage 3: Backup At the end of the simulation, the statistics along the trajectory are updated. """ - network_output = model.recurrent_inference(latent_states, last_actions) + network_output = model.recurrent_inference((agent_latent_states, global_latent_states), last_actions) - network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) + network_output_agent_latent_state, network_output_global_latent_state = network_output.latent_state + + # network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) + network_output_agent_latent_state = to_detach_cpu_numpy(network_output_agent_latent_state) + network_output_global_latent_state = to_detach_cpu_numpy(network_output_global_latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value)) network_output.reward = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.reward)) - latent_state_batch_in_search_path.append(network_output.latent_state) + agent_latent_state_batch_in_search_path.append(network_output_agent_latent_state) + global_latent_state_batch_in_search_path.append(network_output_global_latent_state) # tolist() is to be compatible with cpp datatype. reward_batch = network_output.reward.reshape(-1).tolist() value_batch = network_output.value.reshape(-1).tolist() diff --git a/lzero/model/efficientzero_model_mlp.py b/lzero/model/efficientzero_model_mlp.py index beed7830a..ec664906e 100644 --- a/lzero/model/efficientzero_model_mlp.py +++ b/lzero/model/efficientzero_model_mlp.py @@ -128,7 +128,17 @@ def __init__( res_connection_in_dynamics=self.res_connection_in_dynamics, ) else: - self.dynamics_network = state_dynamics + self.dynamics_network = state_dynamics( + action_encoding_dim=self.action_encoding_dim, + num_channels=latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + lstm_hidden_size=lstm_hidden_size, + fc_reward_layers=fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) if state_prediction == None: self.prediction_network = PredictionNetworkMLP( @@ -141,7 +151,16 @@ def __init__( norm_type=norm_type ) else: - self.prediction_network = state_prediction + self.prediction_network = state_prediction( + action_space_size=action_space_size, + num_channels=latent_state_dim, + fc_value_layers=fc_value_layers, + fc_policy_layers=fc_policy_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type + + ) if self.self_supervised_learning_loss: # self_supervised_learning_loss related network proposed in EfficientZero @@ -186,7 +205,7 @@ def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: """ batch_size = get_shape0(obs) latent_state = self._representation(obs) - device = latent_state.device + device = latent_state[0].device policy_logits, value = self._prediction(latent_state) # zero initialization for reward hidden states # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) @@ -307,19 +326,22 @@ def _dynamics(self, latent_state: torch.Tensor, reward_hidden_state: Tuple, # e.g., torch.Size([8]) -> torch.Size([8, 1]) action_encoding = action_encoding.unsqueeze(-1) - action_encoding = action_encoding.to(latent_state.device).float() + agent_latent_state, global_latent_state = latent_state + action_encoding = action_encoding.to(agent_latent_state.device).float() # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. - state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + agent_state_action_encoding = torch.cat((agent_latent_state, action_encoding), dim=1) + global_state_action_encoding = torch.cat((agent_latent_state, global_latent_state, action_encoding), dim=1) # NOTE: the key difference with MuZero - next_latent_state, next_reward_hidden_state, value_prefix = self.dynamics_network( - state_action_encoding, reward_hidden_state + (next_agent_latent_state, next_global_latent_state), next_reward_hidden_state, value_prefix = self.dynamics_network( + (agent_state_action_encoding, global_state_action_encoding), reward_hidden_state ) if self.state_norm: - next_latent_state = renormalize(next_latent_state) - return next_latent_state, next_reward_hidden_state, value_prefix + next_agent_latent_state = renormalize(next_agent_latent_state) + next_global_latent_state = renormalize(next_global_latent_state) + return (next_agent_latent_state, next_global_latent_state), next_reward_hidden_state, value_prefix def project(self, latent_state: torch.Tensor, with_grad=True): """ diff --git a/lzero/model/muzero_model_mlp.py b/lzero/model/muzero_model_mlp.py index ecde46a32..2bd69ff42 100644 --- a/lzero/model/muzero_model_mlp.py +++ b/lzero/model/muzero_model_mlp.py @@ -124,7 +124,16 @@ def __init__( res_connection_in_dynamics=self.res_connection_in_dynamics, ) else: - self.dynamics_network = state_dynamics + self.dynamics_network = state_dynamics( + action_encoding_dim=self.action_encoding_dim, + num_channels=self.latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + fc_reward_layers=fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) if state_prediction == None: self.prediction_network = PredictionNetworkMLP( @@ -137,7 +146,15 @@ def __init__( norm_type=norm_type ) else: - self.prediction_network = state_prediction + self.prediction_network = state_prediction( + action_space_size=action_space_size, + num_channels=latent_state_dim, + fc_value_layers=fc_value_layers, + fc_policy_layers=fc_policy_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type + ) if self.self_supervised_learning_loss: # self_supervised_learning_loss related network proposed in EfficientZero @@ -293,18 +310,20 @@ def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[t # e.g., torch.Size([8]) -> torch.Size([8, 1]) action_encoding = action_encoding.unsqueeze(-1) - action_encoding = action_encoding.to(latent_state.device).float() + agent_latent_state, global_latent_state = latent_state + action_encoding = action_encoding.to(agent_latent_state.device).float() # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. - state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) - - next_latent_state, reward = self.dynamics_network(state_action_encoding) + agent_state_action_encoding = torch.cat((agent_latent_state, action_encoding), dim=1) + global_state_action_encoding = torch.cat((agent_latent_state, global_latent_state, action_encoding), dim=1) + (next_agent_latent_state, next_global_latent_state), reward = self.dynamics_network((agent_state_action_encoding, global_state_action_encoding)) if not self.state_norm: - return next_latent_state, reward + return (next_agent_latent_state, next_global_latent_state), reward else: - next_latent_state_normalized = renormalize(next_latent_state) - return next_latent_state_normalized, reward + next_agent_latent_state_normalized = renormalize(next_agent_latent_state) + next_global_latent_state_normalized = renormalize(next_global_latent_state) + return (next_agent_latent_state_normalized, next_global_latent_state_normalized), reward def project(self, latent_state: torch.Tensor, with_grad=True) -> torch.Tensor: """ diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index b160f4dd6..001b46294 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -20,6 +20,7 @@ from lzero.policy.muzero import MuZeroPolicy from ding.utils.data import default_collate from ding.torch_utils import to_device, to_tensor +from collections import defaultdict @POLICY_REGISTRY.register('efficientzero') @@ -334,7 +335,9 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # Note: The following lines are just for debugging. predicted_value_prefixs = [] if self._cfg.monitor_extra_statistics: - latent_state_list = latent_state.detach().cpu().numpy() + agent_latent_state, global_latent_state = latent_state + agent_latent_state_list = agent_latent_state.detach().cpu().numpy() + global_latent_state_list = global_latent_state.detach().cpu().numpy() predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( policy_logits, dim=1 ).detach().cpu() @@ -397,17 +400,22 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: if self._cfg.ssl_loss_weight > 0: # obtain the oracle latent states from representation function. beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) - network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + obs_target_batch_tmp = default_collate(obs_target_batch[:, beg_index:end_index].squeeze()) + network_output = self._learn_model.initial_inference(obs_target_batch_tmp) latent_state = to_tensor(latent_state) representation_state = to_tensor(network_output.latent_state) # NOTE: no grad for the representation_state branch. - dynamic_proj = self._learn_model.project(latent_state, with_grad=True) - observation_proj = self._learn_model.project(representation_state, with_grad=False) - temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] + dynamic_proj = self._learn_model.project(latent_state[0], with_grad=True) + observation_proj = self._learn_model.project(representation_state[0], with_grad=False) + temp_loss_0 = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] - consistency_loss += temp_loss + dynamic_proj = self._learn_model.project(latent_state[1], with_grad=True) + observation_proj = self._learn_model.project(representation_state[1], with_grad=False) + temp_loss_1 = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] + + consistency_loss += (temp_loss_0 + temp_loss_1) # NOTE: the target policy, target_value_categorical, target_value_prefix_categorical is calculated in # game buffer now. @@ -455,7 +463,8 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: ) predicted_value_prefixs.append(original_value_prefixs_cpu) predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) - latent_state_list = np.concatenate((latent_state_list, latent_state.detach().cpu().numpy())) + agent_latent_state_list = np.concatenate((agent_latent_state_list, agent_latent_state.detach().cpu().numpy())) + global_latent_state_list = np.concatenate((global_latent_state_list, global_latent_state.detach().cpu().numpy())) # ============================================================== # the core learn model update step. @@ -565,11 +574,13 @@ def _forward_collect( self._collect_mcts_temperature = temperature self.collect_epsilon = epsilon active_collect_env_num = len(data) + batch_size = active_collect_env_num*self.cfg.model.agent_num # - data = sum(data, []) + data = sum(sum(data, []), []) data = default_collate(data) data = to_device(data, self._device) to_play = np.array(to_play).reshape(-1).tolist() + action_mask = sum(action_mask, []) with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} @@ -577,41 +588,43 @@ def _forward_collect( latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( network_output ) - + agent_latent_state_roots, global_latent_state_roots = latent_state_roots pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() - latent_state_roots = latent_state_roots.detach().cpu().numpy() + agent_latent_state_roots = agent_latent_state_roots.detach().cpu().numpy() + global_latent_state_roots = global_latent_state_roots.detach().cpu().numpy() reward_hidden_state_roots = ( reward_hidden_state_roots[0].detach().cpu().numpy(), reward_hidden_state_roots[1].detach().cpu().numpy() ) policy_logits = policy_logits.detach().cpu().numpy().tolist() - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] # the only difference between collect and eval is the dirichlet noise. noises = [ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ).astype(np.float32).tolist() for j in range(batch_size) ] if self._cfg.mcts_ctree: # cpp mcts_tree - roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + roots = MCTSCtree.roots(batch_size, legal_actions) else: # python mcts_tree - roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + roots = MCTSPtree.roots(batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) self._mcts_collect.search( - roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play + roots, self._collect_model, (agent_latent_state_roots, global_latent_state_roots), reward_hidden_state_roots, to_play ) roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} data_id = [i for i in range(active_collect_env_num)] - output = {i: None for i in data_id} + output = {i: defaultdict(list) for i in data_id} if ready_env_id is None: ready_env_id = np.arange(active_collect_env_num) - for i, env_id in enumerate(ready_env_id): + for i in range(batch_size): + env_id = i // self.cfg.model.agent_num distributions, value = roots_visit_count_distributions[i], roots_values[i] if self._cfg.eps.eps_greedy_exploration_in_collect: # eps-greedy collect @@ -630,15 +643,12 @@ def _forward_collect( ) # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[env_id] = { - 'action': action, - 'distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], - } - + output[env_id]['action'].append(action) + output[env_id]['distributions'].append(distributions) + output[env_id]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[env_id]['value'].append(value) + output[env_id]['pred_value'].append(pred_values[i]) + output[env_id]['policy_logits'].append(policy_logits[i]) return output def _init_eval(self) -> None: @@ -676,11 +686,13 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read """ self._eval_model.eval() active_eval_env_num = len(data) + batch_size = active_eval_env_num*self.cfg.model.agent_num # - data = sum(data, []) + data = sum(sum(data, []), []) data = default_collate(data) data = to_device(data, self._device) to_play = np.array(to_play).reshape(-1).tolist() + action_mask = sum(action_mask, []) with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._eval_model.initial_inference(data) @@ -691,33 +703,36 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read if not self._eval_model.training: # if not in training, obtain the scalars of the value/reward pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) - latent_state_roots = latent_state_roots.detach().cpu().numpy() + agent_latent_state_roots, global_latent_state_roots = latent_state_roots + agent_latent_state_roots = agent_latent_state_roots.detach().cpu().numpy() + global_latent_state_roots = global_latent_state_roots.detach().cpu().numpy() reward_hidden_state_roots = ( reward_hidden_state_roots[0].detach().cpu().numpy(), reward_hidden_state_roots[1].detach().cpu().numpy() ) policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] if self._cfg.mcts_ctree: # cpp mcts_tree - roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + roots = MCTSCtree.roots(batch_size, legal_actions) else: # python mcts_tree - roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots = MCTSPtree.roots(batch_size, legal_actions) roots.prepare_no_noise(value_prefix_roots, policy_logits, to_play) - self._mcts_eval.search(roots, self._eval_model, latent_state_roots, reward_hidden_state_roots, to_play) + self._mcts_eval.search(roots, self._eval_model, (agent_latent_state_roots, global_latent_state_roots), reward_hidden_state_roots, to_play) # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} data_id = [i for i in range(active_eval_env_num)] - output = {i: None for i in data_id} + output = {i: defaultdict(list) for i in data_id} if ready_env_id is None: ready_env_id = np.arange(active_eval_env_num) - for i, env_id in enumerate(ready_env_id): + for i in range(batch_size): + env_id = i // self.cfg.model.agent_num distributions, value = roots_visit_count_distributions[i], roots_values[i] # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents # the index within the legal action set, rather than the index in the entire action set. @@ -727,15 +742,12 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read ) # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[env_id] = { - 'action': action, - 'distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], - } - + output[env_id]['action'].append(action) + output[env_id]['distributions'].append(distributions) + output[env_id]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[env_id]['value'].append(value) + output[env_id]['pred_value'].append(pred_values[i]) + output[env_id]['policy_logits'].append(policy_logits[i]) return output def _monitor_vars_learn(self) -> List[str]: diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 411675ae4..cbca2c18c 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -18,6 +18,7 @@ prepare_obs from ding.utils.data import default_collate from ding.torch_utils import to_device, to_tensor +from collections import defaultdict @POLICY_REGISTRY.register('muzero') class MuZeroPolicy(Policy): @@ -326,7 +327,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Note: The following lines are just for debugging. predicted_rewards = [] if self._cfg.monitor_extra_statistics: - latent_state_list = latent_state.detach().cpu().numpy() + agent_latent_state, global_latent_state = latent_state + agent_latent_state_list = agent_latent_state.detach().cpu().numpy() + global_latent_state_list = global_latent_state.detach().cpu().numpy() predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( policy_logits, dim=1 ).detach().cpu() @@ -365,16 +368,22 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in if self._cfg.ssl_loss_weight > 0: # obtain the oracle latent states from representation function. beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) - network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + obs_target_batch_tmp = default_collate(obs_target_batch[:, beg_index:end_index].squeeze()) + network_output = self._learn_model.initial_inference(obs_target_batch_tmp) latent_state = to_tensor(latent_state) representation_state = to_tensor(network_output.latent_state) # NOTE: no grad for the representation_state branch - dynamic_proj = self._learn_model.project(latent_state, with_grad=True) - observation_proj = self._learn_model.project(representation_state, with_grad=False) - temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] - consistency_loss += temp_loss + dynamic_proj = self._learn_model.project(latent_state[0], with_grad=True) + observation_proj = self._learn_model.project(representation_state[0], with_grad=False) + temp_loss_0 = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] + + dynamic_proj = self._learn_model.project(latent_state[1], with_grad=True) + observation_proj = self._learn_model.project(representation_state[1], with_grad=False) + temp_loss_1 = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] + + consistency_loss += (temp_loss_0 + temp_loss_1) # NOTE: the target policy, target_value_categorical, target_reward_categorical is calculated in # game buffer now. @@ -399,7 +408,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in ) predicted_rewards.append(original_rewards_cpu) predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) - latent_state_list = np.concatenate((latent_state_list, latent_state.detach().cpu().numpy())) + agent_latent_state_list = np.concatenate((agent_latent_state_list, agent_latent_state.detach().cpu().numpy())) + global_latent_state_list = np.concatenate((global_latent_state_list, global_latent_state.detach().cpu().numpy())) # ============================================================== # the core learn model update step. @@ -507,48 +517,62 @@ def _forward_collect( self._collect_mcts_temperature = temperature self.collect_epsilon = epsilon active_collect_env_num = len(data) + batch_size = active_collect_env_num*self.cfg.model.agent_num # - data = sum(data, []) + data = sum(sum(data, []), []) data = default_collate(data) data = to_device(data, self._device) to_play = np.array(to_play).reshape(-1).tolist() + action_mask = sum(action_mask, []) with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._collect_model.initial_inference(data) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - + agent_latent_state_roots, global_latent_state_roots = latent_state_roots pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() - latent_state_roots = latent_state_roots.detach().cpu().numpy() + agent_latent_state_roots = agent_latent_state_roots.detach().cpu().numpy() + global_latent_state_roots = global_latent_state_roots.detach().cpu().numpy() + # policy_logits_tmp = policy_logits.reshape(active_collect_env_num, self.cfg.model.agent_num, -1) policy_logits = policy_logits.detach().cpu().numpy().tolist() - - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + reward_roots = [[reward_root]*self.cfg.model.agent_num for reward_root in reward_roots] + reward_roots = sum(reward_roots, []) + # # joint policy_logits + # prob_1 = policy_logits_tmp[:, 0, :].unsqueeze(1).unsqueeze(2) + # prob_2 = policy_logits_tmp[:, 1, :].unsqueeze(1).unsqueeze(3) + # prob_3 = policy_logits_tmp[:, 2, :].unsqueeze(2).unsqueeze(3) + # # TODO(after softmax?) + # joint_policy_logits = prob_1 * prob_2 * prob_3 # boardcast + # joint_policy_logits = joint_policy_logits.reshape(8, -1) + # joint_policy_logits = joint_policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] # the only difference between collect and eval is the dirichlet noise noises = [ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ).astype(np.float32).tolist() for j in range(batch_size) ] if self._cfg.mcts_ctree: # cpp mcts_tree - roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + roots = MCTSCtree.roots(batch_size, legal_actions) else: # python mcts_tree - roots = MCTSPtree.roots(active_collect_env_num, legal_actions) - + roots = MCTSPtree.roots(batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) - self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) + self._mcts_collect.search(roots, self._collect_model, (agent_latent_state_roots, global_latent_state_roots), to_play) # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} data_id = [i for i in range(active_collect_env_num)] - output = {i: None for i in data_id} + output = {i: defaultdict(list) for i in data_id} if ready_env_id is None: ready_env_id = np.arange(active_collect_env_num) - - for i, env_id in enumerate(ready_env_id): + + for i in range(batch_size): + env_id = i // self.cfg.model.agent_num distributions, value = roots_visit_count_distributions[i], roots_values[i] if self._cfg.eps.eps_greedy_exploration_in_collect: # eps greedy collect @@ -567,14 +591,40 @@ def _forward_collect( ) # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[env_id] = { - 'action': action, - 'distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], - } + output[env_id]['action'].append(action) + output[env_id]['distributions'].append(distributions) + output[env_id]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[env_id]['value'].append(value) + output[env_id]['pred_value'].append(pred_values[i]) + output[env_id]['policy_logits'].append(policy_logits[i]) + + # for i, env_id in enumerate(ready_env_id): + # distributions, value = roots_visit_count_distributions[i], roots_values[i] + # if self._cfg.eps.eps_greedy_exploration_in_collect: + # # eps greedy collect + # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + # distributions, temperature=self._collect_mcts_temperature, deterministic=True + # ) + # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # if np.random.rand() < self.collect_epsilon: + # action = np.random.choice(legal_actions[i]) + # else: + # # normal collect + # # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # # the index within the legal action set, rather than the index in the entire action set. + # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + # distributions, temperature=self._collect_mcts_temperature, deterministic=False + # ) + # # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # output[env_id] = { + # 'action': action, + # 'distributions': distributions, + # 'visit_count_distribution_entropy': visit_count_distribution_entropy, + # 'value': value, + # 'pred_value': pred_values[i], + # 'policy_logits': policy_logits[i], + # } return output @@ -611,6 +661,9 @@ def _get_target_obs_index_in_step_k(self, step): elif self._cfg.model.model_type == 'mlp': beg_index = self._cfg.model.observation_shape * step end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) + elif self._cfg.model.model_type == 'structure': + beg_index = step + end_index = step + self._cfg.model.frame_stack_num return beg_index, end_index def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict: @@ -637,11 +690,13 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 """ self._eval_model.eval() active_eval_env_num = len(data) + batch_size = active_eval_env_num*self.cfg.model.agent_num # - data = sum(data, []) + data = sum(sum(data, []), []) data = default_collate(data) data = to_device(data, self._device) to_play = np.array(to_play).reshape(-1).tolist() + action_mask = sum(action_mask, []) with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._collect_model.initial_inference(data) @@ -650,30 +705,35 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 if not self._eval_model.training: # if not in training, obtain the scalars of the value/reward pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) - latent_state_roots = latent_state_roots.detach().cpu().numpy() + agent_latent_state_roots, global_latent_state_roots = latent_state_roots + agent_latent_state_roots = agent_latent_state_roots.detach().cpu().numpy() + global_latent_state_roots = global_latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + reward_roots = [[reward_root]*self.cfg.model.agent_num for reward_root in reward_roots] + reward_roots = sum(reward_roots, []) - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] if self._cfg.mcts_ctree: # cpp mcts_tree - roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + roots = MCTSCtree.roots(batch_size, legal_actions) else: # python mcts_tree - roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots = MCTSPtree.roots(batch_size, legal_actions) roots.prepare_no_noise(reward_roots, policy_logits, to_play) - self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play) + self._mcts_eval.search(roots, self._eval_model, (agent_latent_state_roots, global_latent_state_roots), to_play) # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} data_id = [i for i in range(active_eval_env_num)] - output = {i: None for i in data_id} + output = {i: defaultdict(list) for i in data_id} if ready_env_id is None: ready_env_id = np.arange(active_eval_env_num) - for i, env_id in enumerate(ready_env_id): + for i in range(batch_size): + env_id = i // self.cfg.model.agent_num distributions, value = roots_visit_count_distributions[i], roots_values[i] # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents # the index within the legal action set, rather than the index in the entire action set. @@ -685,16 +745,12 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the # entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - - output[env_id] = { - 'action': action, - 'distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], - } - + output[env_id]['action'].append(action) + output[env_id]['distributions'].append(distributions) + output[env_id]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[env_id]['value'].append(value) + output[env_id]['pred_value'].append(pred_values[i]) + output[env_id]['policy_logits'].append(policy_logits[i]) return output def _monitor_vars_learn(self) -> List[str]: diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 58ad5a2cf..249ef022a 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -1,5 +1,5 @@ import time -from collections import deque, namedtuple +from collections import deque, namedtuple, defaultdict from typing import Optional, Any, List import numpy as np @@ -221,6 +221,29 @@ def _compute_priorities(self, i, pred_values_lst, search_values_lst): priorities = None return priorities + + def _compute_priorities_single(self, i, agent_id, pred_values_lst, search_values_lst): + """ + Overview: + obtain the priorities at index i. + Arguments: + - i: index. + - pred_values_lst: The list of value being predicted. + - search_values_lst: The list of value obtained through search. + """ + if self.policy_config.use_priority: + pred_values = torch.from_numpy(np.array(pred_values_lst[i][agent_id])).to(self.policy_config.device + ).float().view(-1) + search_values = torch.from_numpy(np.array(search_values_lst[i][agent_id])).to(self.policy_config.device + ).float().view(-1) + priorities = L1Loss(reduction='none' + )(pred_values, + search_values).detach().cpu().numpy() + 1e-6 # avoid zero priority + else: + # priorities is None -> use the max priority for all newly collected data + priorities = None + + return priorities def pad_and_save_last_trajectory(self, i, last_game_segments, last_game_priorities, game_segments, done) -> None: """ @@ -290,6 +313,66 @@ def pad_and_save_last_trajectory(self, i, last_game_segments, last_game_prioriti last_game_priorities[i] = None return None + + def pad_and_save_last_trajectory_single( + self, i, agent_id, last_game_segments, last_game_priorities, game_segments, done + ) -> None: + """ + Overview: + put the last game block into the pool if the current game is finished + Arguments: + - last_game_segments (:obj:`list`): list of the last game segments + - last_game_priorities (:obj:`list`): list of the last game priorities + - game_segments (:obj:`list`): list of the current game segments + Note: + (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True + """ + # pad over last block trajectory + beg_index = self.policy_config.model.frame_stack_num + end_index = beg_index + self.policy_config.num_unroll_steps + + # the start obs is init zero obs, so we take the [ : +] obs as the pad obs + # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs + pad_obs_lst = game_segments[i][agent_id].obs_segment[beg_index:end_index] + pad_child_visits_lst = game_segments[i][agent_id].child_visit_segment[:self.policy_config.num_unroll_steps] + # EfficientZero original repo bug: + # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] + + beg_index = 0 + # self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps + end_index = beg_index + self.unroll_plus_td_steps - 1 + + pad_reward_lst = game_segments[i][agent_id].reward_segment[beg_index:end_index] + + beg_index = 0 + end_index = beg_index + self.unroll_plus_td_steps + + pad_root_values_lst = game_segments[i][agent_id].root_value_segment[beg_index:end_index] + + # pad over and save + last_game_segments[i][agent_id].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) + """ + Note: + game_segment element shape: + obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 + rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 + action: game_segment_length -> 20 + root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 + child_visits: game_segment_length + num_unroll_steps -> 20 +5 + to_play: game_segment_length -> 20 + action_mask: game_segment_length -> 20 + """ + + last_game_segments[i][agent_id].game_segment_to_array() + + # put the game block into the pool + self.game_segment_pool.append((last_game_segments[i][agent_id], last_game_priorities[i][agent_id], done[i])) + + # reset last game_segments + last_game_segments[i][agent_id] = None + last_game_priorities[i][agent_id] = None + + return None def collect(self, n_episode: Optional[int] = None, @@ -322,6 +405,8 @@ def collect(self, # initializations init_obs = self._env.ready_obs + agent_num = len(init_obs[0]['action_mask']) + self._multi_agent = True retry_waiting_time = 0.001 while len(init_obs.keys()) != self._env_num: @@ -342,33 +427,39 @@ def collect(self, chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} game_segments = [ - GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) for _ in range(env_nums) + [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(agent_num) + ] for _ in range(env_nums) ] + # stacked observation windows in reset stage for init game_segments - observation_window_stack = [[] for _ in range(env_nums)] + observation_window_stack = [[[] for _ in range(agent_num)] for _ in range(env_nums)] for env_id in range(env_nums): - observation_window_stack[env_id] = deque( - [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) - - game_segments[env_id].reset(observation_window_stack[env_id]) + for agent_id in range(agent_num): + observation_window_stack[env_id][agent_id] = deque( + [to_ndarray(init_obs[env_id]['observation'][agent_id]) for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) + game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) dones = np.array([False for _ in range(env_nums)]) - last_game_segments = [None for _ in range(env_nums)] - last_game_priorities = [None for _ in range(env_nums)] + last_game_segments = [[None for _ in range(agent_num)] for _ in range(env_nums)] + last_game_priorities = [[None for _ in range(agent_num)] for _ in range(env_nums)] # for priorities in self-play - search_values_lst = [[] for _ in range(env_nums)] - pred_values_lst = [[] for _ in range(env_nums)] + search_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] + pred_values_lst = [[[] for _ in range(agent_num)] for _ in range(env_nums)] if self.policy_config.gumbel_algo: improved_policy_lst = [[] for _ in range(env_nums)] # some logs - eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) + if self._multi_agent: + eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros((env_nums, agent_num)) + else: + eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) if self.policy_config.gumbel_algo: completed_value_lst = np.zeros(env_nums) self_play_moves = 0. @@ -387,8 +478,13 @@ def collect(self, new_available_env_id = set(obs.keys()).difference(ready_env_id) ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) - - stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + if self._multi_agent: + stack_obs = defaultdict(list) + for env_id in ready_env_id: + for agent_id in range(agent_num): + stack_obs[env_id].append(game_segments[env_id][agent_id].get_obs()) + else: + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} stack_obs = list(stack_obs.values()) action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} @@ -482,7 +578,13 @@ def collect(self, elif self.policy_config.gumbel_algo: game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], improved_policy = improved_policy_dict[env_id]) else: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id].store_search_stats( + distributions_dict[env_id][agent_id], value_dict[env_id][agent_id] + ) + else: + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} # in ``game_segments[env_id].init``, we have append o_{t} in ``self.obs_segment`` if self.policy_config.use_ture_chance_label_in_chance_encoder: @@ -491,10 +593,17 @@ def collect(self, to_play_dict[env_id], chance_dict[env_id] ) else: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id] - ) + if self._multi_agent: + for agent_id in range(agent_num): + game_segments[env_id][agent_id].append( + actions[env_id][agent_id], to_ndarray(obs['observation'][agent_id]), + reward, action_mask_dict[env_id][agent_id], to_play_dict[env_id][agent_id] + ) + else: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id] + ) # NOTE: the position of code snippet is very important. # the obs['action_mask'] and obs['to_play'] are corresponding to the next action @@ -508,7 +617,9 @@ def collect(self, else: dones[env_id] = done - visit_entropies_lst[env_id] += visit_entropy_dict[env_id] + # visit_entropies_lst[env_id] += visit_entropy_dict[env_id] + for agent_id in range(agent_num): + visit_entropies_lst[env_id][agent_id] += visit_entropy_dict[env_id][agent_id] if self.policy_config.gumbel_algo: completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) @@ -516,45 +627,56 @@ def collect(self, total_transitions += 1 if self.policy_config.use_priority: - pred_values_lst[env_id].append(pred_value_dict[env_id]) - search_values_lst[env_id].append(value_dict[env_id]) + for agent_id in range(agent_num): + pred_values_lst[env_id][agent_id].append(pred_value_dict[env_id][agent_id]) + search_values_lst[env_id][agent_id].append(value_dict[env_id][agent_id]) + # pred_values_lst[env_id].append(pred_value_dict[env_id]) + # search_values_lst[env_id].append(value_dict[env_id]) if self.policy_config.gumbel_algo: improved_policy_lst[env_id].append(improved_policy_dict[env_id]) # append the newest obs - observation_window_stack[env_id].append(to_ndarray(obs['observation'])) + if self._multi_agent: + for agent_id in range(agent_num): + observation_window_stack[env_id][agent_id].append(to_ndarray(obs['observation'][agent_id])) + else: + observation_window_stack[env_id].append(to_ndarray(obs['observation'])) # ============================================================== # we will save a game segment if it is the end of the game or the next game segment is finished. # ============================================================== # if game segment is full, we will save the last game segment - if game_segments[env_id].is_full(): - # pad over last segment trajectory - if last_game_segments[env_id] is not None: - # TODO(pu): return the one game segment - self.pad_and_save_last_trajectory( - env_id, last_game_segments, last_game_priorities, game_segments, dones - ) - - # calculate priority - priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] - if self.policy_config.gumbel_algo: - improved_policy_lst[env_id] = [] - - # the current game_segments become last_game_segment - last_game_segments[env_id] = game_segments[env_id] - last_game_priorities[env_id] = priorities + if self._multi_agent: + # set game_segments size enough large in MA + pass + else: + if game_segments[env_id].is_full(): + # pad over last segment trajectory + if last_game_segments[env_id] is not None: + # TODO(pu): return the one game segment + self.pad_and_save_last_trajectory( + env_id, last_game_segments, last_game_priorities, game_segments, dones + ) - # create new GameSegment - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) - game_segments[env_id].reset(observation_window_stack[env_id]) + # calculate priority + priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) + pred_values_lst[env_id] = [] + search_values_lst[env_id] = [] + if self.policy_config.gumbel_algo: + improved_policy_lst[env_id] = [] + + # the current game_segments become last_game_segment + last_game_segments[env_id] = game_segments[env_id] + last_game_priorities[env_id] = priorities + + # create new GameSegment + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + game_segments[env_id].reset(observation_window_stack[env_id]) self._env_info[env_id]['step'] += 1 collected_step += 1 @@ -580,21 +702,22 @@ def collect(self, # NOTE: put the penultimate game segment in one episode into the trajectory_pool # pad over 2th last game_segment using the last game_segment - if last_game_segments[env_id] is not None: - self.pad_and_save_last_trajectory( - env_id, last_game_segments, last_game_priorities, game_segments, dones - ) + for agent_id in range(agent_num): + if last_game_segments[env_id][agent_id] is not None: + self.pad_and_save_last_trajectory_single( + env_id, agent_id, last_game_segments, last_game_priorities, game_segments, dones + ) - # store current segment trajectory - priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) + # store current segment trajectory + priorities = self._compute_priorities_single(env_id, agent_id, pred_values_lst, search_values_lst) - # NOTE: put the last game segment in one episode into the trajectory_pool - game_segments[env_id].game_segment_to_array() + # NOTE: put the last game segment in one episode into the trajectory_pool + game_segments[env_id][agent_id].game_segment_to_array() - # assert len(game_segments[env_id]) == len(priorities) - # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null - if len(game_segments[env_id].reward_segment) != 0: - self.game_segment_pool.append((game_segments[env_id], priorities, dones[env_id])) + # assert len(game_segments[env_id]) == len(priorities) + # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null + if len(game_segments[env_id][agent_id].reward_segment) != 0: + self.game_segment_pool.append((game_segments[env_id][agent_id], priorities, dones[env_id])) # print(game_segments[env_id].reward_segment) # reset the finished env and init game_segments @@ -627,18 +750,19 @@ def collect(self, if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) - observation_window_stack[env_id] = deque( - [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) - game_segments[env_id].reset(observation_window_stack[env_id]) - last_game_segments[env_id] = None - last_game_priorities[env_id] = None + for agent_id in range(agent_num): + game_segments[env_id][agent_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + observation_window_stack[env_id][agent_id] = deque( + [init_obs[env_id]['observation'][agent_id] for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) + game_segments[env_id][agent_id].reset(observation_window_stack[env_id][agent_id]) + last_game_segments[env_id] = [None for _ in range(agent_num)] + last_game_priorities[env_id] = [None for _ in range(agent_num)] # log self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) diff --git a/zoo/petting_zoo/config/ptz_simple_ez_config.py b/zoo/petting_zoo/config/ptz_simple_ez_config.py index d691f68cf..71f3b8392 100644 --- a/zoo/petting_zoo/config/ptz_simple_ez_config.py +++ b/zoo/petting_zoo/config/ptz_simple_ez_config.py @@ -46,10 +46,11 @@ model=dict( model_type='structure', latent_state_dim=256, + frame_stack_num=1, action_space='discrete', action_space_size=action_space_size, agent_num=n_agent, - self_supervised_learning_loss=False, # default is False + self_supervised_learning_loss=True, agent_obs_shape=6, global_obs_shape=14, discrete_action_encoding_type='one_hot', @@ -73,10 +74,10 @@ use_augmentation=False, update_per_collect=update_per_collect, batch_size=batch_size, - optim_type='SGD', - lr_piecewise_constant_decay=True, - learning_rate=0.2, - ssl_loss_weight=0, # NOTE: default is 0. + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, # NOTE: default is 0. num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, diff --git a/zoo/petting_zoo/config/ptz_simple_mz_config.py b/zoo/petting_zoo/config/ptz_simple_mz_config.py index 45343f978..7db1bc0e6 100644 --- a/zoo/petting_zoo/config/ptz_simple_mz_config.py +++ b/zoo/petting_zoo/config/ptz_simple_mz_config.py @@ -46,10 +46,11 @@ model=dict( model_type='structure', latent_state_dim=256, + frame_stack_num=1, action_space='discrete', action_space_size=action_space_size, agent_num=n_agent, - self_supervised_learning_loss=False, # default is False + self_supervised_learning_loss=True, agent_obs_shape=6, global_obs_shape=14, discrete_action_encoding_type='one_hot', @@ -76,7 +77,7 @@ optim_type='Adam', lr_piecewise_constant_decay=False, learning_rate=0.003, - ssl_loss_weight=0, # NOTE: default is 0. + ssl_loss_weight=2, # NOTE: default is 0. num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, diff --git a/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py new file mode 100644 index 000000000..b6ceae941 --- /dev/null +++ b/zoo/petting_zoo/config/ptz_simple_spread_ez_config.py @@ -0,0 +1,116 @@ +from easydict import EasyDict + +env_name = 'ptz_simple_spread' +multi_agent = True + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +seed = 0 +n_agent = 3 +n_landmark = n_agent +collector_env_num = 8 +evaluator_env_num = 8 +n_episode = 8 +batch_size = 256 +num_simulations = 50 +update_per_collect = 50 +reanalyze_ratio = 0. +action_space_size = 5 +eps_greedy_exploration_in_collect = True +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +main_config = dict( + exp_name= + f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', + env=dict( + env_family='mpe', + env_id='simple_spread_v2', + n_agent=n_agent, + n_landmark=n_landmark, + max_cycles=25, + agent_obs_only=False, + agent_specific_global_state=True, + continuous_actions=False, + stop_value=0, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + multi_agent=multi_agent, + ignore_done=False, + model=dict( + model_type='structure', + latent_state_dim=256, + frame_stack_num=1, + action_space='discrete', + action_space_size=action_space_size, + agent_num=n_agent, + self_supervised_learning_loss=True, + agent_obs_shape=18, + global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, + discrete_action_encoding_type='one_hot', + global_cooperation=True, # TODO: doesn't work now + hidden_size_list=[256, 256], + norm_type='BN', + ), + cuda=True, + mcts_ctree=True, + gumbel_algo=False, + env_type='not_board_games', + game_segment_length=30, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + type='linear', + start=1., + end=0.05, + decay=int(2e5), + ), + use_augmentation=False, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, # NOTE: default is 0. + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + learn=dict(learner=dict( + log_policy=True, + hook=dict(log_show_after_iter=10, ), + ), ), +) +main_config = EasyDict(main_config) +create_config = dict( + env=dict( + import_names=['zoo.petting_zoo.envs.petting_zoo_simple_spread_env'], + type='petting_zoo', + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero', + import_names=['lzero.policy.efficientzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +create_config = EasyDict(create_config) +ptz_simple_spread_efficientzero_config = main_config +ptz_simple_spread_efficientzero_create_config = create_config + +if __name__ == "__main__": + from zoo.petting_zoo.entry import train_muzero + train_muzero([main_config, create_config], seed=seed) diff --git a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py index c93cb0ada..2260f4994 100644 --- a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py +++ b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py @@ -13,10 +13,10 @@ evaluator_env_num = 8 n_episode = 8 batch_size = 256 -num_simulations = 200 +num_simulations = 50 update_per_collect = 50 reanalyze_ratio = 0. -action_space_size = 5*5*5 +action_space_size = 5 eps_greedy_exploration_in_collect = True # ============================================================== # end of the most frequently changed config specified by the user @@ -50,9 +50,9 @@ action_space='discrete', action_space_size=action_space_size, agent_num=n_agent, - self_supervised_learning_loss=False, # default is False + self_supervised_learning_loss=True, agent_obs_shape=18, - global_obs_shape=18*n_agent+30, # 84 + global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, discrete_action_encoding_type='one_hot', global_cooperation=True, # TODO: doesn't work now hidden_size_list=[256, 256], @@ -77,7 +77,7 @@ optim_type='Adam', lr_piecewise_constant_decay=False, learning_rate=0.003, - ssl_loss_weight=0, # NOTE: default is 0. + ssl_loss_weight=2, # NOTE: default is 0. num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, @@ -97,7 +97,7 @@ import_names=['zoo.petting_zoo.envs.petting_zoo_simple_spread_env'], type='petting_zoo', ), - env_manager=dict(type='base'), + env_manager=dict(type='subprocess'), policy=dict( type='muzero', import_names=['lzero.policy.muzero'], @@ -111,6 +111,6 @@ ptz_simple_spread_muzero_config = main_config ptz_simple_spread_muzero_create_config = create_config -if __name__ == '__main__': +if __name__ == "__main__": from zoo.petting_zoo.entry import train_muzero train_muzero([main_config, create_config], seed=seed) diff --git a/zoo/petting_zoo/entry/eval_muzero.py b/zoo/petting_zoo/entry/eval_muzero.py index a6947e9c4..6d6cffc1d 100644 --- a/zoo/petting_zoo/entry/eval_muzero.py +++ b/zoo/petting_zoo/entry/eval_muzero.py @@ -18,7 +18,7 @@ from lzero.policy.random_policy import LightZeroRandomPolicy from lzero.worker import MuZeroCollector as Collector from lzero.worker import MuZeroEvaluator as Evaluator -from zoo.petting_zoo.model import PettingZooEncoder, PettingZooPrediction, PettingZooDynamics +from zoo.petting_zoo.model import PettingZooEncoder, PettingZooPrediction def eval_muzero(main_cfg, create_cfg, seed=0): assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \ diff --git a/zoo/petting_zoo/entry/train_muzero.py b/zoo/petting_zoo/entry/train_muzero.py index 1c32d89fe..aa731f7e0 100644 --- a/zoo/petting_zoo/entry/train_muzero.py +++ b/zoo/petting_zoo/entry/train_muzero.py @@ -18,7 +18,7 @@ from lzero.policy.random_policy import LightZeroRandomPolicy from lzero.worker import MuZeroCollector as Collector from lzero.worker import MuZeroEvaluator as Evaluator -from zoo.petting_zoo.model import PettingZooEncoder, PettingZooPrediction, PettingZooDynamics +from zoo.petting_zoo.model import PettingZooEncoder, PettingZooPrediction from lzero.entry.utils import random_collect @@ -54,9 +54,13 @@ def train_muzero( if create_cfg.policy.type == 'muzero': from lzero.mcts import MuZeroGameBuffer as GameBuffer from lzero.model.muzero_model_mlp import MuZeroModelMLP as Encoder + from zoo.petting_zoo.model import PettingZooMZDynamics as PettingZooDynamics + elif create_cfg.policy.type == 'efficientzero': from lzero.mcts import EfficientZeroGameBuffer as GameBuffer from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP as Encoder + from zoo.petting_zoo.model import PettingZooEZDynamics as PettingZooDynamics + elif create_cfg.policy.type == 'sampled_efficientzero': from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'gumbel_muzero': @@ -80,8 +84,8 @@ def train_muzero( evaluator_env.seed(cfg.seed, dynamic_seed=False) set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - # model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg), state_prediction=PettingZooPrediction(cfg), state_dynamics=PettingZooDynamics(cfg)) - model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg)) + model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg), state_prediction=PettingZooPrediction, state_dynamics=PettingZooDynamics) + # model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg)) policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) # load pretrained model diff --git a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index 69b366b55..c60fecbed 100644 --- a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -36,11 +36,6 @@ def __init__(self, cfg: dict) -> None: if self._act_scale: assert self._continuous_actions, 'Only continuous action space env needs act_scale' - # joint action - import itertools - action_space = [0, 1, 2, 3, 4] - self.combinations = list(itertools.product(action_space, repeat=self._num_agents)) - def reset(self) -> np.ndarray: if not self._init_flag: # In order to align with the simple spread in Multiagent Particle Env (MPE), @@ -172,8 +167,7 @@ def seed(self, seed: int, dynamic_seed: bool = True) -> None: self._dynamic_seed = dynamic_seed np.random.seed(self._seed) - def step(self, action: int) -> BaseEnvTimestep: - action = np.array(self.combinations[action]) + def step(self, action) -> BaseEnvTimestep: self._step_count += 1 action = self._process_action(action) if self._act_scale: @@ -254,12 +248,11 @@ def _process_obs(self, obs: 'torch.Tensor') -> np.ndarray: # noqa # - agent_state info # - global_state info if self._agent_specific_global_state: - ret['global_state'] = np.concatenate((np.concatenate(ret['agent_state']), ret['global_state'])) - # ret['global_state'] = np.concatenate( - # [ret['agent_state'], - # np.expand_dims(ret['global_state'], axis=0).repeat(self._num_agents, axis=0)], - # axis=1 - # ) + ret['global_state'] = np.concatenate( + [ret['agent_state'], + np.expand_dims(ret['global_state'], axis=0).repeat(self._num_agents, axis=0)], + axis=1 + ) # agent_alone_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2). # Stacked observation. Exclude other agents' positions from agent_state. Contains # - agent itself's state(velocity + position) + @@ -285,8 +278,14 @@ def _process_obs(self, obs: 'torch.Tensor') -> np.ndarray: # noqa 1 ) # action_mask: All actions are of use(either 1 for discrete or 5 for continuous). Thus all 1. - joint_action_mask = [1 for _ in range(np.power(5, self._num_agents))] - return {'observation': ret, 'action_mask': joint_action_mask, 'to_play': [-1]} + action_mask = [[1 for _ in range(*self._action_dim)] for _ in range(self._num_agents)] + ret_transform = [] + for i in range(self._num_agents): + tmp = {} + for k,v in ret.items(): + tmp[k] = v[i] + ret_transform.append(tmp) + return {'observation': ret_transform, 'action_mask': action_mask, 'to_play': [-1 for _ in range(self._num_agents)]} def _process_action(self, action: 'torch.Tensor') -> Dict[str, np.ndarray]: # noqa dict_action = {} diff --git a/zoo/petting_zoo/model/__init__.py b/zoo/petting_zoo/model/__init__.py index cc2eebcdf..690a44580 100644 --- a/zoo/petting_zoo/model/__init__.py +++ b/zoo/petting_zoo/model/__init__.py @@ -1 +1 @@ -from .model import PettingZooEncoder, PettingZooPrediction, PettingZooDynamics \ No newline at end of file +from .model import PettingZooEncoder, PettingZooPrediction, PettingZooEZDynamics, PettingZooMZDynamics \ No newline at end of file diff --git a/zoo/petting_zoo/model/model.py b/zoo/petting_zoo/model/model.py index 2bf02a64c..156bfce02 100644 --- a/zoo/petting_zoo/model/model.py +++ b/zoo/petting_zoo/model/model.py @@ -10,47 +10,40 @@ from ding.torch_utils import MLP from ding.utils import MODEL_REGISTRY, SequenceType from lzero.model.utils import get_dynamic_mean, get_reward_mean +from numpy import ndarray class PettingZooEncoder(nn.Module): def __init__(self, cfg): super().__init__() self.agent_num = cfg.policy.model.agent_num - agent_obs_shape = cfg.policy.model.agent_obs_shape - global_obs_shape = cfg.policy.model.global_obs_shape - self.agent_encoder = RepresentationNetworkMLP(observation_shape=agent_obs_shape, - hidden_channels=128, + self.agent_obs_shape = cfg.policy.model.agent_obs_shape + self.global_obs_shape = cfg.policy.model.global_obs_shape + self.agent_encoder = RepresentationNetworkMLP(observation_shape=self.agent_obs_shape, + hidden_channels=256, norm_type='BN') - self.global_encoder = RepresentationNetworkMLP(observation_shape=global_obs_shape, + self.global_encoder = RepresentationNetworkMLP(observation_shape=self.global_obs_shape, hidden_channels=256, norm_type='BN') - - self.encoder = RepresentationNetworkMLP(observation_shape=128+128*self.agent_num, - hidden_channels=128, - norm_type='BN') def forward(self, x): # agent - batch_size = x['global_state'].shape[0] - latent_state = x['global_state'].reshape(batch_size, -1) - latent_state = self.global_encoder(latent_state) - return latent_state - # agent_state_B = agent_state.reshape(batch_size, -1) - # agent_state_B_A = agent_state.reshape(batch_size, agent_num, -1) + agent_state = x['agent_state'].reshape(-1, self.agent_obs_shape) + agent_state = self.agent_encoder(agent_state) + # agent_state = agent_state.reshape(batch_size, agent_num, -1) # global - # global_state = self.global_encoder(x['global_state']) - # global_state = self.encoder(torch.cat((agent_state_B, global_state),dim=1)) - # return (agent_state_B, global_state) - + global_state = x['global_state'].reshape(-1, self.global_obs_shape) + global_state = self.global_encoder(global_state) + # global_state = global_state.reshape(batch_size, agent_num, -1) + return (agent_state, global_state) class PettingZooPrediction(nn.Module): def __init__( self, - cfg, - action_space_size: int=5, - num_channels: int=128, + action_space_size, + num_channels, common_layer_num: int = 2, fc_value_layers: SequenceType = [32], fc_policy_layers: SequenceType = [32], @@ -78,11 +71,9 @@ def __init__( """ super().__init__() self.num_channels = num_channels - self.agent_num = cfg.policy.model.agent_num - self.action_space_size = pow(action_space_size, self.agent_num) # ******* common backbone ****** - self.fc_prediction_common = MLP( + self.fc_prediction_agent_common = MLP( in_channels=self.num_channels, hidden_channels=self.num_channels, out_channels=self.num_channels, @@ -95,6 +86,18 @@ def __init__( last_linear_layer_init_zero=False, ) + self.fc_prediction_global_common = MLP( + in_channels=self.num_channels, + hidden_channels=self.num_channels, + out_channels=self.num_channels, + layer_num=common_layer_num, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) # ******* value and policy head ****** self.fc_value_head = MLP( in_channels=self.num_channels, @@ -109,9 +112,9 @@ def __init__( last_linear_layer_init_zero=last_linear_layer_init_zero ) self.fc_policy_head = MLP( - in_channels=self.num_channels*self.agent_num, + in_channels=self.num_channels, hidden_channels=fc_policy_layers[0], - out_channels=self.action_space_size, + out_channels=action_space_size, layer_num=len(fc_policy_layers) + 1, activation=activation, norm_type=norm_type, @@ -132,20 +135,20 @@ def forward(self, latent_state: torch.Tensor): - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). """ agent_state, global_state = latent_state - global_state = self.fc_prediction_common(global_state) + global_state_common = self.fc_prediction_global_common(global_state) + agent_state_common = self.fc_prediction_agent_common(agent_state) - value = self.fc_value_head(global_state) - policy = self.fc_policy_head(agent_state) + value = self.fc_value_head(global_state_common) + policy = self.fc_policy_head(agent_state_common) return policy, value -class PettingZooDynamics(nn.Module): +class PettingZooMZDynamics(nn.Module): def __init__( self, - cfg, - action_encoding_dim: int = 5, - num_channels: int = 253, + action_encoding_dim: int = 2, + num_channels: int = 64, common_layer_num: int = 2, fc_reward_layers: SequenceType = [32], output_support_size: int = 601, @@ -172,14 +175,13 @@ def __init__( - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. """ super().__init__() - self.agent_num = cfg.policy.model.agent_num - self.action_encoding_dim = pow(action_encoding_dim, self.agent_num) - self.num_channels = 128 + self.action_encoding_dim + self.num_channels = num_channels + self.action_encoding_dim = action_encoding_dim self.latent_state_dim = self.num_channels - self.action_encoding_dim self.res_connection_in_dynamics = res_connection_in_dynamics if self.res_connection_in_dynamics: - self.fc_dynamics_1 = MLP( + self.agent_fc_dynamics_1 = MLP( in_channels=self.num_channels, hidden_channels=self.latent_state_dim, layer_num=common_layer_num, @@ -191,7 +193,31 @@ def __init__( # last_linear_layer_init_zero=False is important for convergence last_linear_layer_init_zero=False, ) - self.fc_dynamics_2 = MLP( + self.agent_fc_dynamics_2 = MLP( + in_channels=self.latent_state_dim, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + self.global_fc_dynamics_1 = MLP( + in_channels=517, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=512, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + self.global_fc_dynamics_2 = MLP( in_channels=self.latent_state_dim, hidden_channels=self.latent_state_dim, layer_num=common_layer_num, @@ -203,22 +229,33 @@ def __init__( # last_linear_layer_init_zero=False is important for convergence last_linear_layer_init_zero=False, ) + self.fc_alignment = MLP( + in_channels=512, + hidden_channels=self.latent_state_dim, + out_channels=self.latent_state_dim, + layer_num=common_layer_num, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) else: - self.fc_dynamics_list = nn.ModuleList( - MLP(in_channels=self.num_channels, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) for _ in range(self.agent_num)) - - self.fc_dynamics_global = MLP( - in_channels=self.num_channels, + self.agent_fc_dynamics = MLP( + in_channels=261, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + self.global_fc_dynamics = MLP( + in_channels=517, hidden_channels=self.latent_state_dim, layer_num=common_layer_num, out_channels=self.latent_state_dim, @@ -254,27 +291,222 @@ def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, to - reward (:obj:`torch.Tensor`): The predicted reward for input state. """ if self.res_connection_in_dynamics: - # take the state encoding (e.g. latent_state), - # state_action_encoding[:, -self.action_encoding_dim:] is action encoding - latent_state = state_action_encoding[:, :-self.action_encoding_dim] - x = self.fc_dynamics_1(state_action_encoding) - # the residual link: add the latent_state to the state_action encoding - next_latent_state = x + latent_state - next_latent_state_encoding = self.fc_dynamics_2(next_latent_state) + # # take the state encoding (e.g. latent_state), + # # state_action_encoding[:, -self.action_encoding_dim:] is action encoding + # latent_state = state_action_encoding[:, :-self.action_encoding_dim] + # x = self.fc_dynamics_1(state_action_encoding) + # # the residual link: add the latent_state to the state_action encoding + # next_latent_state = x + latent_state + # next_latent_state_encoding = self.fc_dynamics_2(next_latent_state) + + agent_state_action_encoding, global_state_action_encoding = state_action_encoding + # agent + agent_latent_state = agent_state_action_encoding[:, :-self.action_encoding_dim] + x = self.agent_fc_dynamics_1(agent_state_action_encoding) + next_agent_latent_state = x + agent_latent_state + next_agent_latent_state_encoding = self.agent_fc_dynamics_2(next_agent_latent_state) + # global + global_latent_state = global_state_action_encoding[:, :-self.action_encoding_dim] # + x = self.global_fc_dynamics_1(global_state_action_encoding) # x 512 + next_global_latent_state = x + global_latent_state + next_global_latent_state = self.fc_alignment(next_global_latent_state) + next_global_latent_state_encoding = self.global_fc_dynamics_2(next_global_latent_state) else: - batch_size = state_action_encoding.shape[0] - next_agent_latent_list = [self.fc_dynamics_list[i](state_action_encoding) for i in range(self.agent_num)] - next_agent_latent_state = torch.stack(next_agent_latent_list, dim=1) - next_agent_latent_state = next_agent_latent_state.reshape(batch_size, -1) - next_global_latent_state = self.fc_dynamics_global(state_action_encoding) - next_latent_state_encoding = next_global_latent_state + agent_state_action_encoding, global_state_action_encoding = state_action_encoding + next_agent_latent_state = self.agent_fc_dynamics(agent_state_action_encoding) + next_global_latent_state = self.global_fc_dynamics(global_state_action_encoding) + next_global_latent_state_encoding = next_global_latent_state - reward = self.fc_reward_head(next_latent_state_encoding) + reward = self.fc_reward_head(next_global_latent_state_encoding) - return (next_agent_latent_state, next_latent_state_encoding), reward + return (next_agent_latent_state, next_global_latent_state), reward def get_dynamic_mean(self) -> float: return get_dynamic_mean(self) def get_reward_mean(self) -> float: + return get_reward_mean(self) + +class PettingZooEZDynamics(nn.Module): + + def __init__( + self, + action_encoding_dim: int = 2, + num_channels: int = 64, + common_layer_num: int = 2, + fc_reward_layers: SequenceType = [32], + output_support_size: int = 601, + lstm_hidden_size: int = 512, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, + ): + """ + Overview: + The definition of dynamics network in EfficientZero algorithm, which is used to predict next latent state + value_prefix and reward_hidden_state by the given current latent state and action. + The networks are mainly built on fully connected layers. + Arguments: + - action_encoding_dim (:obj:`int`): The dimension of action encoding. + - num_channels (:obj:`int`): The num of channels in latent states. + - common_layer_num (:obj:`int`): The number of common layers in dynamics network. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - output_support_size (:obj:`int`): The size of categorical reward output. + - lstm_hidden_size (:obj:`int`): The hidden size of lstm in dynamics network. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializationss for the last layer of value/policy head, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. + """ + super().__init__() + assert num_channels > action_encoding_dim, f'num_channels:{num_channels} <= action_encoding_dim:{action_encoding_dim}' + + self.num_channels = num_channels + self.action_encoding_dim = action_encoding_dim + self.latent_state_dim = self.num_channels - self.action_encoding_dim + self.lstm_hidden_size = lstm_hidden_size + self.activation = activation + self.res_connection_in_dynamics = res_connection_in_dynamics + + if self.res_connection_in_dynamics: + self.agent_fc_dynamics_1 = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + self.agent_fc_dynamics_2 = MLP( + in_channels=self.latent_state_dim, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + self.global_fc_dynamics_1 = MLP( + in_channels=517, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=512, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + self.global_fc_dynamics_2 = MLP( + in_channels=self.latent_state_dim, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + self.fc_alignment = MLP( + in_channels=512, + hidden_channels=self.latent_state_dim, + out_channels=self.latent_state_dim, + layer_num=common_layer_num, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + else: + self.fc_dynamics = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + # input_shape: (sequence_length,batch_size,input_size) + # output_shape: (sequence_length, batch_size, hidden_size) + self.lstm = nn.LSTM(input_size=self.latent_state_dim, hidden_size=self.lstm_hidden_size) + + self.fc_reward_head = MLP( + in_channels=self.lstm_hidden_size, + hidden_channels=fc_reward_layers[0], + layer_num=2, + out_channels=output_support_size, + activation=self.activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, state_action_encoding: torch.Tensor, reward_hidden_state): + """ + Overview: + Forward computation of the dynamics network. Predict next latent state given current state_action_encoding and reward hidden state. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ + latent state and action encoding, with shape (batch_size, num_channels, height, width). + - reward_hidden_state (:obj:`Tuple[torch.Tensor, torch.Tensor]`): The input hidden state of LSTM about reward. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, latent_state_dim). + - next_reward_hidden_state (:obj:`torch.Tensor`): The input hidden state of LSTM about reward. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + """ + if self.res_connection_in_dynamics: + # # take the state encoding (latent_state), state_action_encoding[:, -self.action_encoding_dim] + # # is action encoding + # latent_state = state_action_encoding[:, :-self.action_encoding_dim] + # x = self.fc_dynamics_1(state_action_encoding) + # # the residual link: add state encoding to the state_action encoding + # next_latent_state = x + latent_state + # next_latent_state_ = self.fc_dynamics_2(next_latent_state) + + agent_state_action_encoding, global_state_action_encoding = state_action_encoding + # agent + agent_latent_state = agent_state_action_encoding[:, :-self.action_encoding_dim] + x = self.agent_fc_dynamics_1(agent_state_action_encoding) + next_agent_latent_state = x + agent_latent_state + next_agent_latent_state_encoding = self.agent_fc_dynamics_2(next_agent_latent_state) + # global + global_latent_state = global_state_action_encoding[:, :-self.action_encoding_dim] # + x = self.global_fc_dynamics_1(global_state_action_encoding) # x 512 + next_global_latent_state = x + global_latent_state + next_global_latent_state = self.fc_alignment(next_global_latent_state) + next_latent_state_ = self.global_fc_dynamics_2(next_global_latent_state) + else: + next_latent_state = self.fc_dynamics(state_action_encoding) + next_latent_state_ = next_latent_state + + next_latent_state_unsqueeze = next_latent_state_.unsqueeze(0) + value_prefix, next_reward_hidden_state = self.lstm(next_latent_state_unsqueeze, reward_hidden_state) + value_prefix = self.fc_reward_head(value_prefix.squeeze(0)) + + return (next_agent_latent_state, next_global_latent_state), next_reward_hidden_state, value_prefix + + def get_dynamic_mean(self) -> float: + return get_dynamic_mean(self) + + def get_reward_mean(self) -> Tuple[ndarray, float]: return get_reward_mean(self) \ No newline at end of file From 59c7c56fad58e03d04d87fb1991b76d9e889790a Mon Sep 17 00:00:00 2001 From: chosenone Date: Sun, 26 Nov 2023 21:59:41 +0800 Subject: [PATCH 10/12] polish(yzj): polish reward_roots in mz --- lzero/policy/muzero.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index cbca2c18c..82c1800a9 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -535,8 +535,6 @@ def _forward_collect( global_latent_state_roots = global_latent_state_roots.detach().cpu().numpy() # policy_logits_tmp = policy_logits.reshape(active_collect_env_num, self.cfg.model.agent_num, -1) policy_logits = policy_logits.detach().cpu().numpy().tolist() - reward_roots = [[reward_root]*self.cfg.model.agent_num for reward_root in reward_roots] - reward_roots = sum(reward_roots, []) # # joint policy_logits # prob_1 = policy_logits_tmp[:, 0, :].unsqueeze(1).unsqueeze(2) # prob_2 = policy_logits_tmp[:, 1, :].unsqueeze(1).unsqueeze(3) @@ -709,8 +707,6 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 agent_latent_state_roots = agent_latent_state_roots.detach().cpu().numpy() global_latent_state_roots = global_latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) - reward_roots = [[reward_root]*self.cfg.model.agent_num for reward_root in reward_roots] - reward_roots = sum(reward_roots, []) legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(batch_size)] if self._cfg.mcts_ctree: From 829d86d4be4921b2e3b90705553fdbd7e43706db Mon Sep 17 00:00:00 2001 From: chosenone Date: Sun, 26 Nov 2023 22:03:32 +0800 Subject: [PATCH 11/12] fix(yzj): fix device bug --- lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp | 2 +- lzero/policy/efficientzero.py | 1 + lzero/policy/muzero.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp index 0f0a6ae5f..f03d8b0d2 100644 --- a/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp +++ b/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp @@ -381,7 +381,7 @@ namespace tree for (size_t iter = 0; iter < disturbed_probs.size(); iter++) { #ifdef __APPLE__ - disc_action_with_probs.__emplace_back(std::make_pair(iter, disturbed_probs[iter])); + disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter])); #else disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter])); #endif diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index 001b46294..114851463 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -401,6 +401,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # obtain the oracle latent states from representation function. beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) obs_target_batch_tmp = default_collate(obs_target_batch[:, beg_index:end_index].squeeze()) + obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._device) network_output = self._learn_model.initial_inference(obs_target_batch_tmp) latent_state = to_tensor(latent_state) diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 82c1800a9..711bdce1d 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -369,6 +369,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # obtain the oracle latent states from representation function. beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) obs_target_batch_tmp = default_collate(obs_target_batch[:, beg_index:end_index].squeeze()) + obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._device) network_output = self._learn_model.initial_inference(obs_target_batch_tmp) latent_state = to_tensor(latent_state) From 0368c55ae20814852b672139c31f7efc1a0de0df Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Sat, 2 Dec 2023 22:56:42 +0800 Subject: [PATCH 12/12] feature(yzj): polish buffer in ctde --- lzero/mcts/buffer/game_buffer.py | 31 ++++++++++++------- lzero/mcts/buffer/game_buffer_muzero.py | 21 +++++++------ lzero/model/muzero_model_mlp.py | 4 +-- lzero/policy/muzero.py | 9 +++--- .../config/ptz_simple_mz_config.py | 4 +-- .../config/ptz_simple_spread_mz_config.py | 4 +-- zoo/petting_zoo/entry/train_muzero.py | 2 +- .../envs/petting_zoo_simple_spread_env.py | 6 +++- zoo/petting_zoo/model/model.py | 24 +++++++------- 9 files changed, 62 insertions(+), 43 deletions(-) diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index f632d1e7d..59358d9e1 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -118,33 +118,41 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: # +1e-6 for numerical stability probs = self.game_pos_priorities ** self._alpha + 1e-6 - probs /= probs.sum() + if self._cfg.multi_agent: + probs = np.array([probs[i] for i in range(0, len(probs), self._cfg.model.agent_num)]) #TODO: check this + probs /= probs.sum() + else: + probs /= probs.sum() # sample according to transition index # TODO(pu): replace=True - batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False) + batch_index_list = np.random.choice(num_of_transitions//self._cfg.model.agent_num, batch_size, p=probs, replace=False) if self._cfg.reanalyze_outdated is True: # NOTE: used in reanalyze part batch_index_list.sort() - weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self._beta) + weights_list = ((num_of_transitions//self._cfg.model.agent_num) * probs[batch_index_list]) ** (-self._beta) weights_list /= weights_list.max() game_segment_list = [] pos_in_game_segment_list = [] + agent_id_list = [] + true_batch_index_list = [] for idx in batch_index_list: - game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx] + game_segment_idx, pos_in_game_segment, agent_id = self.game_segment_game_pos_look_up[idx] game_segment_idx -= self.base_idx - game_segment = self.game_segment_buffer[game_segment_idx] - - game_segment_list.append(game_segment) - pos_in_game_segment_list.append(pos_in_game_segment) + for i in range(self._cfg.model.agent_num): + game_segment = self.game_segment_buffer[game_segment_idx*self._cfg.model.agent_num+i] + game_segment_list.append(game_segment) + pos_in_game_segment_list.append(pos_in_game_segment) + agent_id_list.append(agent_id+i) + true_batch_index_list.append(idx) - make_time = [time.time() for _ in range(len(batch_index_list))] + make_time = [time.time() for _ in range(len(true_batch_index_list))] - orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) + orig_data = (game_segment_list, pos_in_game_segment_list, true_batch_index_list, weights_list, make_time) return orig_data def _preprocess_to_play_and_action_mask( @@ -349,8 +357,9 @@ def _push_game_segment(self, data: Any, meta: Optional[dict] = None) -> None: self.game_pos_priorities = np.concatenate((self.game_pos_priorities, priorities)) self.game_segment_buffer.append(data) + agent_id = data.obs_segment[0]['agent_id'] self.game_segment_game_pos_look_up += [ - (self.base_idx + len(self.game_segment_buffer) - 1, step_pos) for step_pos in range(len(data)) + (self.base_idx + len(self.game_segment_buffer) - 1, step_pos, agent_id) for step_pos in range(len(data)) ] def remove_oldest_data_to_fit(self) -> None: diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 6feaf65d6..2ff400023 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -201,14 +201,17 @@ def _prepare_reward_value_context( td_steps_list, action_mask_segment, to_play_segment """ zero_obs = game_segment_list[0].zero_obs() - zero_obs = np.array([{'agent_state': np.zeros((18,), dtype=np.float32), - 'global_state': np.zeros((48,), dtype=np.float32), - 'agent_alone_state': np.zeros((14,), dtype=np.float32), - 'agent_alone_padding_state': np.zeros((18,), dtype=np.float32),}]) - zero_obs = np.array([{'agent_state': np.zeros((6,), dtype=np.float32), - 'global_state': np.zeros((14, ), dtype=np.float32), - 'agent_alone_state': np.zeros((12,), dtype=np.float32), - 'agent_alone_padding_state': np.zeros((12,), dtype=np.float32),}]) + zero_obs = np.array([{ + 'agent_id': np.array(0), + 'agent_state': np.zeros((18,), dtype=np.float32), + 'global_state': np.zeros((30,), dtype=np.float32), + 'agent_alone_state': np.zeros((14,), dtype=np.float32), + 'agent_alone_padding_state': np.zeros((18,), dtype=np.float32), + }]) + # zero_obs = np.array([{'agent_state': np.zeros((6,), dtype=np.float32), + # 'global_state': np.zeros((14, ), dtype=np.float32), + # 'agent_alone_state': np.zeros((12,), dtype=np.float32), + # 'agent_alone_padding_state': np.zeros((12,), dtype=np.float32),}]) value_obs_list = [] # the value is valid or not (out of game_segment) value_mask = [] @@ -218,7 +221,7 @@ def _prepare_reward_value_context( action_mask_segment, to_play_segment = [], [] td_steps_list = [] - for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): + for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list): game_segment_len = len(game_segment) game_segment_lens.append(game_segment_len) diff --git a/lzero/model/muzero_model_mlp.py b/lzero/model/muzero_model_mlp.py index 2bd69ff42..28ba7d400 100644 --- a/lzero/model/muzero_model_mlp.py +++ b/lzero/model/muzero_model_mlp.py @@ -315,8 +315,8 @@ def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[t # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. agent_state_action_encoding = torch.cat((agent_latent_state, action_encoding), dim=1) - global_state_action_encoding = torch.cat((agent_latent_state, global_latent_state, action_encoding), dim=1) - (next_agent_latent_state, next_global_latent_state), reward = self.dynamics_network((agent_state_action_encoding, global_state_action_encoding)) + # global_state_action_encoding = torch.cat((agent_latent_state, global_latent_state, action_encoding), dim=1) + (next_agent_latent_state, next_global_latent_state), reward = self.dynamics_network((agent_state_action_encoding, global_latent_state)) if not self.state_norm: return (next_agent_latent_state, next_global_latent_state), reward diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 711bdce1d..87ce6cb01 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -297,8 +297,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in [mask_batch, target_reward, target_value, target_policy, weights] = to_torch_float_tensor(data_list, self._cfg.device) - target_reward = target_reward.view(self._cfg.batch_size, -1) - target_value = target_value.view(self._cfg.batch_size, -1) + target_reward = target_reward.view(self._cfg.batch_size*self._cfg.model.agent_num, -1) + target_value = target_value.view(self._cfg.batch_size*self._cfg.model.agent_num, -1) # assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) @@ -344,8 +344,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) - reward_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) - consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + reward_loss = torch.zeros(self._cfg.batch_size*self._cfg.model.agent_num, device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size*self._cfg.model.agent_num, device=self._cfg.device) # ============================================================== # the core recurrent_inference in MuZero policy. @@ -420,6 +420,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss ) + weights = weights.repeat_interleave(3) weighted_total_loss = (weights * loss).mean() gradient_scale = 1 / self._cfg.num_unroll_steps diff --git a/zoo/petting_zoo/config/ptz_simple_mz_config.py b/zoo/petting_zoo/config/ptz_simple_mz_config.py index 7db1bc0e6..72fedc80f 100644 --- a/zoo/petting_zoo/config/ptz_simple_mz_config.py +++ b/zoo/petting_zoo/config/ptz_simple_mz_config.py @@ -32,7 +32,7 @@ n_landmark=n_landmark, max_cycles=25, agent_obs_only=False, - agent_specific_global_state=True, + agent_specific_global_state=False, continuous_actions=False, stop_value=0, collector_env_num=collector_env_num, @@ -52,7 +52,7 @@ agent_num=n_agent, self_supervised_learning_loss=True, agent_obs_shape=6, - global_obs_shape=14, + global_obs_shape=8, discrete_action_encoding_type='one_hot', global_cooperation=True, # TODO: doesn't work now hidden_size_list=[256, 256], diff --git a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py index 2260f4994..ce25254fb 100644 --- a/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py +++ b/zoo/petting_zoo/config/ptz_simple_spread_mz_config.py @@ -32,7 +32,7 @@ n_landmark=n_landmark, max_cycles=25, agent_obs_only=False, - agent_specific_global_state=True, + agent_specific_global_state=False, continuous_actions=False, stop_value=0, collector_env_num=collector_env_num, @@ -52,7 +52,7 @@ agent_num=n_agent, self_supervised_learning_loss=True, agent_obs_shape=18, - global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, + global_obs_shape=30, discrete_action_encoding_type='one_hot', global_cooperation=True, # TODO: doesn't work now hidden_size_list=[256, 256], diff --git a/zoo/petting_zoo/entry/train_muzero.py b/zoo/petting_zoo/entry/train_muzero.py index aa731f7e0..ba7f777b0 100644 --- a/zoo/petting_zoo/entry/train_muzero.py +++ b/zoo/petting_zoo/entry/train_muzero.py @@ -179,7 +179,7 @@ def train_muzero( # Learn policy from collected data. for i in range(update_per_collect): # Learner will train ``update_per_collect`` times in one iteration. - if replay_buffer.get_num_of_transitions() > batch_size: + if replay_buffer.get_num_of_transitions()//3 > batch_size: train_data = replay_buffer.sample(batch_size, policy) else: logging.warning( diff --git a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index c60fecbed..3bc4b3084 100644 --- a/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/zoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -282,8 +282,12 @@ def _process_obs(self, obs: 'torch.Tensor') -> np.ndarray: # noqa ret_transform = [] for i in range(self._num_agents): tmp = {} + tmp['agent_id'] = i for k,v in ret.items(): - tmp[k] = v[i] + if k == 'global_state': + tmp[k] = v + else: + tmp[k] = v[i] ret_transform.append(tmp) return {'observation': ret_transform, 'action_mask': action_mask, 'to_play': [-1 for _ in range(self._num_agents)]} diff --git a/zoo/petting_zoo/model/model.py b/zoo/petting_zoo/model/model.py index 156bfce02..6603a83f3 100644 --- a/zoo/petting_zoo/model/model.py +++ b/zoo/petting_zoo/model/model.py @@ -29,11 +29,11 @@ def __init__(self, cfg): def forward(self, x): # agent - agent_state = x['agent_state'].reshape(-1, self.agent_obs_shape) + agent_state = x['agent_state'] agent_state = self.agent_encoder(agent_state) # agent_state = agent_state.reshape(batch_size, agent_num, -1) # global - global_state = x['global_state'].reshape(-1, self.global_obs_shape) + global_state = x['global_state'] global_state = self.global_encoder(global_state) # global_state = global_state.reshape(batch_size, agent_num, -1) return (agent_state, global_state) @@ -206,10 +206,10 @@ def __init__( last_linear_layer_init_zero=False, ) self.global_fc_dynamics_1 = MLP( - in_channels=517, + in_channels=1024, hidden_channels=self.latent_state_dim, layer_num=common_layer_num, - out_channels=512, + out_channels=256, activation=activation, norm_type=norm_type, output_activation=True, @@ -218,7 +218,7 @@ def __init__( last_linear_layer_init_zero=False, ) self.global_fc_dynamics_2 = MLP( - in_channels=self.latent_state_dim, + in_channels=256, hidden_channels=self.latent_state_dim, layer_num=common_layer_num, out_channels=self.latent_state_dim, @@ -299,17 +299,19 @@ def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, to # next_latent_state = x + latent_state # next_latent_state_encoding = self.fc_dynamics_2(next_latent_state) - agent_state_action_encoding, global_state_action_encoding = state_action_encoding + agent_state_action_encoding, global_latent_state = state_action_encoding + batch_size = agent_state_action_encoding.shape[0] // 3 # agent agent_latent_state = agent_state_action_encoding[:, :-self.action_encoding_dim] x = self.agent_fc_dynamics_1(agent_state_action_encoding) next_agent_latent_state = x + agent_latent_state - next_agent_latent_state_encoding = self.agent_fc_dynamics_2(next_agent_latent_state) + # next_agent_latent_state_encoding = self.agent_fc_dynamics_2(next_agent_latent_state) # global - global_latent_state = global_state_action_encoding[:, :-self.action_encoding_dim] # - x = self.global_fc_dynamics_1(global_state_action_encoding) # x 512 - next_global_latent_state = x + global_latent_state - next_global_latent_state = self.fc_alignment(next_global_latent_state) + next_agent_latent_state_tmp = next_agent_latent_state.reshape(batch_size, -1) + global_latent_state = global_latent_state[::3, ] + global_latent_state = torch.cat((next_agent_latent_state_tmp, global_latent_state), dim=1) + next_global_latent_state = self.global_fc_dynamics_1(global_latent_state) # x 512 + next_global_latent_state = next_global_latent_state.unsqueeze(1).expand(-1, 3, -1).reshape(-1, 256) next_global_latent_state_encoding = self.global_fc_dynamics_2(next_global_latent_state) else: agent_state_action_encoding, global_state_action_encoding = state_action_encoding