@@ -1197,16 +1197,9 @@ def calculate_shard_storages(
1197
1197
hbm_storage : int = tensor_storage .get ("hbm" , 0 )
1198
1198
ddr_storage : int = tensor_storage .get ("ddr" , 0 )
1199
1199
1200
- table_cached : bool = False
1201
- if compute_kernel in {
1202
- EmbeddingComputeKernel .FUSED_UVM_CACHING .value ,
1203
- EmbeddingComputeKernel .QUANT_UVM_CACHING .value ,
1204
- EmbeddingComputeKernel .KEY_VALUE .value ,
1205
- EmbeddingComputeKernel .SSD_VIRTUAL_TABLE .value ,
1206
- EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ,
1207
- }:
1200
+ table_cached = _is_table_cached (compute_kernel )
1201
+ if table_cached :
1208
1202
hbm_storage = round (ddr_storage * caching_ratio )
1209
- table_cached = True
1210
1203
1211
1204
optimizer_class = getattr (tensor , "_optimizer_classes" , [None ])[0 ]
1212
1205
@@ -1304,6 +1297,20 @@ def calculate_shard_storages(
1304
1297
]
1305
1298
1306
1299
1300
+ def _is_table_cached (
1301
+ compute_kernel : str ,
1302
+ ) -> bool :
1303
+ if compute_kernel in {
1304
+ EmbeddingComputeKernel .FUSED_UVM_CACHING .value ,
1305
+ EmbeddingComputeKernel .QUANT_UVM_CACHING .value ,
1306
+ EmbeddingComputeKernel .KEY_VALUE .value ,
1307
+ EmbeddingComputeKernel .SSD_VIRTUAL_TABLE .value ,
1308
+ EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ,
1309
+ }:
1310
+ return True
1311
+ return False
1312
+
1313
+
1307
1314
def _calculate_shard_io_sizes (
1308
1315
sharding_type : str ,
1309
1316
batch_sizes : List [int ],
@@ -1565,27 +1572,20 @@ def _calculate_storage_specific_sizes(
1565
1572
is_inference : bool = False ,
1566
1573
clf : Optional [float ] = None ,
1567
1574
) -> List [int ]:
1568
- tensor_sizes : List [int ] = [
1569
- (
1570
- math .ceil (storage * prod (size ) / prod (shape ))
1571
- if sharding_type != ShardingType .DATA_PARALLEL .value
1572
- else storage
1573
- )
1574
- for size in shard_sizes
1575
- ]
1576
- optimizer_multipler : float = _get_optimizer_multipler (optimizer_class , shape )
1577
-
1578
- optimizer_sizes : List [int ] = [
1579
- math .ceil (tensor_size * optimizer_multipler ) for tensor_size in tensor_sizes
1580
- ]
1581
-
1582
- # If a table has turned on UVM caching (meaning clf is not None), there'll be
1583
- # 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to
1584
- # cache aux state (note that this is not the cache content itself)
1585
- cache_aux_state_sizes : List [int ] = (
1586
- [0 ] * len (shard_sizes )
1587
- if clf is None
1588
- else [math .ceil (size [0 ] * (4 + clf * 16 )) for size in shard_sizes ]
1575
+ tensor_sizes : List [int ] = _calculate_tensor_sizes (
1576
+ storage ,
1577
+ shape ,
1578
+ shard_sizes ,
1579
+ sharding_type ,
1580
+ )
1581
+ optimizer_sizes = _calculate_optimizer_sizes (
1582
+ tensor_sizes ,
1583
+ optimizer_class ,
1584
+ shape ,
1585
+ )
1586
+ cache_aux_state_sizes : List [int ] = _calculate_cache_aux_state_sizes (
1587
+ shard_sizes ,
1588
+ clf ,
1589
1589
)
1590
1590
1591
1591
return [
@@ -1600,6 +1600,45 @@ def _calculate_storage_specific_sizes(
1600
1600
]
1601
1601
1602
1602
1603
+ def _calculate_tensor_sizes (
1604
+ storage : int , shape : torch .Size , shard_sizes : List [List [int ]], sharding_type : str
1605
+ ) -> List [int ]:
1606
+ return [
1607
+ (
1608
+ math .ceil (storage * prod (size ) / prod (shape ))
1609
+ if sharding_type != ShardingType .DATA_PARALLEL .value
1610
+ else storage
1611
+ )
1612
+ for size in shard_sizes
1613
+ ]
1614
+
1615
+
1616
+ # If a table has turned on UVM caching (meaning clf is not None), there'll be
1617
+ # 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to
1618
+ # cache aux state (note that this is not the cache content itself)
1619
+ def _calculate_cache_aux_state_sizes (
1620
+ shard_sizes : List [List [int ]], clf : Optional [float ]
1621
+ ) -> List [int ]:
1622
+ if clf is None :
1623
+ return [0 ] * len (shard_sizes )
1624
+ return [math .ceil (size [0 ] * (4 + clf * 16 )) for size in shard_sizes ]
1625
+
1626
+
1627
+ def _calculate_optimizer_sizes (
1628
+ tensor_sizes : List [int ],
1629
+ optimizer_class : Optional [Type [torch .optim .Optimizer ]],
1630
+ sharding_tensor_shape : torch .Size ,
1631
+ ) -> List [int ]:
1632
+ optimizer_multiplier : float = _get_optimizer_multipler (
1633
+ optimizer_class ,
1634
+ sharding_tensor_shape ,
1635
+ )
1636
+ optimizer_sizes : List [int ] = [
1637
+ math .ceil (tensor_size * optimizer_multiplier ) for tensor_size in tensor_sizes
1638
+ ]
1639
+ return optimizer_sizes
1640
+
1641
+
1603
1642
def _get_optimizer_multipler (
1604
1643
optimizer_class : Optional [Type [torch .optim .Optimizer ]],
1605
1644
shape : torch .Size ,
0 commit comments