diff --git a/deep_rl/agent/BaseAgent.py b/deep_rl/agent/BaseAgent.py index 3a00f431..069694d7 100644 --- a/deep_rl/agent/BaseAgent.py +++ b/deep_rl/agent/BaseAgent.py @@ -28,8 +28,9 @@ def load(self, filename): def eval_step(self, state): raise Exception('eval_step not implemented') - def eval_episode(self): - env = self.config.eval_env + def eval_episode(self, env=None): + if env is None: + env = self.config.eval_env state = env.reset() total_rewards = 0 while True: @@ -41,10 +42,18 @@ def eval_episode(self): return total_rewards def eval_episodes(self): + # Do eval task 1 rewards = [] for ep in range(self.config.eval_episodes): - rewards.append(self.eval_episode()) - self.config.logger.info('evaluation episode return: %f(%f)' % ( + rewards.append(self.eval_episode(self.config.eval_env)) + self.config.logger.info('Same env evaluation episode return: %f(%f)' % ( + np.mean(rewards), np.std(rewards) / np.sqrt(len(rewards)))) + + # Do eval task 2 + rewards = [] + for ep in range(self.config.eval_episodes): + rewards.append(self.eval_episode(self.config.eval_env_alt)) + self.config.logger.info('Diff 1 env evaluation episode return: %f(%f)' % ( np.mean(rewards), np.std(rewards) / np.sqrt(len(rewards)))) class BaseActor(mp.Process): @@ -126,4 +135,4 @@ def set_network(self, net): if not self.config.async_actor: self._network = net else: - self.__pipe.send([self.NETWORK, net]) \ No newline at end of file + self.__pipe.send([self.NETWORK, net]) diff --git a/deep_rl/component/atari_wrapper.py b/deep_rl/component/atari_wrapper.py index 9b4bf80d..c449c267 100644 --- a/deep_rl/component/atari_wrapper.py +++ b/deep_rl/component/atari_wrapper.py @@ -305,9 +305,9 @@ def get_gym_env_specs(env_id): def make_new_atari_env(env_id, env_mode=0, env_difficulty=0): specs = get_gym_env_specs(env_id) env_id = env_id.lower() - error_msg = f"{env_id} not supported in ALE 2.0" + error_msg = "%s not supported in ALE 2.0"%env_id assert env_id in supported_new_env_games, error_msg - path = os.path.abspath(f"deep_rl/updated_atari_env/roms/{env_id}.bin") + path = os.path.abspath("deep_rl/updated_atari_env/roms/%s.bin"%env_id) env = UpdatedAtariEnv(rom_path=path, obs_type='image', mode=env_mode, difficulty=env_difficulty) env.spec = specs diff --git a/deep_rl/component/replay.py b/deep_rl/component/replay.py index d9454c3a..07f68bd7 100644 --- a/deep_rl/component/replay.py +++ b/deep_rl/component/replay.py @@ -43,6 +43,10 @@ def size(self): def empty(self): return not len(self.data) + def clear_buffer(self): + self.data = [] + self.pos = 0 + class AsyncReplay(mp.Process): FEED = 0 SAMPLE = 1 diff --git a/deep_rl/component/task.py b/deep_rl/component/task.py index 14eb8327..23262e7c 100644 --- a/deep_rl/component/task.py +++ b/deep_rl/component/task.py @@ -9,6 +9,7 @@ from .bench import Monitor from ..utils import * import uuid +import os class BaseTask: def __init__(self): @@ -17,7 +18,10 @@ def __init__(self): def set_monitor(self, env, log_dir): if log_dir is None: return env - mkdir(log_dir) + try: + os.mkdir(log_dir) + except: + print("File exists") return Monitor(env, '%s/%s' % (log_dir, uuid.uuid4())) def reset(self): diff --git a/deep_rl/utils/config.py b/deep_rl/utils/config.py index 59325d79..640dc3ce 100644 --- a/deep_rl/utils/config.py +++ b/deep_rl/utils/config.py @@ -59,6 +59,7 @@ def __init__(self): self.mode = 0 self.difficulty = 0 self.load_model = None + self.name = None #Name of Atari Game @property diff --git a/deep_rl/utils/logger.py b/deep_rl/utils/logger.py index bba1fdf4..4d74880c 100644 --- a/deep_rl/utils/logger.py +++ b/deep_rl/utils/logger.py @@ -21,19 +21,20 @@ def get_logger(name='MAIN', file_name=None, log_dir='./log', skip=False, level=l fh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s: %(message)s')) fh.setLevel(level) logger.addHandler(fh) - return Logger(log_dir, logger, skip) + return Logger(log_dir, logger, file_name, skip) class Logger(object): - def __init__(self, log_dir, vanilla_logger, skip=False): + def __init__(self, log_dir, vanilla_logger, file_name, skip=False): try: for f in os.listdir(log_dir): if not f.startswith('events'): continue - os.remove('%s/%s' % (log_dir, f)) + # os.remove('%s/%s' % (log_dir, f)) except IOError: os.mkdir(log_dir) if not skip: - self.writer = SummaryWriter(log_dir) + print("LogDir: %s"%os.path.join(log_dir, file_name)) + self.writer = SummaryWriter(os.path.join(log_dir, file_name)) self.info = vanilla_logger.info self.debug = vanilla_logger.debug self.warning = vanilla_logger.warning diff --git a/deep_rl/utils/misc.py b/deep_rl/utils/misc.py index 0df5f0bf..ff7f8965 100644 --- a/deep_rl/utils/misc.py +++ b/deep_rl/utils/misc.py @@ -17,6 +17,8 @@ except: # python == 2.7 from pathlib2 import Path +from ..agent import DQN_agent as DQN_agent +from ..component import PixelAtari def run_steps(agent): random_seed() @@ -26,21 +28,51 @@ def run_steps(agent): if config.load_model is not None: config.logger.info("Loading model {}".format(config.load_model)) agent.load(config.load_model) + env_changed = False while True: if config.save_interval and not agent.total_steps % config.save_interval: - agent.save('data/model-%s-%s-%s.bin' % (agent_name, config.task_name, config.tag)) + agent.save('data/model-%s-%s-%s-%d.bin' % (agent_name, config.task_name, config.tag, agent.total_steps)) if config.log_interval and not agent.total_steps % config.log_interval and len(agent.episode_rewards): rewards = agent.episode_rewards agent.episode_rewards = [] config.logger.info('total steps %d, returns %.2f/%.2f/%.2f/%.2f (mean/median/min/max), %.2f steps/s' % ( agent.total_steps, np.mean(rewards), np.median(rewards), np.min(rewards), np.max(rewards), config.log_interval / (time.time() - t0))) + config.logger.scalar_summary("MeanReward", np.mean(rewards), step=agent.total_steps) + config.logger.scalar_summary("MedianReward", np.median(rewards), step=agent.total_steps) + config.logger.scalar_summary("MinReward", np.min(rewards), step=agent.total_steps) + config.logger.scalar_summary("MaxReward", np.max(rewards), step=agent.total_steps) t0 = time.time() if config.eval_interval and not agent.total_steps % config.eval_interval: agent.eval_episodes() - if config.max_steps and agent.total_steps >= config.max_steps: - agent.close() - break + if agent.total_steps>0 and agent.total_steps%config.switch_interval==0: + if config.env_difficulty == 0 and not env_changed: + config.env_difficulty = 1 + config.logger.info("Environment Difficulty Changed from %d to %d"%(0,config.env_difficulty)) + env_changed = True + if config.clearOnEnvChange: + agent.replay.clear_buffer() + elif config.env_difficulty == 1: + config.env_difficulty = 0 + config.logger.info("Environment Difficulty Changed from %d to %d"%(1,config.env_difficulty)) + if config.clearOnEnvChange: + agent.replay.clear_buffer() + else: + agent.close() + break + config.task_fn = lambda: PixelAtari( + config.name, frame_skip=4, history_length=config.history_length, + use_new_atari_env=config.use_new_atari_env, env_mode=config.env_mode, + env_difficulty=config.env_difficulty) + config.eval_env = PixelAtari( + config.name, frame_skip=4, history_length=config.history_length, + episode_life=False, use_new_atari_env=config.use_new_atari_env, + env_mode=config.env_mode, env_difficulty=config.env_difficulty) + current_state_dict = agent.network.state_dict() + agent.actor = DQN_agent.DQNActor(config) + agent.actor.set_network(agent.network) + # agent.close() + # break agent.step() def get_time_str(): diff --git a/deep_rl/utils/torch_utils.py b/deep_rl/utils/torch_utils.py index 289df9b1..845b8781 100644 --- a/deep_rl/utils/torch_utils.py +++ b/deep_rl/utils/torch_utils.py @@ -9,8 +9,8 @@ import os def select_device(gpu_id): - # if torch.cuda.is_available() and gpu_id >= 0: - if gpu_id >= 0: + if torch.cuda.is_available() and gpu_id >= 0: + # if gpu_id >= 0: Config.DEVICE = torch.device('cuda:%d' % (gpu_id)) else: Config.DEVICE = torch.device('cpu') diff --git a/examples.py b/examples.py index cfd718d1..e630996e 100644 --- a/examples.py +++ b/examples.py @@ -29,17 +29,19 @@ def dqn_cart_pole(): config.sgd_update_frequency = 4 config.gradient_clip = 5 config.eval_interval = int(5e3) - config.max_steps = 1e5 + config.max_steps = 1e7 # config.async_actor = False config.logger = get_logger() run_steps(DQNAgent(config)) def dqn_pixel_atari(name): config = Config() + config.double_q = True config.history_length = 4 - config.use_new_atari_env = False + config.use_new_atari_env = True config.env_mode = 0 config.env_difficulty = 0 + config.name = name config.task_fn = lambda: PixelAtari( name, frame_skip=4, history_length=config.history_length, log_dir=get_default_log_dir(dqn_pixel_atari.__name__), @@ -49,6 +51,10 @@ def dqn_pixel_atari(name): name, frame_skip=4, history_length=config.history_length, episode_life=False, use_new_atari_env=config.use_new_atari_env, env_mode=config.env_mode, env_difficulty=config.env_difficulty) + config.eval_env_alt = PixelAtari( + name, frame_skip=4, history_length=config.history_length, + episode_life=False, use_new_atari_env=config.use_new_atari_env, + env_mode=config.env_mode, env_difficulty=1) config.optimizer_fn = lambda params: torch.optim.RMSprop( params, lr=0.00025, alpha=0.95, eps=0.01, centered=True) @@ -56,26 +62,33 @@ def dqn_pixel_atari(name): # config.network_fn = lambda: DuelingNet(config.action_dim, NatureConvBody(in_channels=config.history_length)) config.random_action_prob = LinearSchedule(1.0, 0.01, 1e6) - # config.replay_fn = lambda: Replay(memory_size=int(1e6), batch_size=32) - config.replay_fn = lambda: AsyncReplay(memory_size=int(1e6), batch_size=32) + config.async_actor = False + config.memory_size = int(2e5) + # memory_size = 200000 + config.replay_fn = lambda: Replay(memory_size=config.memory_size, batch_size=64) + # config.replay_fn = lambda: AsyncReplay(memory_size=memory_size, batch_size=64) - config.batch_size = 32 + config.batch_size = 64 config.state_normalizer = ImageNormalizer() config.reward_normalizer = SignNormalizer() config.discount = 0.99 config.target_network_update_freq = 10000 - config.exploration_steps = 50000 + config.exploration_steps = int(1e6) config.sgd_update_frequency = 4 config.gradient_clip = 5 - # config.double_q = True - config.double_q = False - config.max_steps = int(2e7) - config.eval_interval = int(1e4) + config.double_q = True + # config.double_q = False + config.max_steps = int(1e7) + # config.max_steps = int(1e4) + config.switch_interval = int(1e7) + config.eval_interval = int(5e5) config.logger = get_logger(file_name=dqn_pixel_atari.__name__) config.load_model = None config.save_interval = int(1e5) - + config.clearOnEnvChange = False + # config.max_steps = int(5e7) + config.logger.info(config.__dict__) run_steps(DQNAgent(config)) def dqn_ram_atari(name): @@ -554,9 +567,9 @@ def action_conditional_video_prediction(): mkdir('data/video') mkdir('dataset') mkdir('log') - set_one_thread() - select_device(-1) - # select_device(0) + # set_one_thread() + # select_device(-1) + select_device(0) # dqn_cart_pole() # quantile_regression_dqn_cart_pole() diff --git a/test.txt b/test.txt new file mode 100644 index 00000000..273c1a9f --- /dev/null +++ b/test.txt @@ -0,0 +1 @@ +This is a test. \ No newline at end of file