forked from drakesvoboda/DistributedTrainingExperiments
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathEASGDTrainer.py
More file actions
196 lines (142 loc) · 6.82 KB
/
EASGDTrainer.py
File metadata and controls
196 lines (142 loc) · 6.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import concurrent.futures
from concurrent.futures import Future
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from queue import LifoQueue, Empty, Queue, PriorityQueue
from boilerplate import *
from ParameterServer import *
from dataclasses import dataclass, field
from typing import Any
@dataclass(order=True)
class PrioritizedWork:
priority: int
item: Any=field(compare=False)
class ReducerThread(threading.Thread):
def __init__(self, num_trainers: int, num_shards: int, max_in_flight: int, model: torch.nn.Module):
threading.Thread.__init__(self)
self.num_trainers = num_trainers
self.num_shards = num_shards
self.to_send = PriorityQueue()
self.in_flight = []
self.max_in_flight = max_in_flight
def chunk_param(param):
param = param.view(-1)
num_chunks = (len(param)//250000) + 1
return torch.chunk(param, chunks=num_chunks)
def get_param_dict(module: torch.nn.Module):
res = {}
name2key = {}
priorities = {}
idx, priority = 0, 0
for module_name, module in module.named_modules():
priority += 1
for param_name, param in module.named_parameters(recurse=False):
for chunk_idx, chunk in enumerate(chunk_param(param)):
res[idx] = chunk
priorities[idx] = priority
name2key[f"{module_name}.{param_name}.{chunk_idx}"] = idx
idx += 1
return res, name2key, priorities
def shard_params(params, num_shards=1):
shards = [{} for _ in range(num_shards)]
for idx, (key, param) in enumerate(params.items()):
shards[idx%num_shards][key] = param
return shards
self.model = model
self.params, self.param_name_to_idx, self.priorities = get_param_dict(model)
shards = shard_params(self.params, self.num_shards)
self.ps_rref_map = {}
# Get references to each parameter server shard
for idx, shard in enumerate(shards):
param_server_rref = rpc.remote(f"parameter_server_{idx}", get_parameter_server, args=(shard, num_trainers, idx))
for param_name in shard.keys():
self.ps_rref_map[param_name] = param_server_rref
# Sync inital model parameters with parameter server
with torch.no_grad():
for key, p in self.params.items():
fetched = remote_method(ParameterServer.fetch_param, self.ps_rref_map[key], key)
p.copy_(fetched)
def reduce(self, param_name) -> Future:
param_idx = self.param_name_to_idx[param_name]
fut = Future()
work = PrioritizedWork(self.priorities[param_idx], (param_idx, fut))
self.to_send.put(work)
return fut
def run(self):
with torch.no_grad():
while True:
self.step()
def step(self):
self.in_flight = [fut for fut in self.in_flight if not fut.done()]
if len(self.in_flight) > self.max_in_flight: return
try:
send = self.to_send.get()
except Empty:
return
param_idx, fut = send.item
work = rpc.rpc_async(self.ps_rref_map[param_idx].owner(), ParameterServer.easgd_update, args=(self.ps_rref_map[param_idx], param_idx, self.params[param_idx].data))
fut.set_result(work)
self.in_flight.append(work)
class EASGDTrainer(Trainer):
def __init__(self, model: torch.nn.Module, criterion: callable, optim_fn: callable, rank: int, world_size: int, tau: int, stagger=True):
super().__init__(model, criterion, None)
num_trainers = world_size
num_shards = world_size
self.ctx = mp.spawn(run_parameter_server, nprocs=1, args=(world_size + rank, world_size * 2, rank), join=False)
print(f"Trainer {rank} initializing RPC")
rpc.init_rpc(name=f"trainer_{rank}", rank=rank, world_size=world_size * 2)
print(f"Trainer {rank} initialized!")
self.optim_fn = optim_fn
self.tau = tau
# Start reducer thread
self.reducer = ReducerThread(num_trainers, num_shards, 10, model)
reducer_thread = threading.Thread(target=self.reducer.run)
reducer_thread.start()
for name, module in self.model.named_modules():
if len(list(module.children())) > 0: continue
params = list(module.parameters())
if len(params) > 0:
self.first_module = module
self.first_name = name
break
index = 0
for module_name, module in model.named_modules():
if len(list(module.parameters())) <= 0 or len(list(module.children())) > 0: continue
module.updates = None
module.iteration = index if stagger else 0
index += 1
module.optimizer = self.optim_fn(module.parameters(recurse=False))
module.register_full_backward_hook(EASGDTrainer.backwards_pass_hook(self.reducer, self.tau, module_name))
module.register_forward_pre_hook(EASGDTrainer.forward_pre_hook(num_trainers, module_name))
@staticmethod
def backwards_pass_hook(reducer, tau, module_name):
def hook(self, *args):
if not self.training: return
self.optimizer.step()
self.optimizer.zero_grad()
self.iteration += 1
if self.iteration % tau == 0:
self.updates = {param_name: reducer.reduce(f"{module_name}.{param_name}") for param_name, _ in self.named_parameters(recurse=False)}
return hook
@staticmethod
def forward_pre_hook(num_trainers, module_name):
def hook(self, *args):
if self.updates == None: return
with torch.no_grad():
for param_name, param in self.named_parameters(recurse=False):
param -= self.updates[param_name].result().wait() * (.9 / num_trainers)
self.updates = None
return hook
def train(self, schedule: 'TrainingSchedule'):
super().train(schedule)
rpc.shutdown()
self.ctx.join()
def training_step(self, input, label):
output, loss = self.step(input, label)
loss.backward()
# Backwards hook for the first module in the network is not called by pytorch, call it here manually.
EASGDTrainer.backwards_pass_hook(self.reducer, self.tau, self.first_name)(self.first_module)
# self.model.zero_grad()
return output, loss