Skip to content

Commit

Permalink
Merge pull request #155 from NouamaneTazi/nouamane/better-multi-gpu
Browse files Browse the repository at this point in the history
Support Multi-node evaluation
  • Loading branch information
thakur-nandan authored Feb 4, 2025
2 parents f062f03 + 57cd308 commit 1b38876
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 159 deletions.
274 changes: 132 additions & 142 deletions beir/retrieval/search/dense/exact_search_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,88 +2,58 @@
from .util import cos_sim, dot_score
from sentence_transformers import SentenceTransformer
from torch.utils.data import DataLoader
from datasets import Features, Value
from datasets.utils.filelock import FileLock
from datasets import Array2D, Dataset
from datasets import Dataset
from tqdm.autonotebook import tqdm
from typing import Dict, List
from typing import Dict, List, Optional

import logging
import torch
import math
import queue
import os
import time
import numpy as np
from datasets.distributed import split_dataset_by_node
import torch.distributed as dist
from contextlib import contextmanager

logger = logging.getLogger(__name__)

import importlib.util
@contextmanager
def main_rank_first(group: dist.ProcessGroup):
is_main_rank = dist.get_rank(group) == 0
if is_main_rank:
yield

### HuggingFace Evaluate library (pip install evaluate) only available with Python >= 3.7.
### Hence for no import issues with Python 3.6, we move DummyMetric if ``evaluate`` library is found.
if importlib.util.find_spec("evaluate") is not None:
from evaluate.module import EvaluationModule, EvaluationModuleInfo

class DummyMetric(EvaluationModule):
len_queries = None

def _info(self):
return EvaluationModuleInfo(
description="dummy metric to handle storing middle results",
citation="",
features=Features(
{"cos_scores_top_k_values": Array2D((None, self.len_queries), "float32"), "cos_scores_top_k_idx": Array2D((None, self.len_queries), "int32"), "batch_index": Value("int32")},
),
)
dist.barrier(group)

if not is_main_rank:
yield

def _compute(self, cos_scores_top_k_values, cos_scores_top_k_idx, batch_index):
for i in range(len(batch_index) - 1, -1, -1):
if batch_index[i] == -1:
del cos_scores_top_k_values[i]
del cos_scores_top_k_idx[i]
cos_scores_top_k_values = np.concatenate(cos_scores_top_k_values, axis=0)
cos_scores_top_k_idx = np.concatenate(cos_scores_top_k_idx, axis=0)
return cos_scores_top_k_values, cos_scores_top_k_idx

def warmup(self):
"""
Add dummy batch to acquire filelocks for all processes and avoid getting errors
"""
self.add_batch(cos_scores_top_k_values=torch.ones((1, 1, self.len_queries), dtype=torch.float32), cos_scores_top_k_idx=torch.ones((1, 1, self.len_queries), dtype=torch.int32), batch_index=-torch.ones(1, dtype=torch.int32))

#Parent class for any dense model
# parent class for any dense model that can be used for retrieval
# Abstract class is BaseSearch
class DenseRetrievalParallelExactSearch(BaseSearch):

def __init__(self, model, batch_size: int = 128, corpus_chunk_size: int = None, target_devices: List[str] = None, **kwargs):
def __init__(self, model, batch_size: int = 128, corpus_chunk_size: Optional[int] = None, **kwargs):
#model is class that provides encode_corpus() and encode_queries()
self.model = model
self.batch_size = batch_size
if target_devices is None:
if torch.cuda.is_available():
target_devices = ['cuda:{}'.format(i) for i in range(torch.cuda.device_count())]
else:
logger.info("CUDA is not available. Start 4 CPU worker")
target_devices = ['cpu']*1 # 4
self.target_devices = target_devices # PyTorch target devices, e.g. cuda:0, cuda:1... If None, all available CUDA devices will be used, or 4 CPU processes
self.encoding_batch_size = batch_size
self.score_functions = {'cos_sim': cos_sim, 'dot': dot_score}
self.score_function_desc = {'cos_sim': "Cosine Similarity", 'dot': "Dot Product"}
self.corpus_chunk_size = corpus_chunk_size
self.show_progress_bar = kwargs.get("show_progress_bar", True)
self.corpus_chunk_size = batch_size * 100 if corpus_chunk_size is None else corpus_chunk_size
self.show_progress_bar = kwargs.get("show_progress_bar", True) if int(os.getenv("RANK", 0)) == 0 else False
self.convert_to_tensor = kwargs.get("convert_to_tensor", True)
self.results = {}

self.query_embeddings = {}
self.top_k = None
self.score_function = None
self.sort_corpus = True
self.experiment_id = "exact_search_multi_gpu" # f"test_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

@torch.no_grad()
def search(self,
corpus: Dataset,
queries: Dataset,
top_k: int,
score_function: str,
ignore_identical_ids: bool = True,
**kwargs) -> Dict[str, Dict[str, float]]:
#Create embeddings for all queries using model.encode_queries()
#Runs semantic search against the corpus embeddings
Expand All @@ -92,114 +62,134 @@ def search(self,
raise ValueError("score function: {} must be either (cos_sim) for cosine similarity or (dot) for dot product".format(score_function))
logger.info("Scoring Function: {} ({})".format(self.score_function_desc[score_function], score_function))

if importlib.util.find_spec("evaluate") is None:
raise ImportError("evaluate library not available. Please do ``pip install evaluate`` library with Python>=3.7 (not available with Python 3.6) to use distributed and multigpu evaluation.")

self.corpus_chunk_size = min(math.ceil(len(corpus) / len(self.target_devices) / 10), 5000) if self.corpus_chunk_size is None else self.corpus_chunk_size
self.corpus_chunk_size = min(self.corpus_chunk_size, len(corpus)-1) # to avoid getting error in metric.compute()

if self.sort_corpus:
logger.info("Sorting Corpus by document length (Longest first)...")
corpus = corpus.map(lambda x: {'len': len(x.get("title", "") + x.get("text", ""))}, num_proc=4)
corpus = corpus.sort('len', reverse=True)
world_size = int(os.getenv("WORLD_SIZE", 1))
rank = int(os.getenv("RANK", 0))

if ignore_identical_ids:
with main_rank_first(dist.group.WORLD):
# We remove the query from results if it exists in corpus
logger.info("Ignoring identical ids between queries and corpus...")
start_time = time.time()
# batched filter must return list of bools
queries_ids = queries["id"]
corpus = corpus.filter(lambda batch: [cid not in queries_ids for cid in batch["id"]], batched=True, num_proc=12, load_from_cache_file=True)
logger.info("Ignore identical ids took {:.2f} seconds".format(time.time() - start_time))

# add index column
corpus = corpus.add_column("mteb_index", range(len(corpus)))

# Split corpus across devices to parallelize encoding
assert isinstance(corpus, Dataset), f"Corpus must be a Dataset object, but got {type(corpus)}"
local_corpus = split_dataset_by_node(corpus, rank=rank, world_size=world_size)
logger.info(f"Splitted corpus into {world_size} chunks. Rank {int(os.getenv('RANK', 0))} has {len(local_corpus)} docs.")

all_ranks_corpus_start_idx = local_corpus["mteb_index"][0] # I'm assuming `split_dataset_by_node` splits local_corpus evenly while keeping same order

# Initiate dataloader
queries_dl = DataLoader(queries, batch_size=self.corpus_chunk_size)
corpus_dl = DataLoader(corpus, batch_size=self.corpus_chunk_size)
corpus_dl = DataLoader(local_corpus, batch_size=self.corpus_chunk_size)

# Encode queries
logger.info("Encoding Queries in batches...")
# Encode queries (all ranks do this)
logger.info(f"Encoding Queries ({len(queries)} queries) in {len(queries) // self.corpus_chunk_size + 1} batches... Warning: This might take a while!")
query_embeddings = []
for step, queries_batch in enumerate(queries_dl):
with torch.no_grad():
q_embeds = self.model.encode_queries(
queries_batch['text'], batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_tensor=self.convert_to_tensor)
for step, queries_batch in tqdm(enumerate(queries_dl), total=len(queries) // self.corpus_chunk_size + 1, desc="Encode Queries", disable=rank != 0):
q_embeds = self.model.encode_queries(
queries_batch['text'], batch_size=self.encoding_batch_size, show_progress_bar=self.show_progress_bar, convert_to_tensor=self.convert_to_tensor)
query_embeddings.append(q_embeds)
query_embeddings = torch.cat(query_embeddings, dim=0)

# copy the query embeddings to all target devices
self.query_embeddings = query_embeddings
self.top_k = top_k
self.score_function = score_function

# Start the multi-process pool on all target devices
SentenceTransformer._encode_multi_process_worker = self._encode_multi_process_worker
pool = self.model.start_multi_process_pool(self.target_devices)

logger.info("Encoding Corpus in batches... Warning: This might take a while!")
# Encode corpus (each rank encodes `local_corpus`)
logger.info(f"Encoding Corpus ({len(local_corpus)} docs) in {len(local_corpus) // self.corpus_chunk_size + 1} batches... Warning: This might take a while!")
start_time = time.time()
for chunk_id, corpus_batch in tqdm(enumerate(corpus_dl), total=len(corpus) // self.corpus_chunk_size):
with torch.no_grad():
self.model.encode_corpus_parallel(
corpus_batch, pool=pool, batch_size=self.batch_size, chunk_id=chunk_id)
query_ids = queries["id"]
self.results = {qid: {} for qid in query_ids}
all_chunks_cos_scores_top_k_values = []
all_chunks_cos_scores_top_k_idx = []
for chunk_id, corpus_batch in tqdm(enumerate(corpus_dl), total=len(local_corpus) // self.corpus_chunk_size + 1, desc="Encode Corpus", disable=rank != 0):
corpus_start_idx = chunk_id * self.corpus_chunk_size

# Encode chunk of local_corpus
sub_corpus_embeddings = self.model.encode_corpus(
corpus_batch,
batch_size=self.encoding_batch_size,
show_progress_bar=self.show_progress_bar,
convert_to_tensor = self.convert_to_tensor
)

# Compute similarites using either cosine-similarity or dot product
cos_scores = self.score_functions[score_function](query_embeddings, sub_corpus_embeddings) # (num_queries, self.corpus_chunk_size)
cos_scores[torch.isnan(cos_scores)] = -1

# Get top-k values
cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(cos_scores, min(top_k, cos_scores.shape[1]), dim=1, largest=True) # (num_queries, top_k)

# Displace index by corpus_start_idx
cos_scores_top_k_idx += corpus_start_idx

all_chunks_cos_scores_top_k_values.append(cos_scores_top_k_values)
all_chunks_cos_scores_top_k_idx.append(cos_scores_top_k_idx)

# Stop the proccesses in the pool and free memory
self.model.stop_multi_process_pool(pool)

end_time = time.time()
logger.info("Encoded all batches in {:.2f} seconds".format(end_time - start_time))

# Gather all results
DummyMetric.len_queries = len(queries)
metric = DummyMetric(experiment_id=self.experiment_id, num_process=len(self.target_devices), process_id=0)
metric.filelock = FileLock(os.path.join(metric.data_dir, f"{metric.experiment_id}-{metric.num_process}-{metric.process_id}.arrow.lock"))
metric.cache_file_name = os.path.join(metric.data_dir, f"{metric.experiment_id}-{metric.num_process}-{metric.process_id}.arrow")
# TODO: have shapes (top_k, num_queries) instead?
# keep only top_k top scoring docs from this rank
all_chunks_cos_scores_top_k_values = torch.cat(all_chunks_cos_scores_top_k_values, dim=1) # (num_queries, (top_k)*n_chunks)
all_chunks_cos_scores_top_k_idx = torch.cat(all_chunks_cos_scores_top_k_idx, dim=1) # (num_queries, (top_k)*n_chunks)
cos_scores_top_k_values, temp_cos_scores_top_k_idx = torch.topk(all_chunks_cos_scores_top_k_values, min(top_k, all_chunks_cos_scores_top_k_values.shape[1]), dim=1, largest=True) # (num_queries, top_k)

cos_scores_top_k_values, cos_scores_top_k_idx = metric.compute()
# `all_chunks_cos_scores_top_k_idx` should index `corpus_ids` and not `all_chunks_cos_scores_top_k_values`
all_chunks_cos_scores_top_k_idx = torch.gather(all_chunks_cos_scores_top_k_idx, dim=1, index=temp_cos_scores_top_k_idx) # (num_queries, top_k) // indexes between (0, len(local_corpus))

# sort similar docs for each query by cosine similarity and keep only top_k
sorted_idx = np.argsort(cos_scores_top_k_values, axis=0)[::-1]
sorted_idx = sorted_idx[:self.top_k+1]
cos_scores_top_k_values = np.take_along_axis(cos_scores_top_k_values, sorted_idx, axis=0)
cos_scores_top_k_idx = np.take_along_axis(cos_scores_top_k_idx, sorted_idx, axis=0)
# Displace index by all_ranks_corpus_start_idx so that we index `corpus` and not `local_corpus`
all_chunks_cos_scores_top_k_idx += all_ranks_corpus_start_idx

logger.info("Formatting results...")
# Load corpus ids in memory
query_ids = queries['id']
corpus_ids = corpus['id']
self.results = {qid: {} for qid in query_ids}
for query_itr in tqdm(range(len(query_embeddings))):
query_id = query_ids[query_itr]
for i in range(len(cos_scores_top_k_values)):
sub_corpus_id = cos_scores_top_k_idx[i][query_itr]
score = cos_scores_top_k_values[i][query_itr].item() # convert np.float to float
corpus_id = corpus_ids[sub_corpus_id]
if corpus_id != query_id:
self.results[query_id][corpus_id] = score
return self.results

def _encode_multi_process_worker(self, process_id, device, model, input_queue, results_queue):
"""
(taken from UKPLab/sentence-transformers/sentence_transformers/SentenceTransformer.py)
Internal working process to encode sentences in multi-process setup.
Note: Added distributed similarity computing and finding top k similar docs.
"""
DummyMetric.len_queries = len(self.query_embeddings)
metric = DummyMetric(experiment_id=self.experiment_id, num_process=len(self.target_devices), process_id=process_id)
metric.warmup()
with torch.no_grad():
while True:
try:
id, batch_size, sentences = input_queue.get()
corpus_embeds = model.encode(
sentences, device=device, show_progress_bar=False, convert_to_tensor=True, batch_size=batch_size
).detach()

cos_scores = self.score_functions[self.score_function](self.query_embeddings.to(corpus_embeds.device), corpus_embeds).detach()
cos_scores[torch.isnan(cos_scores)] = -1

#Get top-k values
cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(cos_scores, min(self.top_k+1, len(cos_scores[1])), dim=1, largest=True, sorted=False)
cos_scores_top_k_values = cos_scores_top_k_values.T.unsqueeze(0).detach()
cos_scores_top_k_idx = cos_scores_top_k_idx.T.unsqueeze(0).detach()

# correct sentence ids
cos_scores_top_k_idx += id * self.corpus_chunk_size

# Store results in an Apache Arrow table
metric.add_batch(cos_scores_top_k_values=cos_scores_top_k_values, cos_scores_top_k_idx=cos_scores_top_k_idx, batch_index=[id]*len(cos_scores_top_k_values))

# Alarm that process finished processing a batch
results_queue.put(None)
except queue.Empty:
break
# If local_corpus doesn't have top_k samples we pad scores to `pad_gathered_tensor_to`
if len(local_corpus) < top_k:
pad_gathered_tensor_to = math.ceil(len(corpus) / world_size)
cos_scores_top_k_values = torch.cat(
[cos_scores_top_k_values, torch.ones((cos_scores_top_k_values.shape[0], pad_gathered_tensor_to - cos_scores_top_k_values.shape[1]), device=cos_scores_top_k_values.device) * -1],
dim=1
)
all_chunks_cos_scores_top_k_idx = torch.cat(
[all_chunks_cos_scores_top_k_idx, torch.ones((all_chunks_cos_scores_top_k_idx.shape[0], pad_gathered_tensor_to - all_chunks_cos_scores_top_k_idx.shape[1]), device=all_chunks_cos_scores_top_k_idx.device, dtype=all_chunks_cos_scores_top_k_idx.dtype) * -1],
dim=1
)
else:
pad_gathered_tensor_to = top_k

# all gather top_k results from all ranks
n_queries = len(query_ids)
all_ranks_top_k_values = torch.empty((world_size, n_queries, pad_gathered_tensor_to), dtype=cos_scores_top_k_values.dtype, device="cuda")
all_ranks_top_k_idx = torch.empty((world_size, n_queries, pad_gathered_tensor_to), dtype=all_chunks_cos_scores_top_k_idx.dtype, device="cuda")
dist.barrier()
logger.info(f"All gather top_k values from all ranks...")
dist.all_gather_into_tensor(all_ranks_top_k_values, cos_scores_top_k_values)
logger.info(f"All gather top_k idx from all ranks...")
dist.all_gather_into_tensor(all_ranks_top_k_idx, all_chunks_cos_scores_top_k_idx)
logger.info(f"All gather ... Done!")

all_ranks_top_k_values = all_ranks_top_k_values.permute(1, 0, 2).reshape(n_queries, -1) # (n_queries, world_size*(pad_gathered_tensor_to))
all_ranks_top_k_idx = all_ranks_top_k_idx.permute(1, 0, 2).reshape(n_queries, -1) # (n_queries, world_size*(pad_gathered_tensor_to))

# keep only top_k top scoring docs from all ranks
cos_scores_top_k_values, temp_cos_scores_top_k_idx = torch.topk(all_ranks_top_k_values, min(top_k, all_ranks_top_k_values.shape[1]), dim=1, largest=True) # (num_queries, top_k)

# `all_ranks_top_k_idx` should index `corpus_ids` and not `all_ranks_top_k_values`
all_ranks_top_k_idx = torch.gather(all_ranks_top_k_idx, dim=1, index=temp_cos_scores_top_k_idx) # (num_queries, top_k) // indexes between (0, len(corpus))

# fill in results
cos_scores_top_k_values = cos_scores_top_k_values.tolist()
all_ranks_top_k_idx = all_ranks_top_k_idx.tolist()
corpus_i_to_idx = corpus["id"]
for qid, top_k_values, top_k_idx in zip(query_ids, cos_scores_top_k_values, all_ranks_top_k_idx):
for score, corpus_i in zip(top_k_values, top_k_idx):
corpus_idx = corpus_i_to_idx[corpus_i]
assert ignore_identical_ids is False or corpus_idx != qid, "Query id and corpus id should not be the same if ignore_identical_ids is set to True."
self.results[qid][corpus_idx] = score
return self.results
Loading

0 comments on commit 1b38876

Please sign in to comment.