diff --git a/distributed/rpc/batch/parameter_server.py b/distributed/rpc/batch/parameter_server.py index 32c9c5cab9..06bc301c6e 100644 --- a/distributed/rpc/batch/parameter_server.py +++ b/distributed/rpc/batch/parameter_server.py @@ -54,7 +54,7 @@ def update_and_fetch_model(ps_rref, grads): p.grad /= self.batch_update_size self.curr_update_size = 0 self.optimizer.step() - self.optimizer.zero_grad() + self.optimizer.zero_grad(set_to_none=False) fut.set_result(self.model) timed_log("PS updated model") self.future_model = torch.futures.Future()