-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_dqn.py
executable file
·31 lines (18 loc) · 907 Bytes
/
run_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from unityagents import UnityEnvironment
from src.agents import Agent, ReplayDDQNAgent, PriorityReplayDDQNAgent
from src.environment_utils import Execution_Manager
from src import package_path
LAYERS = [64, 64]
ENV_PATH = package_path + "/unity/Banana.app"
# package_path = '/'.join(__file__.split('/')[:-1])
in_file = 'trained_model.pth'
checkpoint_path = package_path + '/assets/models/{}'.format(in_file)
# plot_fig_path = package_path + '/assets/figs/{}'.format(out_file.split('.')[0] + '.svg')
if __name__ == '__main__':
print(package_path)
agent = Agent.from_file(checkpoint_path)
# agent = PriorityReplayDDQNAgent(states_dim, action_dim, hidden_layers=LAYERS, seed=0)
training_manager = Execution_Manager(agent, ENV_PATH)
score = training_manager.play_episode(200, eps=0.0)
print('Episode Terminated with score: {}'.format(score))
training_manager.exit()