Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions deep_rl/agent/BaseAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ def load(self, filename):
def eval_step(self, state):
raise Exception('eval_step not implemented')

def eval_episode(self):
env = self.config.eval_env
def eval_episode(self, env=None):
if env is None:
env = self.config.eval_env
state = env.reset()
total_rewards = 0
while True:
Expand All @@ -41,10 +42,18 @@ def eval_episode(self):
return total_rewards

def eval_episodes(self):
# Do eval task 1
rewards = []
for ep in range(self.config.eval_episodes):
rewards.append(self.eval_episode())
self.config.logger.info('evaluation episode return: %f(%f)' % (
rewards.append(self.eval_episode(self.config.eval_env))
self.config.logger.info('Same env evaluation episode return: %f(%f)' % (
np.mean(rewards), np.std(rewards) / np.sqrt(len(rewards))))

# Do eval task 2
rewards = []
for ep in range(self.config.eval_episodes):
rewards.append(self.eval_episode(self.config.eval_env_alt))
self.config.logger.info('Diff 1 env evaluation episode return: %f(%f)' % (
np.mean(rewards), np.std(rewards) / np.sqrt(len(rewards))))

class BaseActor(mp.Process):
Expand Down Expand Up @@ -126,4 +135,4 @@ def set_network(self, net):
if not self.config.async_actor:
self._network = net
else:
self.__pipe.send([self.NETWORK, net])
self.__pipe.send([self.NETWORK, net])
4 changes: 2 additions & 2 deletions deep_rl/component/atari_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,9 @@ def get_gym_env_specs(env_id):
def make_new_atari_env(env_id, env_mode=0, env_difficulty=0):
specs = get_gym_env_specs(env_id)
env_id = env_id.lower()
error_msg = f"{env_id} not supported in ALE 2.0"
error_msg = "%s not supported in ALE 2.0"%env_id
assert env_id in supported_new_env_games, error_msg
path = os.path.abspath(f"deep_rl/updated_atari_env/roms/{env_id}.bin")
path = os.path.abspath("deep_rl/updated_atari_env/roms/%s.bin"%env_id)
env = UpdatedAtariEnv(rom_path=path, obs_type='image',
mode=env_mode, difficulty=env_difficulty)
env.spec = specs
Expand Down
4 changes: 4 additions & 0 deletions deep_rl/component/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def size(self):
def empty(self):
return not len(self.data)

def clear_buffer(self):
self.data = []
self.pos = 0

class AsyncReplay(mp.Process):
FEED = 0
SAMPLE = 1
Expand Down
6 changes: 5 additions & 1 deletion deep_rl/component/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .bench import Monitor
from ..utils import *
import uuid
import os

class BaseTask:
def __init__(self):
Expand All @@ -17,7 +18,10 @@ def __init__(self):
def set_monitor(self, env, log_dir):
if log_dir is None:
return env
mkdir(log_dir)
try:
os.mkdir(log_dir)
except:
print("File exists")
return Monitor(env, '%s/%s' % (log_dir, uuid.uuid4()))

def reset(self):
Expand Down
1 change: 1 addition & 0 deletions deep_rl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self):
self.mode = 0
self.difficulty = 0
self.load_model = None
self.name = None #Name of Atari Game


@property
Expand Down
9 changes: 5 additions & 4 deletions deep_rl/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,20 @@ def get_logger(name='MAIN', file_name=None, log_dir='./log', skip=False, level=l
fh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s: %(message)s'))
fh.setLevel(level)
logger.addHandler(fh)
return Logger(log_dir, logger, skip)
return Logger(log_dir, logger, file_name, skip)

class Logger(object):
def __init__(self, log_dir, vanilla_logger, skip=False):
def __init__(self, log_dir, vanilla_logger, file_name, skip=False):
try:
for f in os.listdir(log_dir):
if not f.startswith('events'):
continue
os.remove('%s/%s' % (log_dir, f))
# os.remove('%s/%s' % (log_dir, f))
except IOError:
os.mkdir(log_dir)
if not skip:
self.writer = SummaryWriter(log_dir)
print("LogDir: %s"%os.path.join(log_dir, file_name))
self.writer = SummaryWriter(os.path.join(log_dir, file_name))
self.info = vanilla_logger.info
self.debug = vanilla_logger.debug
self.warning = vanilla_logger.warning
Expand Down
40 changes: 36 additions & 4 deletions deep_rl/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
except:
# python == 2.7
from pathlib2 import Path
from ..agent import DQN_agent as DQN_agent
from ..component import PixelAtari

def run_steps(agent):
random_seed()
Expand All @@ -26,21 +28,51 @@ def run_steps(agent):
if config.load_model is not None:
config.logger.info("Loading model {}".format(config.load_model))
agent.load(config.load_model)
env_changed = False
while True:
if config.save_interval and not agent.total_steps % config.save_interval:
agent.save('data/model-%s-%s-%s.bin' % (agent_name, config.task_name, config.tag))
agent.save('data/model-%s-%s-%s-%d.bin' % (agent_name, config.task_name, config.tag, agent.total_steps))
if config.log_interval and not agent.total_steps % config.log_interval and len(agent.episode_rewards):
rewards = agent.episode_rewards
agent.episode_rewards = []
config.logger.info('total steps %d, returns %.2f/%.2f/%.2f/%.2f (mean/median/min/max), %.2f steps/s' % (
agent.total_steps, np.mean(rewards), np.median(rewards), np.min(rewards), np.max(rewards),
config.log_interval / (time.time() - t0)))
config.logger.scalar_summary("MeanReward", np.mean(rewards), step=agent.total_steps)
config.logger.scalar_summary("MedianReward", np.median(rewards), step=agent.total_steps)
config.logger.scalar_summary("MinReward", np.min(rewards), step=agent.total_steps)
config.logger.scalar_summary("MaxReward", np.max(rewards), step=agent.total_steps)
t0 = time.time()
if config.eval_interval and not agent.total_steps % config.eval_interval:
agent.eval_episodes()
if config.max_steps and agent.total_steps >= config.max_steps:
agent.close()
break
if agent.total_steps>0 and agent.total_steps%config.switch_interval==0:
if config.env_difficulty == 0 and not env_changed:
config.env_difficulty = 1
config.logger.info("Environment Difficulty Changed from %d to %d"%(0,config.env_difficulty))
env_changed = True
if config.clearOnEnvChange:
agent.replay.clear_buffer()
elif config.env_difficulty == 1:
config.env_difficulty = 0
config.logger.info("Environment Difficulty Changed from %d to %d"%(1,config.env_difficulty))
if config.clearOnEnvChange:
agent.replay.clear_buffer()
else:
agent.close()
break
config.task_fn = lambda: PixelAtari(
config.name, frame_skip=4, history_length=config.history_length,
use_new_atari_env=config.use_new_atari_env, env_mode=config.env_mode,
env_difficulty=config.env_difficulty)
config.eval_env = PixelAtari(
config.name, frame_skip=4, history_length=config.history_length,
episode_life=False, use_new_atari_env=config.use_new_atari_env,
env_mode=config.env_mode, env_difficulty=config.env_difficulty)
current_state_dict = agent.network.state_dict()
agent.actor = DQN_agent.DQNActor(config)
agent.actor.set_network(agent.network)
# agent.close()
# break
agent.step()

def get_time_str():
Expand Down
4 changes: 2 additions & 2 deletions deep_rl/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import os

def select_device(gpu_id):
# if torch.cuda.is_available() and gpu_id >= 0:
if gpu_id >= 0:
if torch.cuda.is_available() and gpu_id >= 0:
# if gpu_id >= 0:
Config.DEVICE = torch.device('cuda:%d' % (gpu_id))
else:
Config.DEVICE = torch.device('cpu')
Expand Down
41 changes: 27 additions & 14 deletions examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,19 @@ def dqn_cart_pole():
config.sgd_update_frequency = 4
config.gradient_clip = 5
config.eval_interval = int(5e3)
config.max_steps = 1e5
config.max_steps = 1e7
# config.async_actor = False
config.logger = get_logger()
run_steps(DQNAgent(config))

def dqn_pixel_atari(name):
config = Config()
config.double_q = True
config.history_length = 4
config.use_new_atari_env = False
config.use_new_atari_env = True
config.env_mode = 0
config.env_difficulty = 0
config.name = name
config.task_fn = lambda: PixelAtari(
name, frame_skip=4, history_length=config.history_length,
log_dir=get_default_log_dir(dqn_pixel_atari.__name__),
Expand All @@ -49,33 +51,44 @@ def dqn_pixel_atari(name):
name, frame_skip=4, history_length=config.history_length,
episode_life=False, use_new_atari_env=config.use_new_atari_env,
env_mode=config.env_mode, env_difficulty=config.env_difficulty)
config.eval_env_alt = PixelAtari(
name, frame_skip=4, history_length=config.history_length,
episode_life=False, use_new_atari_env=config.use_new_atari_env,
env_mode=config.env_mode, env_difficulty=1)

config.optimizer_fn = lambda params: torch.optim.RMSprop(
params, lr=0.00025, alpha=0.95, eps=0.01, centered=True)
config.network_fn = lambda: VanillaNet(config.action_dim, NatureConvBody(in_channels=config.history_length))
# config.network_fn = lambda: DuelingNet(config.action_dim, NatureConvBody(in_channels=config.history_length))
config.random_action_prob = LinearSchedule(1.0, 0.01, 1e6)

# config.replay_fn = lambda: Replay(memory_size=int(1e6), batch_size=32)
config.replay_fn = lambda: AsyncReplay(memory_size=int(1e6), batch_size=32)
config.async_actor = False
config.memory_size = int(2e5)
# memory_size = 200000
config.replay_fn = lambda: Replay(memory_size=config.memory_size, batch_size=64)
# config.replay_fn = lambda: AsyncReplay(memory_size=memory_size, batch_size=64)

config.batch_size = 32
config.batch_size = 64
config.state_normalizer = ImageNormalizer()
config.reward_normalizer = SignNormalizer()
config.discount = 0.99
config.target_network_update_freq = 10000
config.exploration_steps = 50000
config.exploration_steps = int(1e6)
config.sgd_update_frequency = 4
config.gradient_clip = 5
# config.double_q = True
config.double_q = False
config.max_steps = int(2e7)
config.eval_interval = int(1e4)
config.double_q = True
# config.double_q = False
config.max_steps = int(1e7)
# config.max_steps = int(1e4)
config.switch_interval = int(1e7)
config.eval_interval = int(5e5)
config.logger = get_logger(file_name=dqn_pixel_atari.__name__)

config.load_model = None
config.save_interval = int(1e5)

config.clearOnEnvChange = False
# config.max_steps = int(5e7)
config.logger.info(config.__dict__)
run_steps(DQNAgent(config))

def dqn_ram_atari(name):
Expand Down Expand Up @@ -554,9 +567,9 @@ def action_conditional_video_prediction():
mkdir('data/video')
mkdir('dataset')
mkdir('log')
set_one_thread()
select_device(-1)
# select_device(0)
# set_one_thread()
# select_device(-1)
select_device(0)

# dqn_cart_pole()
# quantile_regression_dqn_cart_pole()
Expand Down
1 change: 1 addition & 0 deletions test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This is a test.