3434from  contextlib  import  nullcontext 
3535from  datetime  import  timedelta 
3636from  enum  import  Enum 
37- from  typing  import  Callable ,  cast , Dict , List , Optional , TYPE_CHECKING ,  TypeVar 
37+ from  typing  import  TYPE_CHECKING ,  Callable , Dict , List , Optional , TypeVar ,  cast 
3838
3939import  torch 
4040from  torch .distributed  import  ReduceOp , TCPStore 
@@ -106,6 +106,7 @@ def __init__(
106106        hostname : str  =  socket .gethostname (),
107107        heartbeat_interval : timedelta  =  timedelta (milliseconds = 100 ),
108108        checkpoint_transport : Optional [CheckpointTransport [Dict [str , T ]]] =  None ,
109+         init_sync : bool  =  True ,
109110    ) ->  None :
110111        """ 
111112        Args: 
@@ -143,6 +144,9 @@ def __init__(
143144            hostname: if rank==0, the hostname to advertise to the lighthouse server 
144145            checkpoint_transport: the checkpoint transport to use for 
145146                transfering checkpoints to recovering replicas, defaults to HTTPTransport 
147+             init_sync: whether to synchronize the model weights on step 0. If 
148+                 all of the model weights are initialized identically via 
149+                 ``torch.set_seed`` you should set this to False. 
146150        """ 
147151        self ._load_state_dict  =  load_state_dict 
148152        self ._user_state_dict  =  state_dict 
@@ -152,6 +156,7 @@ def __init__(
152156        self ._quorum_timeout  =  quorum_timeout 
153157        self ._connect_timeout  =  connect_timeout 
154158        self ._world_size_mode  =  world_size_mode 
159+         self ._init_sync  =  init_sync 
155160
156161        store_addr  =  store_addr  or  os .environ ["MASTER_ADDR" ]
157162        store_port  =  store_port  or  int (os .environ ["MASTER_PORT" ])
@@ -455,7 +460,7 @@ def _async_quorum(
455460            checkpoint_metadata = self ._checkpoint_transport .metadata (),
456461            shrink_only = shrink_only ,
457462            timeout = quorum_timeout ,
458-             init_sync = self .init_sync ,
463+             init_sync = self ._init_sync ,
459464        )
460465
461466        quorum_id  =  quorum .quorum_id 
0 commit comments