Skip to content

Commit 4ff4641

Browse files
committed
manager: final changes for init_sync
1 parent c594088 commit 4ff4641

File tree

3 files changed

+24
-15
lines changed

3 files changed

+24
-15
lines changed

src/manager.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,7 @@ mod tests {
901901
step: 0,
902902
world_size: 1,
903903
shrink_only: false,
904+
data: String::new(),
904905
},
905906
QuorumMember {
906907
replica_id: "replica_1".to_string(),
@@ -909,6 +910,7 @@ mod tests {
909910
step: 1,
910911
world_size: 1,
911912
shrink_only: false,
913+
data: String::new(),
912914
},
913915
QuorumMember {
914916
replica_id: "replica_2".to_string(),
@@ -917,6 +919,7 @@ mod tests {
917919
step: 0,
918920
world_size: 1,
919921
shrink_only: false,
922+
data: String::new(),
920923
},
921924
QuorumMember {
922925
replica_id: "replica_3".to_string(),
@@ -925,6 +928,7 @@ mod tests {
925928
step: 1,
926929
world_size: 1,
927930
shrink_only: false,
931+
data: String::new(),
928932
},
929933
QuorumMember {
930934
replica_id: "replica_4".to_string(),
@@ -933,6 +937,7 @@ mod tests {
933937
step: 0,
934938
world_size: 1,
935939
shrink_only: false,
940+
data: String::new(),
936941
},
937942
],
938943
created: None,

torchft/manager.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from contextlib import nullcontext
3535
from datetime import timedelta
3636
from enum import Enum
37-
from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, TypeVar
37+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
3838

3939
import torch
4040
from torch.distributed import ReduceOp, TCPStore
@@ -106,6 +106,7 @@ def __init__(
106106
hostname: str = socket.gethostname(),
107107
heartbeat_interval: timedelta = timedelta(milliseconds=100),
108108
checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None,
109+
init_sync: bool = True,
109110
) -> None:
110111
"""
111112
Args:
@@ -143,6 +144,9 @@ def __init__(
143144
hostname: if rank==0, the hostname to advertise to the lighthouse server
144145
checkpoint_transport: the checkpoint transport to use for
145146
transfering checkpoints to recovering replicas, defaults to HTTPTransport
147+
init_sync: whether to synchronize the model weights on step 0. If
148+
all of the model weights are initialized identically via
149+
``torch.set_seed`` you should set this to False.
146150
"""
147151
self._load_state_dict = load_state_dict
148152
self._user_state_dict = state_dict
@@ -152,6 +156,7 @@ def __init__(
152156
self._quorum_timeout = quorum_timeout
153157
self._connect_timeout = connect_timeout
154158
self._world_size_mode = world_size_mode
159+
self._init_sync = init_sync
155160

156161
store_addr = store_addr or os.environ["MASTER_ADDR"]
157162
store_port = store_port or int(os.environ["MASTER_PORT"])
@@ -455,7 +460,7 @@ def _async_quorum(
455460
checkpoint_metadata=self._checkpoint_transport.metadata(),
456461
shrink_only=shrink_only,
457462
timeout=quorum_timeout,
458-
init_sync=self.init_sync,
463+
init_sync=self._init_sync,
459464
)
460465

461466
quorum_id = quorum.quorum_id

torchft/manager_test.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _create_manager(
4040
min_replica_size: int = 2,
4141
world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC,
4242
timeout: timedelta = timedelta(seconds=10),
43+
init_sync: bool = True,
4344
) -> Manager:
4445
pg = create_autospec(ProcessGroup)
4546
pg.errored.return_value = None
@@ -67,6 +68,7 @@ def _create_manager(
6768
use_async_quorum=use_async_quorum,
6869
world_size_mode=world_size_mode,
6970
timeout=timeout,
71+
init_sync=init_sync,
7072
)
7173
self.manager = manager
7274
return manager
@@ -617,7 +619,12 @@ def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None:
617619

618620
@patch("torchft.manager.ManagerClient", autospec=True)
619621
def test_quorum_skip_init(self, client_mock: MagicMock) -> None:
620-
manager = self._create_manager(use_async_quorum=False)
622+
manager = self._create_manager(
623+
use_async_quorum=False,
624+
init_sync=False,
625+
)
626+
627+
self.assertFalse(manager._init_sync)
621628

622629
quorum = QuorumResult()
623630
quorum.quorum_id = 123
@@ -633,16 +640,8 @@ def test_quorum_skip_init(self, client_mock: MagicMock) -> None:
633640
client_mock()._quorum.return_value = quorum
634641

635642
manager.start_quorum()
636-
self.assertEqual(
637-
client_mock()._quorum.call_args.kwargs["init_sync"], True
638-
)
643+
self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], False)
639644

640-
manager.start_quorum(init_sync=True)
641-
self.assertEqual(
642-
client_mock()._quorum.call_args.kwargs["init_sync"], True
643-
)
644-
645-
manager.start_quorum(init_sync=False)
646-
self.assertEqual(
647-
client_mock()._quorum.call_args.kwargs["init_sync"], False
648-
)
645+
manager._init_sync = True
646+
manager.start_quorum()
647+
self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], True)

0 commit comments

Comments
 (0)