29
29
Storage ,
30
30
Topology ,
31
31
)
32
- from torchrec .distributed .planner .utils import bytes_to_gb , reset_shard_rank
32
+ from torchrec .distributed .planner .utils import (
33
+ bytes_to_gb ,
34
+ mb_to_bytes ,
35
+ reset_shard_rank ,
36
+ )
33
37
from torchrec .distributed .types import ShardingType
34
38
35
39
logger : logging .Logger = logging .getLogger (__name__ )
@@ -57,6 +61,25 @@ def _get_uniform_sharding_options(
57
61
return uniform_sharding_options
58
62
59
63
64
+ def _get_shards_assignment (
65
+ sharding_options : List [ShardingOption ],
66
+ ) -> List [List [Optional [int ]]]:
67
+ assignment_per_option = []
68
+ for sharding_option in sharding_options :
69
+ assignment_per_option .append (sharding_option .get_shards_assignment ())
70
+ return assignment_per_option
71
+
72
+
73
+ def _apply_shards_assignment (
74
+ sharding_options : List [ShardingOption ],
75
+ assignment_per_option : List [List [Optional [int ]]],
76
+ ) -> None :
77
+ assert len (sharding_options ) == len (assignment_per_option )
78
+ for sharding_option , assignment in zip (sharding_options , assignment_per_option ):
79
+ for shard_id , rank in enumerate (assignment ):
80
+ sharding_option .shards [shard_id ].rank = rank
81
+
82
+
60
83
@dataclass
61
84
class ShardingOptionGroup :
62
85
sharding_options : List [ShardingOption ]
@@ -171,6 +194,7 @@ def partition(
171
194
self ,
172
195
proposal : List [ShardingOption ],
173
196
storage_constraint : Topology ,
197
+ hbm_per_device : Optional [int ] = None ,
174
198
) -> List [ShardingOption ]:
175
199
"""
176
200
Places sharding options on topology based on each sharding option's
@@ -230,13 +254,27 @@ def partition(
230
254
f"GreedyPerfPartitioner - sort_by: { self ._sort_by } , balance_modules: { self ._balance_modules } "
231
255
)
232
256
233
- _topology : Topology = copy .deepcopy (storage_constraint )
234
257
minheap_devices : Optional [List [OrderedDeviceHardware ]] = None
235
- _host_level_devices = self ._get_host_level_devices (_topology )
258
+
259
+ # Don't store the topology since topology cannot be changed
260
+ # the algorithm will only be modifying the device perf & storage sizes so copy them only
261
+ devices = [
262
+ DeviceHardware (
263
+ rank = d .rank ,
264
+ storage = Storage (hbm = hbm_per_device or d .storage .hbm , ddr = d .storage .ddr ),
265
+ perf = copy .deepcopy (d .perf ),
266
+ )
267
+ for d in storage_constraint .devices
268
+ ]
269
+
270
+ host_level_devices = GreedyPerfPartitioner ._get_host_level_devices (
271
+ storage_constraint , devices
272
+ )
236
273
237
274
# first partition the uniform sharding options (RW & DP)
238
275
uniform_sharding_options = _get_uniform_sharding_options (proposal )
239
- self ._uniform_partition (uniform_sharding_options , _topology .devices )
276
+
277
+ GreedyPerfPartitioner ._uniform_partition (uniform_sharding_options , devices )
240
278
241
279
# group the rest sharding options by colocation type (co-host, co-device, none)
242
280
# and sort the groups by storage in reverse order
@@ -249,15 +287,15 @@ def partition(
249
287
sharding_option_group .sharding_options [0 ].partition_by
250
288
== PartitionByType .MULTI_HOST .value
251
289
):
252
- self ._multi_hosts_partition (sharding_option_group , _host_level_devices )
290
+ self ._multi_hosts_partition (sharding_option_group , host_level_devices )
253
291
# _multi_hosts_partition invalidates minheap_devices, force rebuild before using
254
292
minheap_devices = None
255
293
256
294
elif (
257
295
sharding_option_group .sharding_options [0 ].partition_by
258
296
== PartitionByType .HOST .value
259
297
):
260
- self ._cohost_partition (sharding_option_group , _host_level_devices )
298
+ self ._cohost_partition (sharding_option_group , host_level_devices )
261
299
# _cohost_partition invalidates minheap_devices, force rebuild before using
262
300
minheap_devices = None
263
301
elif (
@@ -266,7 +304,7 @@ def partition(
266
304
):
267
305
if minheap_devices is None :
268
306
minheap_devices = self ._establish_minheap (
269
- _topology . devices , _topology .local_world_size
307
+ devices , storage_constraint .local_world_size
270
308
)
271
309
assert (
272
310
len (sharding_option_group .sharding_options ) == 1
@@ -279,8 +317,6 @@ def partition(
279
317
raise RuntimeError (
280
318
f"Unexpected sharding option group { sharding_option_group } "
281
319
)
282
- # pyre-ignore [16]: `GreedyPerfPartitioner` has no attribute `_topology`.
283
- self ._topology : Topology = _topology
284
320
return proposal
285
321
286
322
@classmethod
@@ -432,7 +468,9 @@ def _multi_hosts_partition(
432
468
sharding_option = sharding_option_group .sharding_options [0 ]
433
469
try :
434
470
if sharding_option .sharding_type == ShardingType .GRID_SHARD .value :
435
- cls ._uniform_partition ([sharding_option ], host_devices )
471
+ GreedyPerfPartitioner ._uniform_partition (
472
+ [sharding_option ], host_devices
473
+ )
436
474
else :
437
475
raise PlannerError (
438
476
error_type = PlannerErrorType .PARTITION ,
@@ -486,7 +524,9 @@ def _cohost_partition(
486
524
sharding_option .sharding_type
487
525
== ShardingType .TABLE_ROW_WISE .value
488
526
):
489
- cls ._uniform_partition ([sharding_option ], host_devices )
527
+ GreedyPerfPartitioner ._uniform_partition (
528
+ [sharding_option ], host_devices
529
+ )
490
530
# _uniform_partition invalidates minheap_devices, force rebuild
491
531
# before using
492
532
minheap_devices = None
@@ -521,20 +561,22 @@ def _cohost_partition(
521
561
message = f"can't find a host for sharding option group { sharding_option_group } " ,
522
562
)
523
563
524
- @classmethod
525
- def _get_host_level_devices (cls , _topology : Topology ) -> List [List [DeviceHardware ]]:
526
- num_hosts : int = _topology .world_size // _topology .local_world_size
564
+ @staticmethod
565
+ def _get_host_level_devices (
566
+ topology : Topology , all_devices : List [DeviceHardware ]
567
+ ) -> List [List [DeviceHardware ]]:
568
+ num_hosts : int = topology .world_size // topology .local_world_size
527
569
host_level_devices : List [List [DeviceHardware ]] = []
528
570
for i in range (num_hosts ):
529
- devices_in_host = _topology . devices [
530
- i * _topology .local_world_size : (i + 1 ) * _topology .local_world_size
571
+ devices_in_host = all_devices [
572
+ i * topology .local_world_size : (i + 1 ) * topology .local_world_size
531
573
]
532
574
host_level_devices .append (devices_in_host )
533
575
return host_level_devices
534
576
535
- @classmethod
577
+ @staticmethod
536
578
def _uniform_partition (
537
- cls , sharding_options : List [ShardingOption ], devices : List [DeviceHardware ]
579
+ sharding_options : List [ShardingOption ], devices : List [DeviceHardware ]
538
580
) -> None :
539
581
for sharding_option in sharding_options :
540
582
if sharding_option .num_shards != len (devices ):
@@ -543,16 +585,17 @@ def _uniform_partition(
543
585
message = f"For a uniform partition, the number of shards ({ sharding_option .num_shards } ) must equal the number of devices ({ len (devices )} )" ,
544
586
)
545
587
for i in range (len (devices )):
546
- storage_needed = cast (Storage , sharding_option .shards [i ].storage )
588
+ shard = sharding_option .shards [i ]
589
+ storage_needed = cast (Storage , shard .storage )
547
590
if not storage_needed .fits_in (devices [i ].storage ):
548
591
raise PlannerError (
549
592
error_type = PlannerErrorType .PARTITION ,
550
593
message = f"Shard of size { storage_needed } bytes does not fit on any rank. Device memory cap: { devices [i ].storage } ." ,
551
594
)
552
595
else :
553
- sharding_option . shards [ i ] .rank = devices [i ].rank
596
+ shard .rank = devices [i ].rank
554
597
devices [i ].storage -= storage_needed
555
- devices [i ].perf += cast (Perf , sharding_option . shards [ i ] .perf )
598
+ devices [i ].perf += cast (Perf , shard .perf )
556
599
557
600
558
601
class MemoryBalancedPartitioner (Partitioner ):
@@ -598,16 +641,15 @@ def partition(
598
641
_partitioner = GreedyPerfPartitioner (
599
642
sort_by = SortBy .PERF , balance_modules = self ._balance_modules
600
643
)
601
- # copying storage_constraint, since we modify it in place
602
- _topology : Topology = copy .deepcopy (storage_constraint )
603
644
604
645
# set up default plan to fall back on
605
- default_plan = _partitioner .partition (proposal , _topology )
606
- default_plan = copy .deepcopy (default_plan )
646
+ default_plan = _partitioner .partition (proposal , storage_constraint )
647
+ best_shard_assignment = _get_shards_assignment (default_plan )
648
+
607
649
original_plan_perf = _perf_model .rate (default_plan )
608
650
609
651
# compute shard and default plan HBM stats
610
- hbm_by_rank = [0 ] * _topology .world_size
652
+ hbm_by_rank = [0 ] * storage_constraint .world_size
611
653
hbm_requirement : int = 0
612
654
max_shard_hbm : int = 0
613
655
for sharding_option in default_plan :
@@ -626,7 +668,7 @@ def partition(
626
668
)
627
669
628
670
# Lower bound for the search is the maximum of avg. HBM usage or the biggest shard
629
- avg_hbm_usage : int = int (hbm_requirement / _topology .world_size )
671
+ avg_hbm_usage : int = int (hbm_requirement / storage_constraint .world_size )
630
672
min_hbm_per_device : int = max (avg_hbm_usage , max_shard_hbm )
631
673
logger .info (
632
674
"Searching in the range (min_hbm_per_device, max_hbm_per_device): "
@@ -636,16 +678,19 @@ def partition(
636
678
637
679
# binary search with (min, max] setting
638
680
search_count = 0
681
+ hbm_diff = mb_to_bytes (10 ) # 10MB
639
682
while (
640
683
search_count < self ._max_search_count
641
- and min_hbm_per_device + 10 * 1024 ** 2 < max_hbm_per_device # 10MB
684
+ and min_hbm_per_device + hbm_diff < max_hbm_per_device
642
685
):
643
686
search_count += 1
644
687
reset_shard_rank (proposal )
645
688
mid_hbm_per_device : int = (max_hbm_per_device + min_hbm_per_device ) // 2
646
- set_hbm_per_device ( _topology , mid_hbm_per_device )
689
+
647
690
try :
648
- new_plan = _partitioner .partition (proposal , _topology )
691
+ new_plan = _partitioner .partition (
692
+ proposal , storage_constraint , mid_hbm_per_device
693
+ )
649
694
new_plan_perf = _perf_model .rate (new_plan )
650
695
perf_diff = (
651
696
(new_plan_perf - original_plan_perf ) / original_plan_perf
@@ -674,7 +719,7 @@ def partition(
674
719
f"Found a more memory-balanced plan with { round (bytes_to_gb (mid_hbm_per_device ), 3 )} "
675
720
f"GB per device for embedding tables. The new plan is { perf_diff_str } "
676
721
)
677
- default_plan = copy . deepcopy (new_plan )
722
+ best_shard_assignment = _get_shards_assignment (new_plan )
678
723
max_hbm_per_device = mid_hbm_per_device
679
724
except PlannerError :
680
725
logger .info (
@@ -683,9 +728,5 @@ def partition(
683
728
)
684
729
min_hbm_per_device = mid_hbm_per_device
685
730
731
+ _apply_shards_assignment (default_plan , best_shard_assignment )
686
732
return default_plan
687
-
688
-
689
- def set_hbm_per_device (storage_constraint : Topology , hbm_per_device : int ) -> None :
690
- for device in storage_constraint .devices :
691
- device .storage .hbm = hbm_per_device
0 commit comments