Skip to content

Commit f232bde

Browse files
committed
2 parents cce4e89 + 3df30ae commit f232bde

File tree

6 files changed

+8792
-13
lines changed

6 files changed

+8792
-13
lines changed

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
* (1/29/23) We have merged a new index updater feature and support for additional Hugging Face models! These are in beta so please give us feedback as you try them out.
44
* (1/24/23) If you're looking for the **DSP** framework for composing ColBERTv2 and LLMs, it's at: https://github.com/stanfordnlp/dsp
55

6+
[<img align="center" src="https://colab.research.google.com/assets/colab-badge.svg" />](https://colab.research.google.com/github/stanford-futuredata/ColBERT/blob/main/docs/intro2new.ipynb)
7+
68
# ColBERT (v2)
79

810
### ColBERT is a _fast_ and _accurate_ retrieval model, enabling scalable BERT-based search over large text collections in tens of milliseconds.
@@ -66,6 +68,8 @@ Below, we illustrate these steps via an example run on the MS MARCO Passage Rank
6668

6769
## API Usage Notebook
6870

71+
**NEW**: We have an experimental notebook on [Google Colab](https://colab.research.google.com/github/stanford-futuredata/ColBERT/blob/main/docs/intro2new.ipynb) that you can use with free GPUs. Indexing 10,000 on the free Colab T4 GPU takes six minutes.
72+
6973
This Jupyter notebook **[docs/intro.ipynb notebook](docs/intro.ipynb)** illustrates using the key features of ColBERT with the new Python API.
7074

7175
It includes how to download the ColBERTv2 model checkpoint trained on MS MARCO Passage Ranking and how to download our new LoTTE benchmark.

colbert/modeling/checkpoint.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ def doc(self, *args, to_cpu=False, **kw_args):
4040

4141
return D
4242

43-
def queryFromText(self, queries, bsize=None, to_cpu=False, context=None):
43+
def queryFromText(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False):
4444
if bsize:
45-
batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize)
45+
batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
4646
batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
4747
return torch.cat(batches)
4848

49-
input_ids, attention_mask = self.query_tokenizer.tensorize(queries, context=context)
49+
input_ids, attention_mask = self.query_tokenizer.tensorize(queries, context=context, full_length_search=full_length_search)
5050
return self.query(input_ids, attention_mask)
5151

5252
def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False):

colbert/modeling/tokenization/query_tokenization.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,29 @@ def encode(self, batch_text, add_special_tokens=False):
4848

4949
return ids
5050

51-
def tensorize(self, batch_text, bsize=None, context=None):
51+
def tensorize(self, batch_text, bsize=None, context=None, full_length_search=False):
5252
assert type(batch_text) in [list, tuple], (type(batch_text))
5353

5454
# add placehold for the [Q] marker
5555
batch_text = ['. ' + x for x in batch_text]
5656

57+
# Full length search is only available for single inference (for now)
58+
# Batched full length search requires far deeper changes to the code base
59+
assert(full_length_search == False or (type(batch_text) == list and len(batch_text) == 1))
60+
61+
if full_length_search:
62+
# Tokenize each string in the batch
63+
un_truncated_ids = self.tok(batch_text, add_special_tokens=False)['input_ids']
64+
# Get the longest length in the batch
65+
max_length_in_batch = max(len(x) for x in un_truncated_ids)
66+
# Set the max length
67+
max_length = self.max_len(max_length_in_batch)
68+
else:
69+
# Max length is the default max length from the config
70+
max_length = self.query_maxlen
71+
5772
obj = self.tok(batch_text, padding='max_length', truncation=True,
58-
return_tensors='pt', max_length=self.query_maxlen)
73+
return_tensors='pt', max_length=max_length)
5974

6075
ids, mask = obj['input_ids'], obj['attention_mask']
6176

@@ -95,3 +110,7 @@ def tensorize(self, batch_text, bsize=None, context=None):
95110
print()
96111

97112
return ids, mask
113+
114+
# Ensure that query_maxlen <= length <= 500 tokens
115+
def max_len(self, length):
116+
return min(500, max(self.query_maxlen, length))

colbert/searcher.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,24 @@ def __init__(self, index, checkpoint=None, collection=None, config=None):
4646
def configure(self, **kw_args):
4747
self.config.configure(**kw_args)
4848

49-
def encode(self, text: TextQueries):
49+
def encode(self, text: TextQueries, full_length_search=False):
5050
queries = text if type(text) is list else [text]
5151
bsize = 128 if len(queries) > 128 else None
5252

5353
self.checkpoint.query_tokenizer.query_maxlen = self.config.query_maxlen
54-
Q = self.checkpoint.queryFromText(queries, bsize=bsize, to_cpu=True)
54+
Q = self.checkpoint.queryFromText(queries, bsize=bsize, to_cpu=True, full_length_search=full_length_search)
5555

5656
return Q
5757

58-
def search(self, text: str, k=10, filter_fn=None):
59-
Q = self.encode(text)
58+
def search(self, text: str, k=10, filter_fn=None, full_length_search=False):
59+
Q = self.encode(text, full_length_search=full_length_search)
6060
return self.dense_search(Q, k, filter_fn=filter_fn)
6161

62-
def search_all(self, queries: TextQueries, k=10, filter_fn=None):
62+
def search_all(self, queries: TextQueries, k=10, filter_fn=None, full_length_search=False):
6363
queries = Queries.cast(queries)
6464
queries_ = list(queries.values())
6565

66-
Q = self.encode(queries_)
66+
Q = self.encode(queries_, full_length_search=full_length_search)
6767

6868
return self._search_all_Q(queries, Q, k, filter_fn=filter_fn)
6969

0 commit comments

Comments
 (0)