1010"""
1111import logging
1212from types import TracebackType
13- from typing import Any , Callable , Dict , Iterator , List , Mapping , Optional , Type
13+ from typing import Any , Callable , Dict , Iterator , List , Optional , Tuple , Type
1414
1515import torch
1616from torch import nn , optim
@@ -59,8 +59,6 @@ def __init__(
5959 model : nn .Module ,
6060 optimizer : optim .Optimizer ,
6161 sync_every : int ,
62- backup_device : Optional [torch .device ] = None ,
63- pin_memory : bool = True ,
6462 ) -> None :
6563 """
6664 Args:
@@ -78,21 +76,8 @@ def __init__(
7876 self ._local_step = 0
7977 self ._sync_every = sync_every
8078 assert sync_every >= 1 , "sync_every must be greater than or equal to 1"
81- device = backup_device or torch .device ("cpu" )
82- self ._backup_parameters : Dict [str , torch .Tensor ] = {}
83- for name , p in self ._model .named_parameters ():
84- t = torch .empty (* tuple (p .shape ), dtype = p .dtype , device = device )
85- if (
86- pin_memory
87- and t .device == torch .device ("cpu" )
88- and torch .cuda .is_available ()
89- ):
90- t = t .pin_memory ()
91- self ._backup_parameters [name ] = t
9279
9380 self ._hooks : List [RemovableHandle ] = []
94- # Need to copy the parameters to the host to be safe if we are on the first step.
95- self ._save_parameters ()
9681
9782 def __enter__ (self ) -> "LocalSGD" :
9883 # Add optimizer hook which increments the local step counter and syncs if necessary
@@ -108,30 +93,15 @@ def __exit__(
10893 traceback : Optional [TracebackType ],
10994 ) -> bool :
11095 # Handle any cleanup or error handling here
111- if exc_type is not None :
112- # If an exception occurred, restore parameters
113- self ._restore_parameters ()
11496 # Clean up hooks
11597 for hook in self ._hooks :
11698 hook .remove ()
11799 self ._hooks .clear ()
118100
119101 return False # Propagate exceptions
120102
121- def _save_parameters (self ) -> None :
122- with torch .no_grad ():
123- # TODO: consider running copy on a separate stream
124- for name , p in self ._model .named_parameters ():
125- self ._backup_parameters [name ].copy_ (p .data , non_blocking = True )
126-
127- def _restore_parameters (self ) -> None :
128- with torch .no_grad ():
129- # TODO: consider running copy on a separate stream
130- for name , p in self ._model .named_parameters ():
131- p .data .copy_ (self ._backup_parameters [name ], non_blocking = False )
132-
133103 def _step_post_hook (
134- self , _optim : optim .Optimizer , _args : List [ object ], _kwargs : Dict [str , object ]
104+ self , _optim : optim .Optimizer , _args : Tuple [ Any , ... ], _kwargs : Dict [str , Any ]
135105 ) -> None :
136106 """
137107 This hook is registered on the optimizer and is called after the optimizer step.
@@ -151,30 +121,31 @@ def sync(self) -> None:
151121 def _perform_sync (self ) -> None :
152122 """
153123 Performs the synchronization of the model weights across the manager.
154- This method is intended to be overridden by subclasses to implement custom
155- synchronization logic.
156124 """
157- self ._average ()
125+ averaged_parameters = self ._average ()
158126 if self ._manager .should_commit ():
159- self ._save_parameters ()
160- else :
161- # commit failed, restore from the backup parameters
162- self ._restore_parameters ()
163-
164- def _average (self ) -> None :
165- # TODO: do we need to broadcast buffers like DDP does?
127+ # Update the model parameters with the averaged values
128+ for param , avg_param in zip (self ._model .parameters (), averaged_parameters ):
129+ param .data .copy_ (avg_param )
166130
131+ def _average (self ) -> list [torch .Tensor ]:
132+ """
133+ Averages the model parameters across the manager and returns the averaged parameters.
134+ """
167135 works = []
168-
136+ averaged_parameters = []
169137 for p in self ._model .parameters ():
170- # TODO: bucketize parameters
171- works .append (self ._manager .allreduce (p .data .detach ()))
172-
138+ # Create a new tensor to store the averaged parameter
139+ p .data .grad = None
140+ avg_param = p .data .clone ()
141+ works .append (self ._manager .allreduce (avg_param ))
142+ averaged_parameters .append (avg_param )
173143 for work in works :
174144 work .wait ()
145+ return averaged_parameters
175146
176147
177- class DiLoCo ( LocalSGD ) :
148+ class DiLoCo :
178149 """
179150 DiLoCo is a subclass of LocalSGD that overrides the synchronization
180151 mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
@@ -197,27 +168,96 @@ def __init__(
197168 "Using DiLoCo require synchronous quorum to be enabled. "
198169 "Ensure that the manager is initialized with use_async_quorum=False"
199170 )
200- super ().__init__ (
201- manager , model , inner_optimizer , sync_every , backup_device , pin_memory
202- )
171+ super ().__init__ ()
172+ self ._manager = manager
173+ self ._model = model
174+ self ._local_optimizer = inner_optimizer
175+ self ._local_step = 0
176+ self ._sync_every = sync_every
177+ assert sync_every >= 1 , "sync_every must be greater than or equal to 1"
178+ self ._backup_device = backup_device
179+ self ._pin_memory = pin_memory
180+
181+ self ._hooks : List [RemovableHandle ] = []
203182 self ._outer_optimizer = outer_optimizer
183+ self .original_parameters : Dict [str , torch .Tensor ] = {}
184+ for name , p in self ._model .named_parameters ():
185+ t = torch .empty (* tuple (p .shape ), dtype = p .dtype , device = self ._backup_device )
186+ if (
187+ self ._pin_memory
188+ and t .device == torch .device ("cpu" )
189+ and torch .cuda .is_available ()
190+ ):
191+ t = t .pin_memory ()
192+ self .original_parameters [name ] = t
193+
194+ # Need to copy the parameters to the host to be safe if we are on the first step.
195+ self ._save_parameters ()
196+
197+ def _save_parameters (self ) -> None :
198+ with torch .no_grad ():
199+ # TODO: consider running copy on a separate stream
200+ for name , p in self ._model .named_parameters ():
201+ self .original_parameters [name ].copy_ (p .data , non_blocking = True )
202+
203+ def _restore_parameters (self ) -> None :
204+ with torch .no_grad ():
205+ # TODO: consider running copy on a separate stream
206+ for name , p in self ._model .named_parameters ():
207+ p .data .copy_ (self .original_parameters [name ], non_blocking = False )
208+
209+ def __enter__ (self ) -> "DiLoCo" :
210+ # Add optimizer hook which increments the local step counter and syncs if necessary
211+ self ._hooks .append (
212+ self ._local_optimizer .register_step_post_hook (self ._step_post_hook )
213+ )
214+ return self
215+
216+ def __exit__ (
217+ self ,
218+ exc_type : Optional [Type [BaseException ]],
219+ exc_value : Optional [BaseException ],
220+ traceback : Optional [TracebackType ],
221+ ) -> bool :
222+ # Handle any cleanup or error handling here
223+ # Clean up hooks
224+ for hook in self ._hooks :
225+ hook .remove ()
226+ self ._hooks .clear ()
227+
228+ return False # Propagate exceptions
229+
230+ def _step_post_hook (
231+ self , _optim : optim .Optimizer , _args : Tuple [Any , ...], _kwargs : Dict [str , Any ]
232+ ) -> None :
233+ """
234+ This hook is registered on the optimizer and is called after the optimizer step.
235+ """
236+ self ._local_step += 1
237+ if self ._local_step >= self ._sync_every :
238+ self .sync ()
239+
240+ def sync (self ) -> None :
241+ """
242+ Synchronizes and averages the model weights across the manager.
243+ """
244+ self ._manager .start_quorum ()
245+ self ._perform_sync ()
246+ self ._local_step = 0
204247
205248 def _perform_sync (self ) -> None :
206249 """
207250 Overrides the sync method to calculate the pseugradient, average them across the manager group, and
208251 step using the outer optimizer.
209252 """
210-
211253 # Set the .grad field of each parameter to its pseudogradient
212254 for name , p in self ._model .named_parameters ():
213- assert name in self ._backup_parameters
214- pseudogradient = p .data - self ._backup_parameters [name ]
255+ pseudogradient = p .data - self .original_parameters [name ]
215256 p .grad = pseudogradient
216257
217258 self ._average_grads ()
218259 # Restore the parameters back to the previous state
219260 self ._restore_parameters ()
220-
221261 if self ._manager .should_commit ():
222262 # Use the outer optimizer to update the model parameters
223263 self ._outer_optimizer .step ()
0 commit comments