This repository was archived by the owner on May 9, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 209
This repository was archived by the owner on May 9, 2025. It is now read-only.
SAC Agent For Ant (PyBulletEnv-v0) Has Dimension Mismatch (Training with GAIL) #93
Copy link
Copy link
Open
Description
Important Note: We do not do technical support, nor consulting and don't answer personal questions per email.
Describe the bug
When I use the available SAC agent for AntBulletEnv-v0 to create a dataset for GAIL I get a dimension mismatch. I'm working in this repository and slightly modify the enjoy.py script to setup training.
Code example
import os
import sys
import argparse
import importlib
import warnings
# numpy warnings because of tensorflow
warnings.filterwarnings("ignore", category=FutureWarning, module='tensorflow')
warnings.filterwarnings("ignore", category=UserWarning, module='gym')
import gym
import utils.import_envs # pytype: disable=import-error
import numpy as np
import pickle
import stable_baselines
from stable_baselines.common import set_global_seeds
from stable_baselines.common.vec_env import VecNormalize, VecFrameStack, VecEnv
from stable_baselines.gail import generate_expert_traj
from stable_baselines.gail import ExpertDataset
from stable_baselines import PPO2, SAC, GAIL
from utils import ALGOS, create_test_env, get_latest_run_id, get_saved_hyperparams, find_saved_model
from utils.utils import StoreDict
import numpy as np
# Fix for breaking change in v2.6.0
sys.modules['stable_baselines.ddpg.memory'] = stable_baselines.common.buffers
stable_baselines.common.buffers.Memory = stable_baselines.common.buffers.ReplayBuffer
def evaluate():
parser = argparse.ArgumentParser()
parser.add_argument('--env', help='environment ID', type=str, default='AntBulletEnv-v0')
parser.add_argument('-f', '--folder', help='Log folder', type=str, default='trained_agents')
parser.add_argument('--algo', help='RL Algorithm', default='sac',
type=str, required=False, choices=list(ALGOS.keys()))
parser.add_argument('-n', '--n-timesteps', help='number of timesteps', default=1000,
type=int)
parser.add_argument('--n-envs', help='number of environments', default=1,
type=int)
parser.add_argument('--exp-id', help='Experiment ID (default: -1, no exp folder, 0: latest)', default=-1,
type=int)
parser.add_argument('--log-rollouts', help='Save Expert Trajectory Data', default=None,
type=str)
parser.add_argument('--n-episodes', help='How Many Episdoes to Rollout', default=100,
type=int)
parser.add_argument('--verbose', help='Verbose mode (0: no output, 1: INFO)', default=1,
type=int)
parser.add_argument('--no-render', action='store_true', default=True,
help='Do not render the environment (useful for tests)')
parser.add_argument('--deterministic', action='store_true', default=False,
help='Use deterministic actions')
parser.add_argument('--stochastic', action='store_true', default=False,
help='Use stochastic actions (for DDPG/DQN/SAC)')
parser.add_argument('--load-best', action='store_true', default=False,
help='Load best model instead of last model if available')
parser.add_argument('--norm-reward', action='store_true', default=False,
help='Normalize reward if applicable (trained with VecNormalize)')
parser.add_argument('--seed', help='Random generator seed', type=int, default=np.random.randint(0,1000))
parser.add_argument('--reward-log', help='Where to log reward', default='', type=str)
parser.add_argument('--gym-packages', type=str, nargs='+', default=[], help='Additional external Gym environemnt package modules to import (e.g. gym_minigrid)')
parser.add_argument('--env-kwargs', type=str, nargs='+', action=StoreDict, help='Optional keyword argument to pass to the env constructor')
args = parser.parse_args()
# Going through custom gym packages to let them register in the global registory
for env_module in args.gym_packages:
importlib.import_module(env_module)
env_id = args.env
algo = args.algo
folder = args.folder
if args.exp_id == 0:
args.exp_id = get_latest_run_id(os.path.join(folder, algo), env_id)
print('Loading latest experiment, id={}'.format(args.exp_id))
# Sanity checks
if args.exp_id > 0:
log_path = os.path.join(folder, algo, '{}_{}'.format(env_id, args.exp_id))
else:
log_path = os.path.join(folder, algo)
assert os.path.isdir(log_path), "The {} folder was not found".format(log_path)
model_path = find_saved_model(algo, log_path, env_id, load_best=args.load_best)
if algo in ['dqn', 'ddpg', 'sac', 'td3']:
args.n_envs = 1
set_global_seeds(args.seed)
is_atari = 'NoFrameskip' in env_id
stats_path = os.path.join(log_path, env_id)
hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=args.norm_reward, test_mode=True)
log_dir = args.reward_log if args.reward_log != '' else None
env_kwargs = {} if args.env_kwargs is None else args.env_kwargs
env = create_test_env(env_id, n_envs=args.n_envs, is_atari=is_atari,
stats_path=stats_path, seed=args.seed, log_dir=log_dir,
should_render=not args.no_render,
hyperparams=hyperparams, env_kwargs=env_kwargs)
# ACER raises errors because the environment passed must have
# the same number of environments as the model was trained on.
load_env = None if algo == 'acer' else env
model = ALGOS[algo].load(model_path, env=load_env)
generate_expert_traj(model, 'Ant_Test', n_episodes=10)
dataset = ExpertDataset(expert_path='Ant_Test.npz', traj_limitation=-1, batch_size=128)
model = GAIL('MlpPolicy', 'AntBulletEnv-v0', dataset, verbose=1)
# Note: in practice, you need to train for 1M steps to have a working policy
model.learn(total_timesteps=1.5e6)
model.save("AntGAIL")
model = GAIL.load("AntGAIL")
env = gym.make('AntBulletEnv-v0')
for i in range(10):
obs = env.reset()
rew = []
done = False
while done != True:
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
rew.append(rewards)
print(np.sum(rew))
if __name__ == '__main__':
evaluate()
When I run this I get the error,
pybullet build time: Jun 2 2020 06:47:43
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/stable_baselines/sac/policies.py:194: flatten (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.flatten instead.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/stable_baselines/common/tf_layers.py:57: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.dense instead.
WARNING:tensorflow:From /home/zrobertson/.local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From /home/zrobertson/.local/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
actions (10000, 8)
obs (10000, 29)
rewards (10000,)
episode_returns (10,)
episode_starts (10000,)
actions (10000, 8)
obs (10000, 29)
rewards (10000,)
episode_returns (10,)
episode_starts (10000,)
Total trajectories: -1
Total transitions: 10000
Average returns: 3490.275921516167
Std for returns: 19.19584521063349
Creating environment from the given name, wrapped in a DummyVecEnv.
pybullet build time: Jun 2 2020 06:47:43
********** Iteration 0 ************
Optimizing Policy...
sampling
done in 2.682 seconds
computegrad
done in 0.106 seconds
conjugate_gradient
iter residual norm soln norm
0 0.124 0
1 0.0697 0.0323
2 0.0524 0.31
3 0.0111 0.764
4 0.0856 1.41
5 0.011 1.47
6 0.00497 2.42
7 0.0125 3.12
8 0.0023 3.13
9 0.00331 3.67
10 0.00095 3.72
done in 0.202 seconds
Expected: 0.087 Actual: 0.060
Stepsize OK!
vf
done in 0.059 seconds
sampling
done in 2.462 seconds
computegrad
done in 0.005 seconds
conjugate_gradient
iter residual norm soln norm
0 0.306 0
1 0.0414 0.0311
2 0.2 0.0867
3 0.063 0.32
4 0.277 0.437
5 0.0418 0.967
6 0.0204 0.989
7 0.388 1.64
8 0.0101 1.92
9 0.314 2.19
10 0.00671 2.88
done in 0.017 seconds
Expected: 0.084 Actual: 0.069
Stepsize OK!
vf
done in 0.029 seconds
sampling
done in 2.611 seconds
computegrad
done in 0.004 seconds
conjugate_gradient
iter residual norm soln norm
0 0.166 0
1 0.0871 0.0361
2 0.102 0.302
3 0.166 0.37
4 0.32 0.809
5 0.0234 0.916
6 0.0473 0.942
7 0.0232 1.02
8 0.131 1.08
9 0.0434 1.63
10 0.0287 1.68
done in 0.022 seconds
Expected: 0.076 Actual: 0.062
Stepsize OK!
vf
done in 0.041 seconds
Optimizing Discriminator...
generator_loss | expert_loss | entropy | entropy_loss | generator_acc | expert_acc
Traceback (most recent call last):
File "/home/zrobertson/Atom_Projects/Python/rl-baselines-zoo/enjoy_noise_GAIL.py", line 140, in <module>
evaluate()
File "/home/zrobertson/Atom_Projects/Python/rl-baselines-zoo/enjoy_noise_GAIL.py", line 121, in evaluate
model.learn(total_timesteps=1.5e6)
File "/usr/local/lib/python3.6/dist-packages/stable_baselines/gail/model.py", line 54, in learn
return super().learn(total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps)
File "/usr/local/lib/python3.6/dist-packages/stable_baselines/trpo_mpi/trpo_mpi.py", line 458, in learn
self.reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0))
File "<__array_function__ internals>", line 6, in concatenate
ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 28 and the array at index 1 has size 29
Process finished with exit code 1
System Info
- OS
Ubuntu 18.04 - Describe how stable baselines was installed (pip, docker, source, ...)
I installed stable baselines with pip - GPU models and configuration
No GPU - Python version
Python version 3.6.9 - Tensorflow version
1.13.1 - Gym version
0.12.5 - Pybullet version
2.8.1 - Stable Baselines version
2.10.0
Additional context
This is a general problem where the dimension of this version of Ant has size 29 for SAC despite the real size being 28. The code works with a2c for example. However, the reward is much higher for SAC so I'd like to use this agent.
Metadata
Metadata
Assignees
Labels
No labels