diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index e7113918..b132c707 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -107,7 +107,7 @@ def sync_inner_model(self, model: nn.Module): self._logger.debug("sync inner model") for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): - param.data.copy_(param_offloaded.data.to(param.device)) # todo: use copy_ here + param.data.copy_(param_offloaded.data) # todo: use copy_ here def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: """