Skip to content

Commit b67f6ae

Browse files
Boris Saranafacebook-github-bot
authored andcommitted
Remove unnecessary deep copy of plan proposals (#3373)
Summary: Pull Request resolved: #3373 Reviewed By: iamzainhuda Differential Revision: D78922919 fbshipit-source-id: c9e26f45c55996468354f31f7a1dae2b567e4b26
1 parent 73a5c05 commit b67f6ae

File tree

5 files changed

+135
-77
lines changed

5 files changed

+135
-77
lines changed

torchrec/distributed/model_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -898,10 +898,10 @@ def __init__(
898898
# pyre-ignore[16]
899899
ctx.sharded_module = self._sharder_map[ctx.module].sharded_module_type
900900

901-
consolidated_plan = copy.deepcopy(self._ctxs[0].plan)
901+
consolidated_plan = self._ctxs[0].plan
902902
for ctx in self._ctxs[1:]:
903903
for key, val in ctx.plan.plan.items():
904-
consolidated_plan.plan[key] = copy.deepcopy(val)
904+
consolidated_plan.plan[key] = val
905905

906906
logger.info(
907907
"[TorchRec 2D Parallel] Consolidated sharding plan:\n%s", consolidated_plan

torchrec/distributed/planner/partitioners.py

Lines changed: 77 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@
2929
Storage,
3030
Topology,
3131
)
32-
from torchrec.distributed.planner.utils import bytes_to_gb, reset_shard_rank
32+
from torchrec.distributed.planner.utils import (
33+
bytes_to_gb,
34+
mb_to_bytes,
35+
reset_shard_rank,
36+
)
3337
from torchrec.distributed.types import ShardingType
3438

3539
logger: logging.Logger = logging.getLogger(__name__)
@@ -57,6 +61,25 @@ def _get_uniform_sharding_options(
5761
return uniform_sharding_options
5862

5963

64+
def _get_shards_assignment(
65+
sharding_options: List[ShardingOption],
66+
) -> List[List[Optional[int]]]:
67+
assignment_per_option = []
68+
for sharding_option in sharding_options:
69+
assignment_per_option.append(sharding_option.get_shards_assignment())
70+
return assignment_per_option
71+
72+
73+
def _apply_shards_assignment(
74+
sharding_options: List[ShardingOption],
75+
assignment_per_option: List[List[Optional[int]]],
76+
) -> None:
77+
assert len(sharding_options) == len(assignment_per_option)
78+
for sharding_option, assignment in zip(sharding_options, assignment_per_option):
79+
for shard_id, rank in enumerate(assignment):
80+
sharding_option.shards[shard_id].rank = rank
81+
82+
6083
@dataclass
6184
class ShardingOptionGroup:
6285
sharding_options: List[ShardingOption]
@@ -171,6 +194,7 @@ def partition(
171194
self,
172195
proposal: List[ShardingOption],
173196
storage_constraint: Topology,
197+
hbm_per_device: Optional[int] = None,
174198
) -> List[ShardingOption]:
175199
"""
176200
Places sharding options on topology based on each sharding option's
@@ -230,13 +254,27 @@ def partition(
230254
f"GreedyPerfPartitioner - sort_by: {self._sort_by}, balance_modules: {self._balance_modules}"
231255
)
232256

233-
_topology: Topology = copy.deepcopy(storage_constraint)
234257
minheap_devices: Optional[List[OrderedDeviceHardware]] = None
235-
_host_level_devices = self._get_host_level_devices(_topology)
258+
259+
# Don't store the topology since topology cannot be changed
260+
# the algorithm will only be modifying the device perf & storage sizes so copy them only
261+
devices = [
262+
DeviceHardware(
263+
rank=d.rank,
264+
storage=Storage(hbm=hbm_per_device or d.storage.hbm, ddr=d.storage.ddr),
265+
perf=copy.deepcopy(d.perf),
266+
)
267+
for d in storage_constraint.devices
268+
]
269+
270+
host_level_devices = GreedyPerfPartitioner._get_host_level_devices(
271+
storage_constraint, devices
272+
)
236273

237274
# first partition the uniform sharding options (RW & DP)
238275
uniform_sharding_options = _get_uniform_sharding_options(proposal)
239-
self._uniform_partition(uniform_sharding_options, _topology.devices)
276+
277+
GreedyPerfPartitioner._uniform_partition(uniform_sharding_options, devices)
240278

241279
# group the rest sharding options by colocation type (co-host, co-device, none)
242280
# and sort the groups by storage in reverse order
@@ -249,15 +287,15 @@ def partition(
249287
sharding_option_group.sharding_options[0].partition_by
250288
== PartitionByType.MULTI_HOST.value
251289
):
252-
self._multi_hosts_partition(sharding_option_group, _host_level_devices)
290+
self._multi_hosts_partition(sharding_option_group, host_level_devices)
253291
# _multi_hosts_partition invalidates minheap_devices, force rebuild before using
254292
minheap_devices = None
255293

256294
elif (
257295
sharding_option_group.sharding_options[0].partition_by
258296
== PartitionByType.HOST.value
259297
):
260-
self._cohost_partition(sharding_option_group, _host_level_devices)
298+
self._cohost_partition(sharding_option_group, host_level_devices)
261299
# _cohost_partition invalidates minheap_devices, force rebuild before using
262300
minheap_devices = None
263301
elif (
@@ -266,7 +304,7 @@ def partition(
266304
):
267305
if minheap_devices is None:
268306
minheap_devices = self._establish_minheap(
269-
_topology.devices, _topology.local_world_size
307+
devices, storage_constraint.local_world_size
270308
)
271309
assert (
272310
len(sharding_option_group.sharding_options) == 1
@@ -279,8 +317,6 @@ def partition(
279317
raise RuntimeError(
280318
f"Unexpected sharding option group {sharding_option_group}"
281319
)
282-
# pyre-ignore [16]: `GreedyPerfPartitioner` has no attribute `_topology`.
283-
self._topology: Topology = _topology
284320
return proposal
285321

286322
@classmethod
@@ -432,7 +468,9 @@ def _multi_hosts_partition(
432468
sharding_option = sharding_option_group.sharding_options[0]
433469
try:
434470
if sharding_option.sharding_type == ShardingType.GRID_SHARD.value:
435-
cls._uniform_partition([sharding_option], host_devices)
471+
GreedyPerfPartitioner._uniform_partition(
472+
[sharding_option], host_devices
473+
)
436474
else:
437475
raise PlannerError(
438476
error_type=PlannerErrorType.PARTITION,
@@ -486,7 +524,9 @@ def _cohost_partition(
486524
sharding_option.sharding_type
487525
== ShardingType.TABLE_ROW_WISE.value
488526
):
489-
cls._uniform_partition([sharding_option], host_devices)
527+
GreedyPerfPartitioner._uniform_partition(
528+
[sharding_option], host_devices
529+
)
490530
# _uniform_partition invalidates minheap_devices, force rebuild
491531
# before using
492532
minheap_devices = None
@@ -521,20 +561,22 @@ def _cohost_partition(
521561
message=f"can't find a host for sharding option group {sharding_option_group}",
522562
)
523563

524-
@classmethod
525-
def _get_host_level_devices(cls, _topology: Topology) -> List[List[DeviceHardware]]:
526-
num_hosts: int = _topology.world_size // _topology.local_world_size
564+
@staticmethod
565+
def _get_host_level_devices(
566+
topology: Topology, all_devices: List[DeviceHardware]
567+
) -> List[List[DeviceHardware]]:
568+
num_hosts: int = topology.world_size // topology.local_world_size
527569
host_level_devices: List[List[DeviceHardware]] = []
528570
for i in range(num_hosts):
529-
devices_in_host = _topology.devices[
530-
i * _topology.local_world_size : (i + 1) * _topology.local_world_size
571+
devices_in_host = all_devices[
572+
i * topology.local_world_size : (i + 1) * topology.local_world_size
531573
]
532574
host_level_devices.append(devices_in_host)
533575
return host_level_devices
534576

535-
@classmethod
577+
@staticmethod
536578
def _uniform_partition(
537-
cls, sharding_options: List[ShardingOption], devices: List[DeviceHardware]
579+
sharding_options: List[ShardingOption], devices: List[DeviceHardware]
538580
) -> None:
539581
for sharding_option in sharding_options:
540582
if sharding_option.num_shards != len(devices):
@@ -543,16 +585,17 @@ def _uniform_partition(
543585
message=f"For a uniform partition, the number of shards ({sharding_option.num_shards}) must equal the number of devices ({len(devices)})",
544586
)
545587
for i in range(len(devices)):
546-
storage_needed = cast(Storage, sharding_option.shards[i].storage)
588+
shard = sharding_option.shards[i]
589+
storage_needed = cast(Storage, shard.storage)
547590
if not storage_needed.fits_in(devices[i].storage):
548591
raise PlannerError(
549592
error_type=PlannerErrorType.PARTITION,
550593
message=f"Shard of size {storage_needed} bytes does not fit on any rank. Device memory cap: {devices[i].storage}.",
551594
)
552595
else:
553-
sharding_option.shards[i].rank = devices[i].rank
596+
shard.rank = devices[i].rank
554597
devices[i].storage -= storage_needed
555-
devices[i].perf += cast(Perf, sharding_option.shards[i].perf)
598+
devices[i].perf += cast(Perf, shard.perf)
556599

557600

558601
class MemoryBalancedPartitioner(Partitioner):
@@ -598,16 +641,15 @@ def partition(
598641
_partitioner = GreedyPerfPartitioner(
599642
sort_by=SortBy.PERF, balance_modules=self._balance_modules
600643
)
601-
# copying storage_constraint, since we modify it in place
602-
_topology: Topology = copy.deepcopy(storage_constraint)
603644

604645
# set up default plan to fall back on
605-
default_plan = _partitioner.partition(proposal, _topology)
606-
default_plan = copy.deepcopy(default_plan)
646+
default_plan = _partitioner.partition(proposal, storage_constraint)
647+
best_shard_assignment = _get_shards_assignment(default_plan)
648+
607649
original_plan_perf = _perf_model.rate(default_plan)
608650

609651
# compute shard and default plan HBM stats
610-
hbm_by_rank = [0] * _topology.world_size
652+
hbm_by_rank = [0] * storage_constraint.world_size
611653
hbm_requirement: int = 0
612654
max_shard_hbm: int = 0
613655
for sharding_option in default_plan:
@@ -626,7 +668,7 @@ def partition(
626668
)
627669

628670
# Lower bound for the search is the maximum of avg. HBM usage or the biggest shard
629-
avg_hbm_usage: int = int(hbm_requirement / _topology.world_size)
671+
avg_hbm_usage: int = int(hbm_requirement / storage_constraint.world_size)
630672
min_hbm_per_device: int = max(avg_hbm_usage, max_shard_hbm)
631673
logger.info(
632674
"Searching in the range (min_hbm_per_device, max_hbm_per_device): "
@@ -636,16 +678,19 @@ def partition(
636678

637679
# binary search with (min, max] setting
638680
search_count = 0
681+
hbm_diff = mb_to_bytes(10) # 10MB
639682
while (
640683
search_count < self._max_search_count
641-
and min_hbm_per_device + 10 * 1024**2 < max_hbm_per_device # 10MB
684+
and min_hbm_per_device + hbm_diff < max_hbm_per_device
642685
):
643686
search_count += 1
644687
reset_shard_rank(proposal)
645688
mid_hbm_per_device: int = (max_hbm_per_device + min_hbm_per_device) // 2
646-
set_hbm_per_device(_topology, mid_hbm_per_device)
689+
647690
try:
648-
new_plan = _partitioner.partition(proposal, _topology)
691+
new_plan = _partitioner.partition(
692+
proposal, storage_constraint, mid_hbm_per_device
693+
)
649694
new_plan_perf = _perf_model.rate(new_plan)
650695
perf_diff = (
651696
(new_plan_perf - original_plan_perf) / original_plan_perf
@@ -674,7 +719,7 @@ def partition(
674719
f"Found a more memory-balanced plan with {round(bytes_to_gb(mid_hbm_per_device), 3)} "
675720
f"GB per device for embedding tables. The new plan is {perf_diff_str}"
676721
)
677-
default_plan = copy.deepcopy(new_plan)
722+
best_shard_assignment = _get_shards_assignment(new_plan)
678723
max_hbm_per_device = mid_hbm_per_device
679724
except PlannerError:
680725
logger.info(
@@ -683,9 +728,5 @@ def partition(
683728
)
684729
min_hbm_per_device = mid_hbm_per_device
685730

731+
_apply_shards_assignment(default_plan, best_shard_assignment)
686732
return default_plan
687-
688-
689-
def set_hbm_per_device(storage_constraint: Topology, hbm_per_device: int) -> None:
690-
for device in storage_constraint.devices:
691-
device.storage.hbm = hbm_per_device

0 commit comments

Comments
 (0)