-
Notifications
You must be signed in to change notification settings - Fork 301
/
Copy pathmanager.py
628 lines (539 loc) · 22 KB
/
manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import queue
from typing import Optional
from warnings import warn
import numpy as np
import torch
import torch.distributed as dist
from modulus.distributed.config import ProcessGroupConfig, ProcessGroupNode
class ModulusUndefinedGroupError(Exception):
"""Exception for querying an undefined process group using the Modulus DistributedManager"""
def __init__(self, name: str):
"""
Parameters
----------
name : str
Name of the process group being queried.
"""
message = (
f"Cannot query process group '{name}' before it is explicitly created."
)
super().__init__(message)
class ModulusUninitializedDistributedManagerWarning(Warning):
"""Warning to indicate usage of an uninitialized DistributedManager"""
def __init__(self):
message = (
"A DistributedManager object is being instantiated before "
+ "this singleton class has been initialized. Instantiating a manager before "
+ "initialization can lead to unexpected results where processes fail "
+ "to communicate. Initialize the distributed manager via "
+ "DistributedManager.initialize() before instantiating."
)
super().__init__(message)
class DistributedManager(object):
"""Distributed Manager for setting up distributed training environment.
This is a singleton that creates a persistance class instance for storing parallel
environment information through out the life time of the program. This should be
used to help set up Distributed Data Parallel and parallel datapipes.
Note
----
One should call `DistributedManager.initialize()` prior to constructing a manager
object
Example
-------
>>> DistributedManager.initialize()
>>> manager = DistributedManager()
>>> manager.rank
0
>>> manager.world_size
1
"""
_shared_state = {}
def __new__(cls):
obj = super(DistributedManager, cls).__new__(cls)
obj.__dict__ = cls._shared_state
# Set the defaults
if not hasattr(obj, "_rank"):
obj._rank = 0
if not hasattr(obj, "_world_size"):
obj._world_size = 1
if not hasattr(obj, "_local_rank"):
obj._local_rank = 0
if not hasattr(obj, "_distributed"):
obj._distributed = False
if not hasattr(obj, "_device"):
obj._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if not hasattr(obj, "_cuda"):
obj._cuda = torch.cuda.is_available()
if not hasattr(obj, "_broadcast_buffers"):
obj._broadcast_buffers = False
if not hasattr(obj, "_find_unused_parameters"):
obj._find_unused_parameters = False
if not hasattr(obj, "_initialization_method"):
obj._initialization_method = "None"
if not hasattr(obj, "_groups"):
obj._groups = {}
if not hasattr(obj, "_group_ranks"):
obj._group_ranks = {}
if not hasattr(obj, "_group_names"):
obj._group_names = {}
if not hasattr(obj, "_is_initialized"):
obj._is_initialized = False
return obj
def __init__(self):
if not self._is_initialized:
raise ModulusUninitializedDistributedManagerWarning()
super().__init__()
@property
def rank(self):
"""Process rank"""
return self._rank
@property
def local_rank(self):
"""Process rank on local machine"""
return self._local_rank
@property
def world_size(self):
"""Number of processes in distributed enviroment"""
return self._world_size
@property
def device(self):
"""Process device"""
return self._device
@property
def distributed(self):
"""Distributed enviroment"""
return self._distributed
@property
def cuda(self):
"""If cuda is available"""
return self._cuda
@property
def group_names(self):
"""
Returns a list of all named process groups created
"""
return self._groups.keys()
def group(self, name=None):
"""
Returns a process group with the given name
If name is None, group is also None indicating the default process group
If named group does not exist, ModulusUndefinedGroupError exception is raised
"""
if name in self._groups.keys():
return self._groups[name]
elif name is None:
return None
else:
raise ModulusUndefinedGroupError(name)
def group_size(self, name=None):
"""
Returns the size of named process group
"""
if name is None:
return self._world_size
group = self.group(name)
return dist.get_world_size(group=group)
def group_rank(self, name=None):
"""
Returns the rank in named process group
"""
if name is None:
return self._rank
group = self.group(name)
return dist.get_rank(group=group)
def group_name(self, group=None):
"""
Returns the name of process group
"""
if group is None:
return None
return self._group_names[group]
@property
def broadcast_buffers(self):
"""broadcast_buffers in PyTorch DDP"""
return self._broadcast_buffers
@broadcast_buffers.setter
def broadcast_buffers(self, broadcast: bool):
"""Setter for broadcast_buffers"""
self._broadcast_buffers = broadcast
@property
def find_unused_parameters(self):
"""find_unused_parameters in PyTorch DDP"""
return self._find_unused_parameters
@find_unused_parameters.setter
def find_unused_parameters(self, find_params: bool):
"""Setter for find_unused_parameters"""
if find_params:
warn(
"Setting `find_unused_parameters` in DDP to true, "
"use only if necessary."
)
self._find_unused_parameters = find_params
def __str__(self):
output = (
f"Initialized process {self.rank} of {self.world_size} using "
f"method '{self._initialization_method}'. Device set to {str(self.device)}"
)
return output
@classmethod
def is_initialized(cls) -> bool:
"""If manager singleton has been initialized"""
return cls._shared_state.get("_is_initialized", False)
@staticmethod
def get_available_backend():
"""Get communication backend"""
if torch.cuda.is_available() and torch.distributed.is_nccl_available():
return "nccl"
else:
return "gloo"
@staticmethod
def initialize_env(**kwargs):
"""Setup method using generic initialization"""
rank = int(os.environ.get("RANK"))
world_size = int(os.environ.get("WORLD_SIZE"))
if "LOCAL_RANK" in os.environ:
local_rank = os.environ.get("LOCAL_RANK")
if local_rank is not None:
local_rank = int(local_rank)
else:
local_rank = rank % torch.cuda.device_count()
else:
local_rank = rank % torch.cuda.device_count()
# Read env variables
addr = os.environ.get("MASTER_ADDR")
port = os.environ.get("MASTER_PORT")
DistributedManager.setup(
rank=rank,
world_size=world_size,
local_rank=local_rank,
addr=addr,
port=port,
backend=DistributedManager.get_available_backend(),
**kwargs,
)
@staticmethod
def initialize_open_mpi(addr, port, **kwargs):
"""Setup method using OpenMPI initialization"""
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK"))
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE"))
local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK"))
DistributedManager.setup(
rank=rank,
world_size=world_size,
local_rank=local_rank,
addr=addr,
port=port,
backend=DistributedManager.get_available_backend(),
method="openmpi",
**kwargs,
)
@staticmethod
def initialize_slurm(port, **kwargs):
"""Setup method using SLURM initialization"""
rank = int(os.environ.get("SLURM_PROCID"))
world_size = int(os.environ.get("SLURM_NPROCS"))
local_rank = int(os.environ.get("SLURM_LOCALID"))
addr = os.environ.get("SLURM_LAUNCH_NODE_IPADDR")
DistributedManager.setup(
rank=rank,
world_size=world_size,
local_rank=local_rank,
addr=addr,
port=port,
backend=DistributedManager.get_available_backend(),
method="slurm",
**kwargs,
)
@staticmethod
def initialize(**kwargs):
"""
Initialize distributed manager
Current supported initialization methods are:
`ENV`: PyTorch environment variable initialization
https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
`SLURM`: Initialization on SLURM systems.
Uses `SLURM_PROCID`, `SLURM_NPROCS`, `SLURM_LOCALID` and
`SLURM_LAUNCH_NODE_IPADDR` environment variables.
`OPENMPI`: Initialization for OpenMPI launchers.
Uses `OMPI_COMM_WORLD_RANK`, `OMPI_COMM_WORLD_SIZE` and
`OMPI_COMM_WORLD_LOCAL_RANK` environment variables.
Initialization by default is done using the first valid method in the order
listed above. Initialization method can also be explicitly controlled using the
`MODULUS_DISTRIBUTED_INITIALIZATION_METHOD` environment variable and setting it
to one of the options above.
kwargs are passed down to torch.distributed.init_process_group directly. This can be used
to set parameters like `timeout=timedelta(minutes=60)`
"""
if DistributedManager.is_initialized():
warn("Distributed manager is already intialized")
return
addr = os.getenv("MASTER_ADDR", "localhost")
port = os.getenv("MASTER_PORT", "12355")
# https://pytorch.org/docs/master/notes/cuda.html#id5
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
initialization_method = os.getenv("MODULUS_DISTRIBUTED_INITIALIZATION_METHOD")
if initialization_method is None:
try:
DistributedManager.initialize_env(**kwargs)
except TypeError:
if "SLURM_PROCID" in os.environ:
DistributedManager.initialize_slurm(port, **kwargs)
elif "OMPI_COMM_WORLD_RANK" in os.environ:
DistributedManager.initialize_open_mpi(addr, port, **kwargs)
else:
warn(
"Could not initialize using ENV, SLURM or OPENMPI methods. Assuming this is a single process job"
)
DistributedManager._shared_state["_is_initialized"] = True
elif initialization_method == "ENV":
DistributedManager.initialize_env(**kwargs)
elif initialization_method == "SLURM":
DistributedManager.initialize_slurm(port, **kwargs)
elif initialization_method == "OPENMPI":
DistributedManager.initialize_open_mpi(addr, port, **kwargs)
else:
raise RuntimeError(
"Unknown initialization method "
f"{initialization_method}. "
"Supported values for "
"MODULUS_DISTRIBUTED_INITIALIZATION_METHOD are "
"ENV, SLURM and OPENMPI"
)
# Set per rank numpy random seed for data sampling
np.random.seed(seed=DistributedManager().rank)
@staticmethod
def setup(
rank=0,
world_size=1,
local_rank=None,
addr="localhost",
port="12355",
backend="nccl",
method="env",
**kwargs,
):
"""Set up PyTorch distributed process group and update manager attributes"""
os.environ["MASTER_ADDR"] = addr
os.environ["MASTER_PORT"] = str(port)
DistributedManager._shared_state["_is_initialized"] = True
manager = DistributedManager()
manager._distributed = torch.distributed.is_available()
if manager._distributed:
# Update rank and world_size if using distributed
manager._rank = rank
manager._world_size = world_size
if local_rank is None:
manager._local_rank = rank % torch.cuda.device_count()
else:
manager._local_rank = local_rank
manager._device = torch.device(
f"cuda:{manager.local_rank}" if torch.cuda.is_available() else "cpu"
)
if manager._distributed:
# Setup distributed process group
try:
dist.init_process_group(
backend,
rank=manager.rank,
world_size=manager.world_size,
device_id=manager.device,
**kwargs,
)
except TypeError:
# device_id only introduced in PyTorch 2.3
dist.init_process_group(
backend,
rank=manager.rank,
world_size=manager.world_size,
**kwargs,
)
if torch.cuda.is_available():
# Set device for this process and empty cache to optimize memory usage
torch.cuda.set_device(manager.device)
torch.cuda.device(manager.device)
torch.cuda.empty_cache()
manager._initialization_method = method
@staticmethod
def create_process_subgroup(
name: str, size: int, group_name: Optional[str] = None, verbose: bool = False
): # pragma: no cover
"""
Create a process subgroup of a parent process group. This must be a collective
call by all processes participating in this application.
Parameters
----------
name : str
Name of the process subgroup to be created.
size : int
Size of the process subgroup to be created. This must be an integer factor of
the parent group's size.
group_name : Optional[str]
Name of the parent process group, optional. If None, the default process group
will be used. Default None.
verbose : bool
Print out ranks of each created process group, default False.
"""
manager = DistributedManager()
if not manager.distributed:
raise AssertionError(
"torch.distributed is unavailable. "
"Check pytorch build to ensure the distributed package is available. "
"If building PyTorch from source, set `USE_DISTRIBUTED=1` "
"to enable the distributed package"
)
if name in manager._groups:
raise AssertionError(f"Group with name {name} already exists")
# Get parent group's params
group = manager._groups[group_name] if group_name else None
group_size = dist.get_world_size(group=group)
num_groups = manager.world_size // group_size
# Get number of sub-groups per parent group
if group_size % size != 0:
raise AssertionError(
f"Cannot divide group size {group_size} evenly into subgroups of"
f" size {size}"
)
num_subgroups = group_size // size
# Create all the sub-groups
# Note: all ranks in the job need to create all sub-groups in
# the same order even if a rank is not part of a sub-group
manager._group_ranks[name] = []
for g in range(num_groups):
for i in range(num_subgroups):
# Get global ranks that are part of this sub-group
start = i * size
end = start + size
if group_name:
ranks = manager._group_ranks[group_name][g][start:end]
else:
ranks = list(range(start, end))
# Create sub-group and keep track of ranks
tmp_group = dist.new_group(ranks=ranks)
manager._group_ranks[name].append(ranks)
if manager.rank in ranks:
# Set group in manager only if this rank is part of the group
manager._groups[name] = tmp_group
manager._group_names[tmp_group] = name
if verbose and manager.rank == 0:
print(f"Process group '{name}':")
for grp in manager._group_ranks[name]:
print(" ", grp)
@staticmethod
def create_orthogonal_process_group(
orthogonal_group_name: str, group_name: str, verbose: bool = False
): # pragma: no cover
"""
Create a process group that is orthogonal to the specified process group.
Parameters
----------
orthogonal_group_name : str
Name of the orthogonal process group to be created.
group_name : str
Name of the existing process group.
verbose : bool
Print out ranks of each created process group, default False.
"""
manager = DistributedManager()
if not manager.distributed:
raise AssertionError(
"torch.distributed is unavailable. "
"Check pytorch build to ensure the distributed package is available. "
"If building PyTorch from source, set `USE_DISTRIBUTED=1` "
"to enable the distributed package"
)
if group_name not in manager._groups:
raise ValueError(f"Group with name {group_name} does not exist")
if orthogonal_group_name in manager._groups:
raise ValueError(f"Group with name {orthogonal_group_name} already exists")
group_ranks = manager._group_ranks[group_name]
orthogonal_ranks = [list(i) for i in zip(*group_ranks)]
for ranks in orthogonal_ranks:
tmp_group = dist.new_group(ranks=ranks)
if manager.rank in ranks:
# Set group in manager only if this rank is part of the group
manager._groups[orthogonal_group_name] = tmp_group
manager._group_names[tmp_group] = orthogonal_group_name
manager._group_ranks[orthogonal_group_name] = orthogonal_ranks
if verbose and manager.rank == 0:
print(f"Process group '{orthogonal_group_name}':")
for grp in manager._group_ranks[orthogonal_group_name]:
print(" ", grp)
@staticmethod
def create_group_from_node(
node: ProcessGroupNode,
parent: Optional[str] = None,
verbose: bool = False,
): # pragma: no cover
if node.size is None:
raise AssertionError(
"Cannot create groups from a ProcessGroupNode that is not fully"
" populated. Ensure that config.set_leaf_group_sizes is called first"
" with `update_parent_sizes = True`"
)
DistributedManager.create_process_subgroup(
node.name, node.size, group_name=parent, verbose=verbose
)
# Create orthogonal process group
orthogonal_group = f"__orthogonal_to_{node.name}"
DistributedManager.create_orthogonal_process_group(
orthogonal_group, node.name, verbose=verbose
)
return orthogonal_group
@staticmethod
def create_groups_from_config(
config: ProcessGroupConfig, verbose: bool = False
): # pragma: no cover
# Traverse process group tree in breadth first order
# to create nested process groups
q = queue.Queue()
q.put(config.root_id)
DistributedManager.create_group_from_node(config.root)
while not q.empty():
node_id = q.get()
if verbose:
print(f"Node ID: {node_id}")
children = config.tree.children(node_id)
if verbose:
print(f" Children: {children}")
parent_group = node_id
for child in children:
# Create child group and replace parent group by orthogonal group so
# that each child forms an independent block of processes
parent_group = DistributedManager.create_group_from_node(
child.data,
parent=parent_group,
)
# Add child ids to the queue
q.put(child.identifier)
@staticmethod
def cleanup():
"""Clean up distributed group and singleton"""
# Destroying group.WORLD is enough for all process groups to get destroyed
if (
"_is_initialized" in DistributedManager._shared_state
and DistributedManager._shared_state["_is_initialized"]
and "_distributed" in DistributedManager._shared_state
and DistributedManager._shared_state["_distributed"]
):
if torch.cuda.is_available():
dist.barrier(device_ids=[DistributedManager().local_rank])
else:
dist.barrier()
dist.destroy_process_group()
DistributedManager._shared_state = {}