Skip to content

Commit 676c808

Browse files
emlinfacebook-github-bot
authored andcommitted
add auto feature score collection to EC (#3474)
Summary: X-link: pytorch/FBGEMM#5030 X-link: facebookresearch/FBGEMM#2043 Enable feature score auto collection in ShardedEmbeddingCollection based on static feature to score mapping. If user needs custom score for specific id, they can disable auto collection and then change model code explicitly to collect score for each id. Here is the sample eviction policy config in embedding_table config to enable auto score collection: virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( training_id_eviction_trigger_count=260_000_000, # 260M training_id_keep_count=160_000_000, # 160M enable_auto_feature_score_collection=True, feature_score_mapping={ "sparse_public_original_content_creator": 1.0, }, feature_score_default_value=0.5, ), Additionally the counter collected previously during EC dedup is not used by kvzch backend, so this diff removed that counter and allow KJT to transfer a single float32 weight tensor to backend. This allows feature score collection for EBC since there could have another float weight for EBC pooling already. Reviewed By: EddyLXJ Differential Revision: D83945722
1 parent 80dbb88 commit 676c808

File tree

5 files changed

+852
-20
lines changed

5 files changed

+852
-20
lines changed

torchrec/distributed/embedding.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@
4646
ShardedEmbeddingModule,
4747
ShardingType,
4848
)
49+
50+
from torchrec.distributed.feature_score_utils import (
51+
create_sharding_type_to_feature_score_mapping,
52+
may_collect_feature_scores,
53+
)
4954
from torchrec.distributed.fused_params import (
5055
FUSED_PARAM_IS_SSD_TABLE,
5156
FUSED_PARAM_SSD_TABLE_LIST,
@@ -90,7 +95,6 @@
9095
from torchrec.modules.embedding_configs import (
9196
EmbeddingConfig,
9297
EmbeddingTableConfig,
93-
FeatureScoreBasedEvictionPolicy,
9498
PoolingType,
9599
)
96100
from torchrec.modules.embedding_modules import (
@@ -460,12 +464,12 @@ def __init__(
460464
] = {
461465
sharding_type: self.create_embedding_sharding(
462466
sharding_type=sharding_type,
463-
sharding_infos=embedding_confings,
467+
sharding_infos=embedding_configs,
464468
env=env,
465469
device=device,
466470
qcomm_codecs_registry=self.qcomm_codecs_registry,
467471
)
468-
for sharding_type, embedding_confings in sharding_type_to_sharding_infos.items()
472+
for sharding_type, embedding_configs in sharding_type_to_sharding_infos.items()
469473
}
470474

471475
self.enable_embedding_update: bool = any(
@@ -487,16 +491,20 @@ def __init__(
487491
self._has_uninitialized_input_dist: bool = True
488492
logger.info(f"EC index dedup enabled: {self._use_index_dedup}.")
489493

490-
for config in self._embedding_configs:
491-
virtual_table_eviction_policy = config.virtual_table_eviction_policy
492-
if virtual_table_eviction_policy is not None and isinstance(
493-
virtual_table_eviction_policy, FeatureScoreBasedEvictionPolicy
494-
):
495-
self._enable_feature_score_weight_accumulation = True
496-
break
497-
494+
self._enable_feature_score_weight_accumulation: bool = False
495+
self._enabled_feature_score_auto_collection: bool = False
496+
self._sharding_type_feature_score_mapping: Dict[str, Dict[str, float]] = {}
497+
(
498+
self._enable_feature_score_weight_accumulation,
499+
self._enabled_feature_score_auto_collection,
500+
self._sharding_type_feature_score_mapping,
501+
) = create_sharding_type_to_feature_score_mapping(
502+
self._embedding_configs, sharding_type_to_sharding_infos
503+
)
498504
logger.info(
499-
f"EC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}."
505+
f"EC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}, "
506+
f"auto collection enabled: {self._enabled_feature_score_auto_collection}, "
507+
f"sharding type to feature score mapping: {self._sharding_type_feature_score_mapping}"
500508
)
501509

502510
# Get all fused optimizers and combine them.
@@ -1357,22 +1365,22 @@ def _dedup_indices(
13571365
source_weights.dtype == torch.float32
13581366
), "Only float32 weights are supported for feature score eviction weights."
13591367

1360-
acc_weights = torch.ops.fbgemm.jagged_acc_weights_and_counts(
1361-
source_weights.view(-1),
1362-
reverse_indices,
1368+
# Accumulate weights using scatter_add
1369+
acc_weights = torch.zeros(
13631370
unique_indices.numel(),
1371+
dtype=torch.float32,
1372+
device=source_weights.device,
13641373
)
13651374

1375+
# Use PyTorch's scatter_add to accumulate weights
1376+
acc_weights.scatter_add_(0, reverse_indices, source_weights)
1377+
13661378
dedup_features = KeyedJaggedTensor(
13671379
keys=input_feature.keys(),
13681380
lengths=lengths,
13691381
offsets=offsets,
13701382
values=unique_indices,
1371-
weights=(
1372-
acc_weights.view(torch.float64).view(-1)
1373-
if acc_weights is not None
1374-
else None
1375-
),
1383+
weights=(acc_weights.view(-1) if acc_weights is not None else None),
13761384
)
13771385

13781386
ctx.input_features.append(input_feature)
@@ -1491,6 +1499,11 @@ def input_dist(
14911499
self._features_order_tensor,
14921500
)
14931501
features_by_shards = features.split(self._feature_splits)
1502+
features_by_shards = may_collect_feature_scores(
1503+
features_by_shards,
1504+
self._enabled_feature_score_auto_collection,
1505+
self._sharding_type_feature_score_mapping,
1506+
)
14941507
if self._use_index_dedup:
14951508
features_by_shards = self._dedup_indices(ctx, features_by_shards)
14961509

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
import logging
10+
from typing import Dict, List, Sequence, Tuple
11+
12+
import torch
13+
14+
from torch.autograd.profiler import record_function
15+
16+
from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo
17+
from torchrec.distributed.embedding_types import ShardingType
18+
19+
from torchrec.modules.embedding_configs import (
20+
EmbeddingConfig,
21+
FeatureScoreBasedEvictionPolicy,
22+
)
23+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
24+
25+
logger: logging.Logger = logging.getLogger(__name__)
26+
27+
28+
def create_sharding_type_to_feature_score_mapping(
29+
embedding_configs: Sequence[EmbeddingConfig],
30+
sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]],
31+
) -> Tuple[bool, bool, Dict[str, Dict[str, float]]]:
32+
enable_feature_score_weight_accumulation = False
33+
enabled_feature_score_auto_collection = False
34+
35+
# Validation for virtual table configurations
36+
virtual_tables = [
37+
config for config in embedding_configs if config.use_virtual_table
38+
]
39+
if virtual_tables:
40+
virtual_tables_with_eviction = [
41+
config
42+
for config in virtual_tables
43+
if config.virtual_table_eviction_policy is not None
44+
]
45+
if virtual_tables_with_eviction:
46+
# Check if any virtual table uses FeatureScoreBasedEvictionPolicy
47+
tables_with_feature_score_policy = [
48+
config
49+
for config in virtual_tables_with_eviction
50+
if isinstance(
51+
config.virtual_table_eviction_policy,
52+
FeatureScoreBasedEvictionPolicy,
53+
)
54+
]
55+
56+
# If any virtual table uses FeatureScoreBasedEvictionPolicy,
57+
# then ALL virtual tables with eviction policies must use FeatureScoreBasedEvictionPolicy
58+
if tables_with_feature_score_policy:
59+
assert all(
60+
isinstance(
61+
config.virtual_table_eviction_policy,
62+
FeatureScoreBasedEvictionPolicy,
63+
)
64+
for config in virtual_tables_with_eviction
65+
), "If any virtual table uses FeatureScoreBasedEvictionPolicy, all virtual tables with eviction policies must use FeatureScoreBasedEvictionPolicy"
66+
enable_feature_score_weight_accumulation = True
67+
68+
# Check if any table has enable_auto_feature_score_collection=True
69+
tables_with_auto_collection = [
70+
config
71+
for config in tables_with_feature_score_policy
72+
if config.virtual_table_eviction_policy is not None
73+
and isinstance(
74+
config.virtual_table_eviction_policy,
75+
FeatureScoreBasedEvictionPolicy,
76+
)
77+
and config.virtual_table_eviction_policy.enable_auto_feature_score_collection
78+
]
79+
if tables_with_auto_collection:
80+
# All virtual tables with FeatureScoreBasedEvictionPolicy must have enable_auto_feature_score_collection=True
81+
assert all(
82+
config.virtual_table_eviction_policy is not None
83+
and isinstance(
84+
config.virtual_table_eviction_policy,
85+
FeatureScoreBasedEvictionPolicy,
86+
)
87+
and config.virtual_table_eviction_policy.enable_auto_feature_score_collection
88+
for config in tables_with_feature_score_policy
89+
), "If any virtual table has enable_auto_feature_score_collection=True, all virtual tables with FeatureScoreBasedEvictionPolicy must have enable_auto_feature_score_collection=True"
90+
enabled_feature_score_auto_collection = True
91+
92+
sharding_type_feature_score_mapping: Dict[str, Dict[str, float]] = {}
93+
if enabled_feature_score_auto_collection:
94+
for (
95+
sharding_type,
96+
sharding_info,
97+
) in sharding_type_to_sharding_infos.items():
98+
feature_score_mapping: Dict[str, float] = {}
99+
if sharding_type == ShardingType.DATA_PARALLEL.value:
100+
sharding_type_feature_score_mapping[sharding_type] = (
101+
feature_score_mapping
102+
)
103+
continue
104+
for config in sharding_info:
105+
vtep = config.embedding_config.virtual_table_eviction_policy
106+
if vtep is not None and isinstance(
107+
vtep, FeatureScoreBasedEvictionPolicy
108+
):
109+
if vtep.eviction_ttl_mins > 0:
110+
logger.info(
111+
f"Virtual table eviction policy enabled for table {config.embedding_config.name} {sharding_type} with eviction TTL {vtep.eviction_ttl_mins} mins."
112+
)
113+
feature_score_mapping.update(
114+
dict.fromkeys(config.embedding_config.feature_names, 0.0)
115+
)
116+
continue
117+
for k in config.embedding_config.feature_names:
118+
if (
119+
k
120+
# pyre-ignore [16]
121+
in config.embedding_config.virtual_table_eviction_policy.feature_score_mapping
122+
):
123+
feature_score_mapping[k] = (
124+
config.embedding_config.virtual_table_eviction_policy.feature_score_mapping[
125+
k
126+
]
127+
)
128+
else:
129+
assert (
130+
# pyre-ignore [16]
131+
config.embedding_config.virtual_table_eviction_policy.feature_score_default_value
132+
is not None
133+
), f"Table {config.embedding_config.name} eviction policy feature_score_default_value is not set but feature {k} is not in feature_score_mapping."
134+
feature_score_mapping[k] = (
135+
config.embedding_config.virtual_table_eviction_policy.feature_score_default_value
136+
)
137+
sharding_type_feature_score_mapping[sharding_type] = feature_score_mapping
138+
return (
139+
enable_feature_score_weight_accumulation,
140+
enabled_feature_score_auto_collection,
141+
sharding_type_feature_score_mapping,
142+
)
143+
144+
145+
@torch.fx.wrap
146+
def may_collect_feature_scores(
147+
input_feature_splits: List[KeyedJaggedTensor],
148+
enabled_feature_score_auto_collection: bool,
149+
sharding_type_feature_score_mapping: Dict[str, Dict[str, float]],
150+
) -> List[KeyedJaggedTensor]:
151+
if not enabled_feature_score_auto_collection:
152+
return input_feature_splits
153+
with record_function("## collect_feature_score ##"):
154+
for features, mapping in zip(
155+
input_feature_splits, sharding_type_feature_score_mapping.values()
156+
):
157+
assert (
158+
features.weights_or_none() is None
159+
), f"Auto feature collection: {features.keys()=} has non empty weights"
160+
if (
161+
mapping is None or len(mapping) == 0
162+
): # collection is disabled fir this sharding type
163+
continue
164+
feature_score_weights = []
165+
device = features.device()
166+
for f in features.keys():
167+
# input dist includes multiple lookups input including both virtual table and non-virtual table features.
168+
# We needs to attach weights for all features due to KJT weights requirements, so set 0.0 score for non virtual table features
169+
score = mapping[f] if f in mapping else 0.0
170+
feature_score_weights.append(
171+
torch.ones_like(
172+
features[f].values(),
173+
dtype=torch.float32,
174+
device=device,
175+
)
176+
* score
177+
)
178+
features._weights = (
179+
torch.cat(feature_score_weights, dim=0)
180+
if feature_score_weights
181+
else None
182+
)
183+
return input_feature_splits

0 commit comments

Comments
 (0)