Skip to content

Commit ab3d79e

Browse files
emlinfacebook-github-bot
authored andcommitted
enable feature score auto collection in EBC (#3475)
Summary: X-link: pytorch/FBGEMM#5031 X-link: facebookresearch/FBGEMM#2044 Enable feature score auto collection for EBC in the similar way of EC. The configuration has no difference in embedding table config: virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( training_id_eviction_trigger_count=260_000_000, # 260M training_id_keep_count=160_000_000, # 160M enable_auto_feature_score_collection=True, feature_score_mapping={ "sparse_public_original_content_creator": 1.0, }, feature_score_default_value=0.5, ), Differential Revision: D85017179
1 parent 676c808 commit ab3d79e

File tree

5 files changed

+334
-7
lines changed

5 files changed

+334
-7
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,8 @@ def _populate_zero_collision_tbe_params(
341341
tbe_params.pop("kvzch_eviction_trigger_mode")
342342
else:
343343
eviction_trigger_mode = 2 # 2 means mem_util based eviction
344+
345+
fs_evcition_enabled: bool = False
344346
for i, table in enumerate(config.embedding_tables):
345347
policy_t = table.virtual_table_eviction_policy
346348
if policy_t is not None:
@@ -369,6 +371,7 @@ def _populate_zero_collision_tbe_params(
369371
raise ValueError(
370372
f"Do not support multiple eviction strategy in one tbe {eviction_strategy} and 5 for tables {table_names}"
371373
)
374+
fs_evcition_enabled = True
372375
elif isinstance(policy_t, TimestampBasedEvictionPolicy):
373376
training_id_eviction_trigger_count[i] = (
374377
policy_t.training_id_eviction_trigger_count
@@ -440,6 +443,7 @@ def _populate_zero_collision_tbe_params(
440443
backend_return_whole_row=(backend_type == BackendType.DRAM),
441444
eviction_policy=eviction_policy,
442445
embedding_cache_mode=embedding_cache_mode_,
446+
feature_score_collection_enabled=fs_evcition_enabled,
443447
)
444448

445449

@@ -2872,6 +2876,7 @@ def __init__(
28722876
_populate_zero_collision_tbe_params(
28732877
ssd_tbe_params, self._bucket_spec, config, backend_type
28742878
)
2879+
self._kv_zch_params: KVZCHParams = ssd_tbe_params["kv_zch_params"]
28752880
compute_kernel = config.embedding_tables[0].compute_kernel
28762881
embedding_location = compute_kernel_to_embedding_location(compute_kernel)
28772882

@@ -3155,7 +3160,40 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
31553160
self._split_weights_res = None
31563161
self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None)
31573162

3158-
return super().forward(features)
3163+
weights = features.weights_or_none()
3164+
per_sample_weights = None
3165+
score_weights = None
3166+
if weights is not None and weights.dtype == torch.float64:
3167+
fp32_weights = weights.view(torch.float32)
3168+
per_sample_weights = fp32_weights[:, 0]
3169+
score_weights = fp32_weights[:, 1]
3170+
elif weights is not None and weights.dtype == torch.float32:
3171+
if self._kv_zch_params.feature_score_collection_enabled:
3172+
score_weights = weights.view(-1)
3173+
else:
3174+
per_sample_weights = weights.view(-1)
3175+
if features.variable_stride_per_key() and isinstance(
3176+
self.emb_module,
3177+
(
3178+
SplitTableBatchedEmbeddingBagsCodegen,
3179+
DenseTableBatchedEmbeddingBagsCodegen,
3180+
SSDTableBatchedEmbeddingBags,
3181+
),
3182+
):
3183+
return self.emb_module(
3184+
indices=features.values().long(),
3185+
offsets=features.offsets().long(),
3186+
weights=score_weights,
3187+
per_sample_weights=per_sample_weights,
3188+
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
3189+
)
3190+
else:
3191+
return self.emb_module(
3192+
indices=features.values().long(),
3193+
offsets=features.offsets().long(),
3194+
weights=score_weights,
3195+
per_sample_weights=per_sample_weights,
3196+
)
31593197

31603198

31613199
class BatchedFusedEmbeddingBag(

torchrec/distributed/embedding_lookup.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
QuantBatchedEmbeddingBag,
6767
)
6868
from torchrec.distributed.types import rank_device, ShardedTensor, ShardingType
69+
from torchrec.modules.embedding_configs import FeatureScoreBasedEvictionPolicy
6970
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
7071

7172
logger: logging.Logger = logging.getLogger(__name__)
@@ -490,6 +491,23 @@ def __init__(
490491
) -> None:
491492
super().__init__()
492493
self._emb_modules: nn.ModuleList = nn.ModuleList()
494+
self._feature_score_auto_collections: List[bool] = []
495+
for config in grouped_configs:
496+
collection = False
497+
for table in config.embedding_tables:
498+
if table.use_virtual_table and isinstance(
499+
table.virtual_table_eviction_policy, FeatureScoreBasedEvictionPolicy
500+
):
501+
if (
502+
table.virtual_table_eviction_policy.enable_auto_feature_score_collection
503+
):
504+
collection = True
505+
self._feature_score_auto_collections.append(collection)
506+
507+
logger.info(
508+
f"GroupedPooledEmbeddingsLookup: {self._feature_score_auto_collections=}"
509+
)
510+
493511
for config in grouped_configs:
494512
self._emb_modules.append(
495513
self._create_embedding_kernel(config, device, pg, sharding_type)
@@ -663,8 +681,11 @@ def forward(
663681
features_by_group = sparse_features.split(
664682
self._feature_splits,
665683
)
666-
for config, emb_op, features in zip(
667-
self.grouped_configs, self._emb_modules, features_by_group
684+
for config, emb_op, features, fs_auto_collection in zip(
685+
self.grouped_configs,
686+
self._emb_modules,
687+
features_by_group,
688+
self._feature_score_auto_collections,
668689
):
669690
if (
670691
config.has_feature_processor
@@ -674,9 +695,19 @@ def forward(
674695
features = self._feature_processor(features)
675696

676697
if config.is_weighted:
677-
features._weights = CommOpGradientScaling.apply(
698+
feature_weights = CommOpGradientScaling.apply(
678699
features._weights, self._scale_gradient_factor
679-
)
700+
).float()
701+
702+
if fs_auto_collection and features.weights_or_none() is not None:
703+
score_weights = features.weights().float()
704+
assert (
705+
feature_weights.numel() == score_weights.numel()
706+
), f"feature_weights.numel() {feature_weights.numel()} != score_weights.numel() {score_weights.numel()}"
707+
cat_weights = torch.cat(
708+
[feature_weights, score_weights], dim=1
709+
).view(torch.float64)
710+
features._weights = cat_weights
680711

681712
embeddings.append(emb_op(features))
682713

torchrec/distributed/embeddingbag.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
KJTList,
5252
ShardedEmbeddingModule,
5353
)
54+
from torchrec.distributed.feature_score_utils import (
55+
create_sharding_type_to_feature_score_mapping,
56+
may_collect_feature_scores,
57+
)
5458
from torchrec.distributed.fused_params import (
5559
FUSED_PARAM_IS_SSD_TABLE,
5660
FUSED_PARAM_SSD_TABLE_LIST,
@@ -565,6 +569,24 @@ def __init__(
565569
# forward pass flow control
566570
self._has_uninitialized_input_dist: bool = True
567571
self._has_features_permute: bool = True
572+
573+
self._enable_feature_score_weight_accumulation: bool = False
574+
self._enabled_feature_score_auto_collection: bool = False
575+
self._sharding_type_feature_score_mapping: Dict[str, Dict[str, float]] = {}
576+
(
577+
self._enable_feature_score_weight_accumulation,
578+
self._enabled_feature_score_auto_collection,
579+
self._sharding_type_feature_score_mapping,
580+
) = create_sharding_type_to_feature_score_mapping(
581+
self._embedding_bag_configs, self.sharding_type_to_sharding_infos
582+
)
583+
584+
logger.info(
585+
f"EBC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}, "
586+
f"auto collection enabled: {self._enabled_feature_score_auto_collection}, "
587+
f"sharding type to feature score mapping: {self._sharding_type_feature_score_mapping}"
588+
)
589+
568590
# Get all fused optimizers and combine them.
569591
optims = []
570592
for lookup in self._lookups:
@@ -1565,6 +1587,11 @@ def input_dist(
15651587
features_by_shards = features.split(
15661588
self._feature_splits,
15671589
)
1590+
features_by_shards = may_collect_feature_scores(
1591+
features_by_shards,
1592+
self._enabled_feature_score_auto_collection,
1593+
self._sharding_type_feature_score_mapping,
1594+
)
15681595
awaitables = []
15691596
for input_dist, features_by_shard, sharding_type in zip(
15701597
self._input_dists,

torchrec/distributed/feature_score_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torchrec.distributed.embedding_types import ShardingType
1818

1919
from torchrec.modules.embedding_configs import (
20-
EmbeddingConfig,
20+
BaseEmbeddingConfig,
2121
FeatureScoreBasedEvictionPolicy,
2222
)
2323
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@@ -26,7 +26,7 @@
2626

2727

2828
def create_sharding_type_to_feature_score_mapping(
29-
embedding_configs: Sequence[EmbeddingConfig],
29+
embedding_configs: Sequence[BaseEmbeddingConfig],
3030
sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]],
3131
) -> Tuple[bool, bool, Dict[str, Dict[str, float]]]:
3232
enable_feature_score_weight_accumulation = False

0 commit comments

Comments
 (0)