forked from SDSMT-SC2AI/ResourceGather
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
38 lines (28 loc) · 1.01 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from utils import parse_args, ensure_dir
import os
import tensorflow as tf
from pysc2 import maps
from pysc2.env import available_actions_printer
from pysc2.env import run_loop
from pysc2.env import sc2_env
from pysc2.lib import stopwatch
from dummy import Dummy
FLAGS = None
def main():
FLAGS = parse_args()
stopwatch.sw.enabled = FLAGS.profile or FLAGS.trace
stopwatch.sw.trace = FLAGS.trace
agent_cls = Dummy
maps.get(FLAGS.map or Dummy.map_name)
tf.reset_default_graph()
config = tf.ConfigProto(
allow_soft_placement=True,
intra_op_parallelism_threads=FLAGS.num_envs,
inter_op_parallelism_threads=FLAGS.num_envs)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
# Make sure all the directories are ready to go
ensure_dir(FLAGS.base_output_dir)
for _dir in [FLAGS.summary_dir, FLAGS.checkpoint_dir, FLAGS.logging_dir, FLAGS.test_dir]:
ensure_dir(os.path.join(FLAGS.base_output_dir, _dir))
a2c = A2C(sess, config_args)