Skip to content

Commit

Permalink
doc(zjow): add API doc for ding agent (#758)
Browse files Browse the repository at this point in the history
* polish API doc for agent ppof and dqn

* add doc for ding agent

* polish code
  • Loading branch information
zjowowen authored Dec 27, 2023
1 parent 779b4b8 commit 4d53074
Show file tree
Hide file tree
Showing 10 changed files with 1,738 additions and 15 deletions.
173 changes: 172 additions & 1 deletion ding/bonus/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,23 @@


class A2CAgent:
"""
Overview:
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
Advantage Actor Critic(A2C).
For more information about the system design of RL agent, please refer to \
<https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
Interface:
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
"""
supported_env_list = list(supported_env_cfg.keys())
"""
Overview:
List of supported envs.
Examples:
>>> from ding.bonus.a2c import A2CAgent
>>> print(A2CAgent.supported_env_list)
"""

def __init__(
self,
Expand All @@ -35,6 +51,52 @@ def __init__(
cfg: Optional[Union[EasyDict, dict]] = None,
policy_state_dict: str = None,
) -> None:
"""
Overview:
Initialize agent for A2C algorithm.
Arguments:
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
- seed (:obj:`int`): The random seed, which is set before running the program. \
Default to 0.
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
- model (:obj:`torch.nn.Module`): The model of A2C algorithm, which should be an instance of class \
:class:`ding.model.VAC`. \
If not specified, a default model will be generated according to the configuration.
- cfg (:obj:Union[EasyDict, dict]): The configuration of A2C algorithm, which is a dict. \
Default to None. If not specified, the default configuration will be used. \
The default configuration can be found in ``ding/config/example/A2C/gym_lunarlander_v2.py``.
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
If specified, the policy will be loaded from this file. Default to None.
.. note::
An RL Agent Instance can be initialized in two basic ways. \
For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
and we want to train an agent with A2C algorithm with default configuration. \
Then we can initialize the agent in the following ways:
>>> agent = A2CAgent(env_id='LunarLanderContinuous-v2')
or, if we want can specify the env_id in the configuration:
>>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
>>> agent = A2CAgent(cfg=cfg)
There are also other arguments to specify the agent when initializing.
For example, if we want to specify the environment instance:
>>> env = CustomizedEnv('LunarLanderContinuous-v2')
>>> agent = A2CAgent(cfg=cfg, env=env)
or, if we want to specify the model:
>>> model = VAC(**cfg.policy.model)
>>> agent = A2CAgent(cfg=cfg, model=model)
or, if we want to reload the policy from a saved policy state dict:
>>> agent = A2CAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
Make sure that the configuration is consistent with the saved policy state dict.
"""

assert env_id is not None or cfg is not None, "Please specify env_id or cfg."

if cfg is not None and not isinstance(cfg, EasyDict):
Expand Down Expand Up @@ -91,6 +153,32 @@ def train(
debug: bool = False,
wandb_sweep: bool = False,
) -> TrainingReturn:
"""
Overview:
Train the agent with A2C algorithm for ``step`` iterations with ``collector_env_num`` collector \
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
recorded and saved by wandb.
Arguments:
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
If not specified, it will be set according to the configuration.
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
If not specified, it will be set according to the configuration.
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
Default to 1000.
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
If set True, base environment manager will be used for easy debugging. Otherwise, \
subprocess environment manager will be used.
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
which is a hyper-parameter optimization process for seeking the best configurations. \
Default to False. If True, the wandb sweep id will be used as the experiment name.
Returns:
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
"""

if debug:
logging.getLogger().setLevel(logging.DEBUG)
logging.debug(self.policy._model)
Expand Down Expand Up @@ -142,6 +230,31 @@ def deploy(
seed: Optional[Union[int, List]] = None,
debug: bool = False
) -> EvalReturn:
"""
Overview:
Deploy the agent with A2C algorithm by interacting with the environment, during which the replay video \
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
Arguments:
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
the replay video of each episode will be saved separately.
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
If not specified, the video will be saved in ``exp_name/videos``.
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
Default to None. If not specified, ``self.seed`` will be used. \
If ``seed`` is an integer, the agent will be deployed once. \
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
If set True, base environment manager will be used for easy debugging. Otherwise, \
subprocess environment manager will be used.
Returns:
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
- eval_value (:obj:`np.float32`): The mean of evaluation return.
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
"""

if debug:
logging.getLogger().setLevel(logging.DEBUG)
# define env and policy
Expand Down Expand Up @@ -227,6 +340,26 @@ def collect_data(
context: Optional[str] = None,
debug: bool = False
) -> None:
"""
Overview:
Collect data with A2C algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
``exp_name/demo_data``.
Arguments:
- env_num (:obj:`int`): The number of collector environments. Default to 8.
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
If not specified, the data will be saved in ``exp_name/demo_data``.
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
If not specified, ``n_episode`` must be specified.
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
If not specified, ``n_sample`` must be specified.
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
If set True, base environment manager will be used for easy debugging. Otherwise, \
subprocess environment manager will be used.
"""

if debug:
logging.getLogger().setLevel(logging.DEBUG)
if n_episode is not None:
Expand Down Expand Up @@ -258,6 +391,27 @@ def batch_evaluate(
context: Optional[str] = None,
debug: bool = False
) -> EvalReturn:
"""
Overview:
Evaluate the agent with A2C algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
environments. The evaluation result will be returned.
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
will only create one evaluator environment to evaluate the agent and save the replay video.
Arguments:
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
If set True, base environment manager will be used for easy debugging. Otherwise, \
subprocess environment manager will be used.
Returns:
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
- eval_value (:obj:`np.float32`): The mean of evaluation return.
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
"""

if debug:
logging.getLogger().setLevel(logging.DEBUG)
# define env and policy
Expand All @@ -280,7 +434,24 @@ def batch_evaluate(
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)

@property
def best(self):
def best(self) -> 'A2CAgent':
"""
Overview:
Load the best model from the checkpoint directory, \
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
The return value is the agent with the best model.
Returns:
- (:obj:`A2CAgent`): The agent with the best model.
Examples:
>>> agent = A2CAgent(env_id='LunarLanderContinuous-v2')
>>> agent.train()
>>> agent = agent.best
.. note::
The best model is the model with the highest evaluation return. If this method is called, the current \
model will be replaced by the best model.
"""

best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
# Load best model if it exists
if os.path.exists(best_model_file_path):
Expand Down
Loading

0 comments on commit 4d53074

Please sign in to comment.