You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Right now, there's some overhead in gradient computations: deepcopy needs to be called on every computation to make sure gradients aren't shared between workers. This about doubles the time for a gradient computation (128 gradients w/ Wide-ResNet 16-4 and CIFAR10 dataset on a GPU).
Right now, we're using Client(processes=False) to avoid the need for serialization (we're using the GPU, so there's no need to send objects between CPU cores). But one complication of this is that the same memory address is broadcasted to each worker, so they all modify the same object (which leads to a race condition). We call deepcopy to address this immediate need.
How can the call to deepcopy be avoided?
Timing script for serialization
# [ins] In [30]: %timeit deepcopy(m)# 35.4 ms ± 237 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)# [ins] In [31]: %timeit deepcopy(x)# 92.4 ms ± 1.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)# [ins] In [34]: %timeit deepcopy(y)# 97.1 ms ± 2.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)fromtorchvision.modelsimportresnet34importpicklefromtimeimportperf_counterastimefromioimportBytesIOimportnumpyasnpdefnp_pickle(x):
start=time()
y=pickle.loads(pickle.dumps(x))
end=time()
returnend-startdefnp_pickle2(x):
start=time()
withBytesIO() asf:
np.save(f, x)
y=np.load(f)
end=time()
returnend-startdeftorch_model(m):
start=time()
m2=pickle.loads(pickle.dumps(m))
end=time()
returnend-startm=resnet34()
nparams=sum(p.nelement() forpinm.parameters())
x=np.random.uniform(size=nparams)
assertlen(x) >=1e6n_runs=30times= []
for_inrange(n_runs):
times.append(np_pickle(x))
ms=int(min(times) *1000)
print(f"Serializing numpy array w/ pickle took at least {ms}ms") # 265mstimes= []
for_inrange(n_runs):
times.append(np_pickle(x))
ms=int(min(times) *1000)
print(f"Serializing numpy array took at least {ms}ms") # 275mstimes= []
for_inrange(n_runs):
times.append(torch_model(m))
ms=int(min(times) *1000)
print(f"Serializing resnet34 took at least {ms}ms") # 245ms
The text was updated successfully, but these errors were encountered:
Fundamentally, there are two ways to avoid deepcopy:
Communicate the model to each worker.
Have each worker hold it's own model.
(1) requires communicating the model to each worker, which requires either serialization or deepcopy. (2) requires a decentralized parameter server where each worker holds onto the model. (2) has some distinct advantages: only the necessary communication is done (communicating gradients), and deepcopy/model serialization are avoided, though it's slightly harder to change the number of workers.
Relevant work:
dask-pytorch-ddp, which allows Dask clusters to be used with PyTorch distributed.
I think integration with dask-pytorch-ddp is the most straightforward solution: it wraps a distributed protocol that uses some fancy communication tricks (overlapping communication/computation, etc) as mentioned in "PyTorch Distributed: Experiences on Accelerating Data Parallel Training."
stsievert
changed the title
[feature request] accelerate gradient computations
[feature request] remove overhead in gradient computations
Feb 23, 2021
Right now, there's some overhead in gradient computations:
deepcopy
needs to be called on every computation to make sure gradients aren't shared between workers. This about doubles the time for a gradient computation (128 gradients w/ Wide-ResNet 16-4 and CIFAR10 dataset on a GPU).Right now, we're using
Client(processes=False)
to avoid the need for serialization (we're using the GPU, so there's no need to send objects between CPU cores). But one complication of this is that the same memory address is broadcasted to each worker, so they all modify the same object (which leads to a race condition). We calldeepcopy
to address this immediate need.How can the call to
deepcopy
be avoided?Timing script for serialization
The text was updated successfully, but these errors were encountered: