Skip to content

Commit fcfc8ec

Browse files
Ahmed Shuaibifacebook-github-bot
authored andcommitted
refactor: create functions for shard/tensor size calculations (#3257)
Summary: Pull Request resolved: #3257 - refactor to create function for checking if table is cached - refactor to create functions for tensor size calculations Reviewed By: levythu, aporialiao Differential Revision: D79007077 fbshipit-source-id: 9c514bd86cbef9f85cfb10238b2ba294d344d12b
1 parent 7cf5f60 commit fcfc8ec

File tree

1 file changed

+69
-30
lines changed

1 file changed

+69
-30
lines changed

torchrec/distributed/planner/shard_estimators.py

Lines changed: 69 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,16 +1197,9 @@ def calculate_shard_storages(
11971197
hbm_storage: int = tensor_storage.get("hbm", 0)
11981198
ddr_storage: int = tensor_storage.get("ddr", 0)
11991199

1200-
table_cached: bool = False
1201-
if compute_kernel in {
1202-
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
1203-
EmbeddingComputeKernel.QUANT_UVM_CACHING.value,
1204-
EmbeddingComputeKernel.KEY_VALUE.value,
1205-
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value,
1206-
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
1207-
}:
1200+
table_cached = _is_table_cached(compute_kernel)
1201+
if table_cached:
12081202
hbm_storage = round(ddr_storage * caching_ratio)
1209-
table_cached = True
12101203

12111204
optimizer_class = getattr(tensor, "_optimizer_classes", [None])[0]
12121205

@@ -1304,6 +1297,20 @@ def calculate_shard_storages(
13041297
]
13051298

13061299

1300+
def _is_table_cached(
1301+
compute_kernel: str,
1302+
) -> bool:
1303+
if compute_kernel in {
1304+
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
1305+
EmbeddingComputeKernel.QUANT_UVM_CACHING.value,
1306+
EmbeddingComputeKernel.KEY_VALUE.value,
1307+
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value,
1308+
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
1309+
}:
1310+
return True
1311+
return False
1312+
1313+
13071314
def _calculate_shard_io_sizes(
13081315
sharding_type: str,
13091316
batch_sizes: List[int],
@@ -1565,27 +1572,20 @@ def _calculate_storage_specific_sizes(
15651572
is_inference: bool = False,
15661573
clf: Optional[float] = None,
15671574
) -> List[int]:
1568-
tensor_sizes: List[int] = [
1569-
(
1570-
math.ceil(storage * prod(size) / prod(shape))
1571-
if sharding_type != ShardingType.DATA_PARALLEL.value
1572-
else storage
1573-
)
1574-
for size in shard_sizes
1575-
]
1576-
optimizer_multipler: float = _get_optimizer_multipler(optimizer_class, shape)
1577-
1578-
optimizer_sizes: List[int] = [
1579-
math.ceil(tensor_size * optimizer_multipler) for tensor_size in tensor_sizes
1580-
]
1581-
1582-
# If a table has turned on UVM caching (meaning clf is not None), there'll be
1583-
# 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to
1584-
# cache aux state (note that this is not the cache content itself)
1585-
cache_aux_state_sizes: List[int] = (
1586-
[0] * len(shard_sizes)
1587-
if clf is None
1588-
else [math.ceil(size[0] * (4 + clf * 16)) for size in shard_sizes]
1575+
tensor_sizes: List[int] = _calculate_tensor_sizes(
1576+
storage,
1577+
shape,
1578+
shard_sizes,
1579+
sharding_type,
1580+
)
1581+
optimizer_sizes = _calculate_optimizer_sizes(
1582+
tensor_sizes,
1583+
optimizer_class,
1584+
shape,
1585+
)
1586+
cache_aux_state_sizes: List[int] = _calculate_cache_aux_state_sizes(
1587+
shard_sizes,
1588+
clf,
15891589
)
15901590

15911591
return [
@@ -1600,6 +1600,45 @@ def _calculate_storage_specific_sizes(
16001600
]
16011601

16021602

1603+
def _calculate_tensor_sizes(
1604+
storage: int, shape: torch.Size, shard_sizes: List[List[int]], sharding_type: str
1605+
) -> List[int]:
1606+
return [
1607+
(
1608+
math.ceil(storage * prod(size) / prod(shape))
1609+
if sharding_type != ShardingType.DATA_PARALLEL.value
1610+
else storage
1611+
)
1612+
for size in shard_sizes
1613+
]
1614+
1615+
1616+
# If a table has turned on UVM caching (meaning clf is not None), there'll be
1617+
# 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to
1618+
# cache aux state (note that this is not the cache content itself)
1619+
def _calculate_cache_aux_state_sizes(
1620+
shard_sizes: List[List[int]], clf: Optional[float]
1621+
) -> List[int]:
1622+
if clf is None:
1623+
return [0] * len(shard_sizes)
1624+
return [math.ceil(size[0] * (4 + clf * 16)) for size in shard_sizes]
1625+
1626+
1627+
def _calculate_optimizer_sizes(
1628+
tensor_sizes: List[int],
1629+
optimizer_class: Optional[Type[torch.optim.Optimizer]],
1630+
sharding_tensor_shape: torch.Size,
1631+
) -> List[int]:
1632+
optimizer_multiplier: float = _get_optimizer_multipler(
1633+
optimizer_class,
1634+
sharding_tensor_shape,
1635+
)
1636+
optimizer_sizes: List[int] = [
1637+
math.ceil(tensor_size * optimizer_multiplier) for tensor_size in tensor_sizes
1638+
]
1639+
return optimizer_sizes
1640+
1641+
16031642
def _get_optimizer_multipler(
16041643
optimizer_class: Optional[Type[torch.optim.Optimizer]],
16051644
shape: torch.Size,

0 commit comments

Comments
 (0)