Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature request] remove overhead in gradient computations #21

Open
stsievert opened this issue Feb 23, 2021 · 1 comment
Open

[feature request] remove overhead in gradient computations #21

stsievert opened this issue Feb 23, 2021 · 1 comment

Comments

@stsievert
Copy link
Owner

stsievert commented Feb 23, 2021

Right now, there's some overhead in gradient computations: deepcopy needs to be called on every computation to make sure gradients aren't shared between workers. This about doubles the time for a gradient computation (128 gradients w/ Wide-ResNet 16-4 and CIFAR10 dataset on a GPU).

Right now, we're using Client(processes=False) to avoid the need for serialization (we're using the GPU, so there's no need to send objects between CPU cores). But one complication of this is that the same memory address is broadcasted to each worker, so they all modify the same object (which leads to a race condition). We call deepcopy to address this immediate need.

How can the call to deepcopy be avoided?

Timing script for serialization
#  [ins] In [30]: %timeit deepcopy(m)
#  35.4 ms ± 237 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
#  [ins] In [31]: %timeit deepcopy(x)
#  92.4 ms ± 1.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
#  [ins] In [34]: %timeit deepcopy(y)
#  97.1 ms ± 2.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

from torchvision.models import resnet34
import pickle
from time import perf_counter as time
from io import BytesIO

import numpy as np

def np_pickle(x):
    start = time()
    y = pickle.loads(pickle.dumps(x))
    end = time()
    return end - start

def np_pickle2(x):
    start = time()
    with BytesIO() as f:
        np.save(f, x)
        y = np.load(f)
    end = time()
    return end - start

def torch_model(m):
    start = time()
    m2 = pickle.loads(pickle.dumps(m))
    end = time()
    return end - start

m = resnet34()
nparams = sum(p.nelement() for p in m.parameters())
x = np.random.uniform(size=nparams)
assert len(x) >= 1e6
n_runs = 30

times = []
for _ in range(n_runs):
    times.append(np_pickle(x))
ms = int(min(times) * 1000)
print(f"Serializing numpy array w/ pickle took at least {ms}ms")  # 265ms

times = []
for _ in range(n_runs):
    times.append(np_pickle(x))
ms = int(min(times) * 1000)
print(f"Serializing numpy array took at least {ms}ms")  # 275ms

times = []
for _ in range(n_runs):
    times.append(torch_model(m))
ms = int(min(times) * 1000)
print(f"Serializing resnet34 took at least {ms}ms")  # 245ms
@stsievert
Copy link
Owner Author

stsievert commented Feb 23, 2021

Fundamentally, there are two ways to avoid deepcopy:

  1. Communicate the model to each worker.
  2. Have each worker hold it's own model.

(1) requires communicating the model to each worker, which requires either serialization or deepcopy. (2) requires a decentralized parameter server where each worker holds onto the model. (2) has some distinct advantages: only the necessary communication is done (communicating gradients), and deepcopy/model serialization are avoided, though it's slightly harder to change the number of workers.

Relevant work:

I think integration with dask-pytorch-ddp is the most straightforward solution: it wraps a distributed protocol that uses some fancy communication tricks (overlapping communication/computation, etc) as mentioned in "PyTorch Distributed: Experiences on Accelerating Data Parallel Training."

@stsievert stsievert changed the title [feature request] accelerate gradient computations [feature request] remove overhead in gradient computations Feb 23, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant