-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtrain_client.py
90 lines (75 loc) · 2.74 KB
/
train_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
from datetime import datetime
import glob
import argparse
import pickle
import torch
import zmq
from redq import REDQ as TD3
from rollout_server import OBS_SIZE, ACT_SIZE
parser = argparse.ArgumentParser(description='RealAnt training client')
parser.add_argument('--n_episodes', default=100, type=int)
parser.add_argument('--resume', default='', type=str) # folder path of past run
parser.add_argument('--task', default='walk', type=str)
args = parser.parse_args()
if args.resume == '':
# Create new folder
now = datetime.now()
project_dir = now.strftime('%Y_%m_%d_%H_%M_%S') + "_" + args.task
os.mkdir(project_dir)
start_episode = 0
else:
# Find latest episode and continue from there
project_dir = args.resume
list_of_files = glob.glob(f'{args.resume}/td3_*')
latest_file = max(list_of_files, key=os.path.getctime)
start_episode = int(latest_file.split('.')[-2].split('_')[-1]) + 1
print(f'resuming from {start_episode}')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if args.resume != '':
# load exisiting agent
with open(latest_file, 'rb') as f:
td3 = pickle.load(f)
else:
# create new agent if it doesn't exist
td3 = TD3(device, OBS_SIZE, ACT_SIZE)
# connect to robot
ctx = zmq.Context()
socket = ctx.socket(zmq.REQ)
socket.connect('tcp://localhost:5555')
for episode in range(start_episode, start_episode+args.n_episodes):
now = datetime.now().isoformat()
print(f'\nEpisode {episode} {now}')
print(f'Collecting data...')
# if (not args.resume) and (episode < 10):
if episode < 10:
print("Random episode")
socket.send_pyobj((args.task, None))
else:
# send actor weights from CPU
td3.actor.to('cpu')
socket.send_pyobj((args.task, td3.actor.state_dict()))
td3.actor.to(device)
# recieve new data
new_data = socket.recv_pyobj()
# save episodic data
with open(f'{project_dir}/episode_{episode}.pickle', 'wb') as f:
pickle.dump(new_data, f)
new_transitions, new_info = new_data
_, _, rewards, _, _ = zip(*new_transitions)
cum_rewards = sum([r[0] for r in rewards])
print(f'Return: {cum_rewards}')
if args.task == 'walk' and cum_rewards < -4:
print('Return does not look right. Please check tracking.')
exit()
# update replay buffer
td3.replay_buffer.extend(new_transitions)
if episode >= 9:
print('Training...')
start_time = datetime.utcnow()
for _ in range(len(new_transitions)*10):
td3.update_parameters()
print("Training took: %3.2fs" % (datetime.utcnow() - start_time).total_seconds())
# save agent
with open(f'{project_dir}/td3_{episode}.pickle', 'wb') as f:
pickle.dump(td3, f)