Skip to content

Adding support for reranker and other utilities #258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ For fast retrieval, indexing precomputes the ColBERT representations of passages

Example usage:

```
```python
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert import Indexer

Expand All @@ -120,7 +120,7 @@ if __name__=='__main__':

We typically recommend that you use ColBERT for **end-to-end** retrieval, where it directly finds its top-k passages from the full collection:

```
```python
from colbert.data import Queries
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert import Searcher
Expand Down Expand Up @@ -154,7 +154,7 @@ Training requires a JSONL triples file with a `[qid, pid+, pid-]` list per line.

Example usage (training on 4 GPUs):

```
```python
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert import Trainer

Expand All @@ -177,6 +177,38 @@ if __name__=='__main__':
print(f"Saved checkpoint to {checkpoint_path}...")
```

## Reranking
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really cool, Herumb!


We provide an easy to use interface that can be used for reranking and rank fusion for multiple rankings. `Reranker` takes Queries, Collection and Ranking as input and you can get the reranking by passing the checkpoint to the reranker model. Here is an example:-

```python
from colbert.data.ranking import Ranking
from colbert.data.queries import Queries
from colbert.data.collection import Collection

from colbert.infra import Run, RunConfig
from colbert.reranker import Reranker

if __name__=='__main__':
with Run().context(RunConfig(
nranks=number_of_gpu_devices,
root="path/to/experiments",
experiment="awesome_experiment",
name='awesome_run_name'
)):
ranking = Ranking(path = 'path/to/ranking/file')
queries = Queries(path = 'path/to/query/file')
collection = Collection(path = 'path/to/collection/file')

reranker = Reranker(ranking=ranking, queries=queries, collection=collection)
score_file = reranker.rerank(checkpoint)
```

Additionally you can use ranx to fuse multiple ranking together too! Just pass the rankings and fusion strategy and the rest will be taken care of:-
```python

```

## Running a lightweight ColBERTv2 server
We provide a script to run a lightweight server which serves k (upto 100) results in ranked order for a given search query, in JSON format. This script can be used to power DSP programs.

Expand Down
26 changes: 24 additions & 2 deletions colbert/data/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
# I think multiprocessing.Manager can do that!

import os
import jsonl
import itertools

from colbert.evaluation.loaders import load_collection
from colbert.utils.utils import print_message
from colbert.infra.run import Run
from colbert.evaluation.loaders import load_collection


class Collection:
Expand Down Expand Up @@ -36,7 +38,27 @@ def _load_tsv(self, path):
return load_collection(path)

def _load_jsonl(self, path):
raise NotImplementedError()
assert path.endswith('.jsonl'), "ColBERTv2.0 only support .tsv and .jsonl collection files for now."
print_message("#> Loading collection...")

with open(path, 'r') as json_file:
collection_list = list(json_file)

collection = []
for line_idx, line in enumerate(collection_list):
if line_idx % (1000*1000) == 0:
print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True)

passage = json.loads(line)

pid = passage['id']
collection.append(passage['contents'])

assert int(pid) == line_idx, f"pid={pid}, line_idx={line_idx}"

print()

return collection_list

def provenance(self):
return self.path
Expand Down
19 changes: 12 additions & 7 deletions colbert/data/examples.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from colbert.infra.run import Run
import os
import ujson
import random

from colbert.infra.run import Run
from colbert.utils.utils import print_message
from colbert.infra.provenance import Provenance
from utility.utils.save_metadata import get_metadata_only


class Examples:
def __init__(self, path=None, data=None, nway=None, provenance=None):
def __init__(self, path=None, data=None, nway=None, provenance=None, shuffle=False):
self.__provenance = provenance or path or Provenance()
self.nway = nway
self.path = path
self.shuffle = shuffle
self.data = data or self._load_file(path)

def provenance(self):
Expand All @@ -28,7 +30,10 @@ def _load_file(self, path):
for line in f:
example = ujson.loads(line)[:nway]
examples.append(example)


if self.shuffle:
random.shuffle(examples)

return examples

def tolist(self, rank=None, nranks=None):
Expand All @@ -41,7 +46,7 @@ def tolist(self, rank=None, nranks=None):

if rank or nranks:
assert rank in range(nranks), (rank, nranks)
return [self.data[idx] for idx in range(0, len(self.data), nranks)] # if line_idx % nranks == rank
return [self.data[idx + rank] for idx in range(0, len(self.data), nranks) if idx + rank < len(self.data)] # if line_idx % nranks == rank

return list(self.data)

Expand All @@ -68,12 +73,12 @@ def save(self, new_path):
return output_path

@classmethod
def cast(cls, obj, nway=None):
def cast(cls, obj, nway=None, shuffle=False):
if type(obj) is str:
return cls(path=obj, nway=nway)
return cls(path=obj, nway=nway, shuffle=shuffle)

if isinstance(obj, list):
return cls(data=obj, nway=nway)
return cls(data=obj, nway=nway, shuffle=shuffle)

if type(obj) is cls:
assert nway is None, nway
Expand Down
22 changes: 15 additions & 7 deletions colbert/data/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@


class Queries:
def __init__(self, path=None, data=None):
def __init__(self, path=None, data=None, query_id_key='id', query_key='query', body_key='body', include_body=True, body_char_limit=1000):
self.path = path
self.query_key = query_key
self.query_id_key = query_id_key
self.body_key = body_key
self.include_body = include_body
self.body_char_limit = body_char_limit

if data:
assert isinstance(data, dict), type(data)
Expand All @@ -37,7 +42,7 @@ def _load_data(self, data):

for qid, content in data.items():
if isinstance(content, dict):
self.data[qid] = content['question']
self.data[qid] = content[self.query_key]
self._qas[qid] = content
else:
self.data[qid] = content
Expand All @@ -48,7 +53,7 @@ def _load_data(self, data):
return True

def _load_file(self, path):
if not path.endswith('.json'):
if not (path.endswith('.json') or path.endswith('.jsonl')):
self.data = load_queries(path)
return True

Expand All @@ -60,9 +65,12 @@ def _load_file(self, path):
for line in f:
qa = ujson.loads(line)

assert qa['qid'] not in self.data
self.data[qa['qid']] = qa['question']
self._qas[qa['qid']] = qa
assert qa[self.query_id_key] not in self.data
if self.include_body and self.body_key in qa:
self.data[qa[self.query_id_key]] = qa[self.query_key] + '|' + qa[self.body_key][:self.body_char_limit]
else:
self.data[qa[self.query_id_key]] = qa[self.query_key]
self._qas[qa[self.query_id_key]] = qa

return self.data

Expand Down Expand Up @@ -98,7 +106,7 @@ def save_qas(self, new_path):

with open(new_path, 'w') as f:
for qid, qa in self._qas.items():
qa['qid'] = qid
qa[self.query_id_key] = qid
f.write(ujson.dumps(qa) + '\n')

def _load_tsv(self, path):
Expand Down
2 changes: 2 additions & 0 deletions colbert/distillation/hf_scorers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .bge_reranker import BGERerankerScorer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this here?

Copy link
Author

@krypticmouse krypticmouse Sep 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the init file or bge models and hf scorers?

No in both. We can remove init file and we can remove hf scorers as whole too given they don't play any role as of now.

from .bge_large import BGELargeV15Scorer
31 changes: 31 additions & 0 deletions colbert/distillation/hf_scorers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from colbert.infra import Run
from colbert.parameters import DEVICE
from colbert.utils.utils import flatten
from colbert.infra.launcher import Launcher

class BaseHFScorer:
def __init__(self, queries, collection, model, bsize=32, maxlen=180):
self.queries = queries
self.collection = collection
self.model = model

self.device = DEVICE
self.bsize = bsize
self.maxlen = maxlen

def launch(self, qids, pids):
launcher = Launcher(self._score_pairs_process, return_all=True)
outputs = launcher.launch(Run().config, qids, pids)

return flatten(outputs)

def _score_pairs_process(self, config, qids, pids):
assert len(qids) == len(pids), (len(qids), len(pids))
share = 1 + len(qids) // config.nranks
offset = config.rank * share
endpos = (1 + config.rank) * share

return self.score(qids[offset:endpos], pids[offset:endpos], show_progress=(config.rank < 1))

def score(self, qids, pids):
raise NotImplementedError
63 changes: 63 additions & 0 deletions colbert/distillation/hf_scorers/bge_large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import tqdm
from transformers import AutoTokenizer, AutoModel

from colbert.infra import Run
from colbert.distillation.hf_scorers.base import BaseHFScorer

class BGELargeV15Scorer(BaseHFScorer):
def __init__(self, queries, collection, model, bsize=32, maxlen=180, query_instruction=None):
super().__init__(queries, collection, model, bsize=bsize, maxlen=maxlen)

self.query_instruction = query_instruction or "Represent this sentence for searching relevant passages:"

def score(self, qids, pids, show_progress=False):
tokenizer = AutoTokenizer.from_pretrained(self.model)
model = AutoModel.from_pretrained(self.model).to(self.device)

assert len(qids) == len(pids), (len(qids), len(pids))

scores = []

model.eval()
with torch.inference_mode():
with torch.cuda.amp.autocast():
for offset in tqdm.tqdm(range(0, len(qids), self.bsize), disable=(not show_progress)):
endpos = offset + self.bsize

if self.query_instruction is None:
queries_ = [self.queries[qid] for qid in qids[offset:endpos]]
else:
queries_ = [self.query_instruction + self.queries[qid] for qid in qids[offset:endpos]]

try:
passages_ = [self.collection[pid] for pid in pids[offset:endpos]]
except:
print(pids[offset:endpos])
raise Exception

query_features = tokenizer(queries_, padding='longest', truncation=True,
return_tensors='pt', max_length=self.maxlen).to(self.device)

passage_features = tokenizer(passages_, padding='longest', truncation=True,
return_tensors='pt', max_length=self.maxlen).to(self.device)

query_embeddings = model(**query_features)
query_embeddings = query_embeddings[0][:, 0]
query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)

passage_embeddings = model(**passage_features)
passage_embeddings = passage_embeddings[0][:, 0]
passage_embeddings = torch.nn.functional.normalize(passage_embeddings, p=2, dim=1)

batch_scores = torch.einsum('nd,nd->n', query_embeddings, passage_embeddings)

scores.append(batch_scores)


scores = torch.cat(scores)
scores = scores.tolist()

Run().print(f'Returning with {len(scores)} scores')

return scores
54 changes: 54 additions & 0 deletions colbert/distillation/hf_scorers/bge_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from colbert.infra import Run
from colbert.distillation.hf_scorers.base import BaseHFScorer

class BGERerankerScorer(BaseHFScorer):
def __init__(self, queries, collection, model, bsize=32, maxlen=180, query_instruction=None):
super().__init__(queries, collection, model, bsize=bsize, maxlen=maxlen)

self.query_instruction = query_instruction

def score(self, qids, pids, show_progress=False):
tokenizer = AutoTokenizer.from_pretrained(self.model)
model = AutoModelForSequenceClassification.from_pretrained(self.model).cuda()

assert len(qids) == len(pids), (len(qids), len(pids))

scores = []

model.eval()
with torch.inference_mode():
with torch.cuda.amp.autocast():
for offset in tqdm.tqdm(range(0, len(qids), self.bsize), disable=(not show_progress)):
endpos = offset + self.bsize

if self.query_instruction is None:
queries_ = [self.queries[qid] for qid in qids[offset:endpos]]
else:
queries_ = [self.query_instruction + self.queries[qid] for qid in qids[offset:endpos]]

try:
passages_ = [self.collection[pid] for pid in pids[offset:endpos]]
except:
print(pids[offset:endpos])
raise Exception

pairs = [[q,p] for q, p in zip(queries_, passages_)]

features = tokenizer(pairs, padding='longest', truncation=True,
return_tensors='pt', max_length=self.maxlen).to(self.device)

batch_scores = model(**features, return_dict=True).logits.view(-1, ).float()

scores.append(batch_scores)


scores = torch.cat(scores)
scores = scores.tolist()

Run().print(f'Returning with {len(scores)} scores')

return scores
Loading