diff --git a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py index 24b32a92b..947b7c366 100644 --- a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py +++ b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py @@ -45,6 +45,61 @@ ) +def _mask_embeddings_by_frequency( + cache: Optional[Cache], + storage: Storage, + unique_keys: torch.Tensor, + unique_embs: torch.Tensor, + frequency_threshold: int, + mask_dims: int, +) -> None: + """ + Mask low-frequency embeddings by setting specific dimensions to zero. + + This function queries scores from cache and storage, then masks embeddings + whose scores are below the frequency threshold. + + Args: + cache: Optional cache table (can be None if caching is disabled) + storage: Storage table (always present) + unique_keys: Keys to query scores for + unique_embs: Embeddings to mask (modified in-place) + frequency_threshold: Minimum score threshold + mask_dims: Number of dimensions to mask (from the end) + """ + batch = unique_keys.size(0) + if batch == 0: + return + assert hasattr( + storage, "query_scores" + ), "If you want to use frequency masking, storage must implement the query_scores method" + # Query scores from cache and storage + if cache is not None: + # 1. Query cache first + cache_scores = cache.query_scores(unique_keys) + cache_founds = cache_scores > 0 + + # 2. Query storage for cache misses + if (~cache_founds).any(): + missing_keys = unique_keys[~cache_founds] + storage_scores = storage.query_scores(missing_keys) + cache_scores[~cache_founds] = storage_scores + + scores = cache_scores + else: + # Without cache: query from storage only + scores = storage.query_scores(unique_keys) + + # Apply masking + low_freq_mask = scores < frequency_threshold + if low_freq_mask.any(): + unique_embs[low_freq_mask, -mask_dims:] = 0.0 + for i in range(unique_embs.size(0)): + print( + f"Row {i}: score = {scores[i].item()}, last {mask_dims} dims = {unique_embs[i, -mask_dims:].tolist()}" + ) + + # TODO: BatchedDynamicEmbeddingFunction is more concrete. class DynamicEmbeddingBagFunction(torch.autograd.Function): @staticmethod @@ -348,6 +403,8 @@ def forward( input_dist_dedup: bool = False, training: bool = True, frequency_counters: Optional[torch.Tensor] = None, + frequency_threshold: int = 0, + mask_dims: int = 0, *args, ): table_num = len(storages) @@ -426,6 +483,17 @@ def forward( lfu_accumulated_frequency_per_table, ) + # Apply frequency-based masking if enabled + if is_lfu_enabled and mask_dims > 0 and frequency_threshold > 0: + _mask_embeddings_by_frequency( + caches[i] if caching else None, + storages[i], + unique_indices_per_table, + unique_embs_per_table, + frequency_threshold, + mask_dims, + ) + if training or caching: output_embs = torch.empty( indices.shape[0], emb_dim, dtype=output_dtype, device=indices.device @@ -501,4 +569,4 @@ def backward(ctx, grads): optimizer, ) - return (None,) * 14 + return (None,) * 16 diff --git a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py index f4814d318..c56e5f140 100644 --- a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py +++ b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py @@ -483,7 +483,8 @@ def __init__( self._table_names = table_names self.bounds_check_mode_int: int = bounds_check_mode.value self._create_score() - + self.frequency_threshold = table_option.frequency_threshold + self.mask_dims = table_option.mask_dims if device is not None: self.device_id = int(str(device)[-1]) else: @@ -984,6 +985,8 @@ def forward( self.use_index_dedup, self.training, per_sample_weights, # Pass frequency counters as weights + self.frequency_threshold, + self.mask_dims, self._empty_tensor, ) for cache in self._caches: diff --git a/corelib/dynamicemb/dynamicemb/dynamicemb_config.py b/corelib/dynamicemb/dynamicemb/dynamicemb_config.py index eb849db7c..4be9203a4 100644 --- a/corelib/dynamicemb/dynamicemb/dynamicemb_config.py +++ b/corelib/dynamicemb/dynamicemb/dynamicemb_config.py @@ -386,6 +386,9 @@ class DynamicEmbTableOptions(_ContextOptions): global_hbm_for_values: int = 0 # in bytes external_storage: Storage = None index_type: Optional[torch.dtype] = None + # Frequency-based masking parameters + frequency_threshold: int = 0 # frequency threshold for masking (0 = disabled) + mask_dims: int = 0 # number of dimensions to mask (0 = disabled) def __post_init__(self): assert ( diff --git a/corelib/dynamicemb/dynamicemb/key_value_table.py b/corelib/dynamicemb/dynamicemb/key_value_table.py index 6e4950438..5f34a1bea 100644 --- a/corelib/dynamicemb/dynamicemb/key_value_table.py +++ b/corelib/dynamicemb/dynamicemb/key_value_table.py @@ -40,6 +40,7 @@ erase, export_batch, export_batch_matched, + find, find_pointers, find_pointers_with_scores, insert_and_evict, @@ -347,6 +348,29 @@ def create_scores( else: return None + def query_scores(self, unique_keys: torch.Tensor) -> torch.Tensor: + """Query scores for given keys from the table. + + Returns: + scores: Tensor of scores, with 0 for keys not found in table + """ + + batch = unique_keys.size(0) + device = unique_keys.device + + scores = torch.empty(batch, device=device, dtype=torch.long) + values = torch.empty( + batch, self.value_dim(), device=device, dtype=self.embedding_dtype() + ) + founds = torch.empty(batch, device=device, dtype=torch.bool) + + find(self.table, batch, unique_keys, values, founds, score=scores) + + # for not found keys, set score to 0 + scores[~founds] = 0 + + return scores + def insert( self, unique_keys: torch.Tensor, diff --git a/corelib/dynamicemb/example/example.py b/corelib/dynamicemb/example/example.py index 2725fb48f..283dcd16c 100644 --- a/corelib/dynamicemb/example/example.py +++ b/corelib/dynamicemb/example/example.py @@ -486,9 +486,11 @@ def get_planner(device, eb_configs, batch_size, optimizer_type, training, cachin initializer_args=DynamicEmbInitializerArgs( mode=DynamicEmbInitializerMode.NORMAL ), - score_strategy=DynamicEmbScoreStrategy.STEP, + score_strategy=DynamicEmbScoreStrategy.LFU, caching=caching, training=training, + frequency_threshold=10, + mask_dims=10, ), )