Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ process:
redis_address: 'redis://localhost:6379' # the address of redis server
lowercase: false # whether to convert text to lower case
ignore_non_character: false # whether to ignore non-alphabet characters, including whitespaces, digits, and punctuations
- ray_bts_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm
- ray_bts_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm
tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece]
window_size: 5 # window size of shingling
num_permutations: 256 # number of permutations in minhash computing
Expand All @@ -1027,6 +1027,16 @@ process:
tmp_file_name: './outputs/ray-dedup-tmp/' # the temporary folder name for deduplication.

# Selector ops
- domain_diversity_selector: # selector to select samples based on the data's domain diversity
api_or_hf_model: 'text-embedding-v3' # API or huggingface embedding model name
is_hf_model: False # indicates if the model is from HuggingFace
api_endpoint: '/embeddings' # embedding URL endpoint for the API
response_path: 'data.0.embedding' # path to extract content from the API response
model_params: {} # parameters for initializing the API model
select_ratio: # the ratio to be sampled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The select_ratio is left empty in this configuration example. While the code correctly handles a None value by skipping the operation, it would be more user-friendly for an example configuration to provide a default value (e.g., 0.5) and add a comment explaining its purpose.

      select_ratio: 0.5                                           # the ratio to be sampled

init_k: 3 # the value of k in k-means algorithm
ebd_dim: 512 # the embedding's dimension via API
strategy: 'inter' # the selection strategy based on the relation across domains
- frequency_specified_field_selector: # selector to select samples based on the sorted frequency of specified field value
field_key: '' # the target keys corresponding to multi-level field information need to be separated by '.'
top_ratio: # ratio of selected top specified field value
Expand Down
2 changes: 2 additions & 0 deletions data_juicer/ops/selector/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .domain_diversity_selector import DomainDiversitySelector
from .frequency_specified_field_selector import FrequencySpecifiedFieldSelector
from .random_selector import RandomSelector
from .range_specified_field_selector import RangeSpecifiedFieldSelector
from .tags_specified_field_selector import TagsSpecifiedFieldSelector
from .topk_specified_field_selector import TopkSpecifiedFieldSelector

__all__ = [
"DomainDiversitySelector",
"FrequencySpecifiedFieldSelector",
"RandomSelector",
"RangeSpecifiedFieldSelector",
Expand Down
187 changes: 187 additions & 0 deletions data_juicer/ops/selector/domain_diversity_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from typing import Dict, Optional

import numpy as np
from pydantic import Field, PositiveInt
from sklearn.cluster import KMeans
from tqdm import tqdm
from typing_extensions import Annotated

from data_juicer.ops.base_op import OPERATORS, Selector
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model

torch = LazyLoader("torch")


@OPERATORS.register_module("domain_diversity_selector")
class DomainDiversitySelector(Selector):
"""Selector to select samples based on the data's domain diversity."""

_accelerator = "cuda"

def __init__(
self,
api_or_hf_model: str = "text-embedding-v3",
is_hf_model: bool = False,
api_endpoint: str = "/embeddings",
response_path: str = "data.0.embedding",
model_params: Dict = {},
select_ratio: Optional[Annotated[float, Field(ge=0, le=1)]] = None,
init_k: PositiveInt = 3,
ebd_dim: PositiveInt = 512,
strategy: str = "inter",
*args,
**kwargs,
):
"""
Initialization method.
:param api_or_hf_model: API or huggingface embedding model name.
:param is_hf_model: Indicates if the model is from HuggingFace.
:param api_endpoint: Embedding URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'data.0.embedding' for embedding model.
:param model_params: Parameters for initializing the API model.
:param select_ratio: The ratio to select.
:param init_k: The value of k in k-means algorithm.
:param ebd_dim: The embedding's dimension via API.
:param strategy: 'inter' - Domain's inter diversity,
'intra' - Domain's intra diversity,
'global' - Diversity to global centroid.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.api_or_hf_model = api_or_hf_model
self.is_hf_model = is_hf_model
self.api_endpoint = api_endpoint
self.response_path = response_path
self.select_ratio = select_ratio
self.init_k = init_k
self.ebd_dim = ebd_dim
self.strategy = strategy

if is_hf_model:
self.model_key = prepare_model(
model_type="embedding", model_path=api_or_hf_model, trust_remote_code=True, **model_params
)
else:
self.model_key = prepare_model(
model_type="api",
model=api_or_hf_model,
endpoint=self.api_endpoint,
response_path=self.response_path,
**model_params,
)

def dataset_embedding(self, dataset, rank=None):
embeddings = []
model = get_model(self.model_key, rank, self.use_cuda())

if self.is_hf_model:
# Embeddings extract via local models
for sample in tqdm(dataset, desc="Embedding", unit="sample"):
text = sample["text"]
with torch.no_grad():
embedding = model.encode(text)
embeddings.append(embedding)
else:
# Embeddings extract via API
for sample in tqdm(dataset, desc="Embedding", unit="sample"):
text = sample["text"]
embedding = model(text, dimensions=self.ebd_dim, encoding_format="float")
embeddings.append(embedding)
Comment on lines +81 to +93
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation for generating embeddings processes samples one by one, which is highly inefficient for large datasets. Both HuggingFace encode methods and many embedding APIs support batch processing. Refactoring this to process samples in batches will significantly improve performance.

Suggested change
if self.is_hf_model:
# Embeddings extract via local models
for sample in tqdm(dataset, desc="Embedding", unit="sample"):
text = sample["text"]
with torch.no_grad():
embedding = model.encode(text)
embeddings.append(embedding)
else:
# Embeddings extract via API
for sample in tqdm(dataset, desc="Embedding", unit="sample"):
text = sample["text"]
embedding = model(text, dimensions=self.ebd_dim, encoding_format="float")
embeddings.append(embedding)
if self.is_hf_model:
# Embeddings extract via local models in batches for efficiency
texts = [sample['text'] for sample in dataset]
with torch.no_grad():
embeddings = model.encode(texts, batch_size=self.batch_size, show_progress_bar=True)
else:
# Embeddings extract via API. Consider batching if the API supports it.
for sample in tqdm(dataset, desc="Embedding", unit="sample"):
text = sample["text"]
embedding = model(text, dimensions=self.ebd_dim, encoding_format="float")
embeddings.append(embedding)


embeddings = np.array(embeddings)
return embeddings

def domain_diversity_status(self, dataset):

data_status = []

embeddings_array = self.dataset_embedding(dataset)
global_centroid = np.mean(embeddings_array, axis=0)

# K-means cluster
kmeans = KMeans(n_clusters=self.init_k, random_state=42)
labels = kmeans.fit_predict(embeddings_array)

centroid_embeddings = []
for label in np.unique(labels):
label_embeddings = embeddings_array[labels == label]
centroid = np.mean(label_embeddings, axis=0)
centroid_embeddings.append(centroid)

centroid_embeddings = np.array(centroid_embeddings)

# Sample-level cos-similarity to other centroids
for i, entry in tqdm(enumerate(dataset), total=len(dataset), desc="Calculating similarity:"):
current_embedding = embeddings_array[i]
current_label = int(labels[i])

similarities = []
for j, centroid in enumerate(centroid_embeddings):
if j != current_label:
similarity = torch.nn.functional.cosine_similarity(
torch.tensor(current_embedding).unsqueeze(0), torch.tensor(centroid).unsqueeze(0)
).item()
similarities.append(similarity)

own_centroid_similarity = torch.nn.functional.cosine_similarity(
torch.tensor(current_embedding).unsqueeze(0),
torch.tensor(centroid_embeddings[current_label]).unsqueeze(0),
).item()

global_centroid_similarity = torch.nn.functional.cosine_similarity(
torch.tensor(current_embedding).unsqueeze(0), torch.tensor(global_centroid).unsqueeze(0)
).item()
total_similarity = sum(similarities)

data_status.append(
{
"text": entry["text"],
"label": current_label,
"similarity_with_other_centroids": similarities,
"total_similarity": total_similarity,
"similarity_with_own_centroid": own_centroid_similarity,
"global_centroid_similarity": global_centroid_similarity,
"original_index": i,
}
)
Comment on lines +117 to +150
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This section has critical performance and memory issues that should be addressed:

  1. Performance: Cosine similarities are calculated in a loop, creating and destroying PyTorch tensors for each sample. This is very slow. These calculations should be vectorized using sklearn.metrics.pairwise.cosine_similarity or torch.nn.functional.cosine_similarity on the entire matrices.
  2. Memory: The data_status list stores a dictionary for each sample, including the original text. For large datasets, this will consume a very large amount of memory. It's more memory-efficient to work with NumPy arrays for embeddings, labels, and calculated similarities, and only use indices to refer back to the original dataset.

A full refactor is recommended to process these calculations in a batched/vectorized manner and to avoid creating large intermediate data structures.


return data_status, labels

def diversity_process(self, dataset):
data_status, labels = self.domain_diversity_status(dataset)
select_indices = []

for label in np.unique(labels):
label_data_status = [item for item in data_status if item["label"] == label]

# Related to the strategy
if self.strategy == "inter":
label_data_status.sort(key=lambda x: x["total_similarity"])
elif self.strategy == "intra":
label_data_status.sort(key=lambda x: x["similarity_with_own_centroid"], reverse=True)
elif self.strategy == "global":
label_data_status.sort(key=lambda x: x["global_centroid_similarity"])
else:
raise ValueError("Invalid strategy. Use 'inter', 'intra' or 'global'.")
Comment on lines +162 to +169
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if/elif/else chain for selecting the sorting key based on the strategy can be made more concise and extensible by using a dictionary to map strategy names to their corresponding sorting logic (key and reverse flag).

Suggested change
if self.strategy == "inter":
label_data_status.sort(key=lambda x: x["total_similarity"])
elif self.strategy == "intra":
label_data_status.sort(key=lambda x: x["similarity_with_own_centroid"], reverse=True)
elif self.strategy == "global":
label_data_status.sort(key=lambda x: x["global_centroid_similarity"])
else:
raise ValueError("Invalid strategy. Use 'inter', 'intra' or 'global'.")
strategy_configs = {
'inter': {'key': 'total_similarity', 'reverse': False},
'intra': {'key': 'similarity_with_own_centroid', 'reverse': True},
'global': {'key': 'global_centroid_similarity', 'reverse': False},
}
if self.strategy not in strategy_configs:
raise ValueError("Invalid strategy. Use 'inter', 'intra' or 'global'.")
config = strategy_configs[self.strategy]
label_data_status.sort(key=lambda x: x[config['key']], reverse=config['reverse'])


num_to_select = max(1, int(self.select_ratio * len(label_data_status)))
selected_indices = [item["original_index"] for item in label_data_status[:num_to_select]]
select_indices.extend(selected_indices)

select_dataset = dataset.select(select_indices)

return select_dataset

def process(self, dataset):

if len(dataset) <= 1:
return dataset
if self.select_ratio is None:
return dataset

select_dataset = self.diversity_process(dataset)
return select_dataset
Loading
Loading