Skip to content

Commit 4120feb

Browse files
committed
Initial commit with the new API and residual compression
1 parent c4e79e8 commit 4120feb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+6583
-0
lines changed

.vscode/settings.json

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"jupyter.jupyterServerType": "local",
3+
"python.formatting.autopep8Args": ["--max-line-length", "120"],
4+
}

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2019, 2020 Stanford Future Data Systems
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# ColBERT
2+
3+
### ColBERT is a _fast_ and _accurate_ retrieval model, enabling scalable BERT-based search over large text collections in tens of milliseconds.
4+
5+
6+
<p align="center">
7+
<img align="center" src="docs/images/ColBERT-Framework-MaxSim-W370px.png" />
8+
</p>
9+
<p align="center">
10+
<b>Figure 1:</b> ColBERT's late interaction, efficiently scoring the fine-grained similarity between a queries and a passage.
11+
</p>
12+
13+
As Figure 1 illustrates, ColBERT relies on fine-grained **contextual late interaction**: it encodes each passage into a **matrix** of token-level embeddings (shown above in blue). Then at search time, it embeds every query into another matrix (shown in green) and efficiently finds passages that contextually match the query using scalable vector-similarity (`MaxSim`) operators.
14+
15+
These rich interactions allow ColBERT to surpass the quality of _single-vector_ representation models, while scaling efficiently to large corpora. You can read more in our papers:
16+
17+
* [**ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT**](https://arxiv.org/abs/2004.12832) (SIGIR'20).
18+
* [**Relevance-guided Supervision for OpenQA with ColBERT**](https://arxiv.org/abs/2007.00814) (TACL'21).
19+
20+
21+
----
22+
23+
## Installation
24+
25+
ColBERT (currently: [v0.4.6](#releases)) requires Python 3.7+ and Pytorch 1.9+ and uses the [HuggingFace Transformers](https://github.com/huggingface/transformers) library.
26+
27+
We strongly recommend creating a conda environment using the commands below. (If you don't have conda, follow the official [conda installation guide](https://docs.anaconda.com/anaconda/install/linux/#installation).)
28+
29+
```
30+
conda env create -f conda_env.yml
31+
conda activate colbert-v0.4.2
32+
```
33+
34+
If you face any problems, please [open a new issue](https://github.com/stanford-futuredata/ColBERT/issues) and we'll help you promptly!
35+
36+
37+
## NEW: API Usage Notebook
38+
39+
This Jupyter **[docs/intro.ipynb notebook](docs/intro.ipynb)** illustrates using the key features of ColBERT with the new Python API.
40+

colbert/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .trainer import Trainer
2+
from .indexer import Indexer
3+
from .searcher import Searcher
4+

colbert/evaluation/__init__.py

Whitespace-only changes.

colbert/evaluation/load_model.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os
2+
import ujson
3+
import torch
4+
import random
5+
6+
from collections import defaultdict, OrderedDict
7+
8+
from colbert.parameters import DEVICE
9+
from colbert.modeling.colbert import ColBERT
10+
from colbert.utils.utils import print_message, load_checkpoint
11+
12+
13+
def load_model(args, do_print=True):
14+
colbert = ColBERT.from_pretrained('bert-base-uncased',
15+
query_maxlen=args.query_maxlen,
16+
doc_maxlen=args.doc_maxlen,
17+
dim=args.dim,
18+
similarity_metric=args.similarity,
19+
mask_punctuation=args.mask_punctuation)
20+
colbert = colbert.to(DEVICE)
21+
22+
print_message("#> Loading model checkpoint.", condition=do_print)
23+
24+
checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print)
25+
26+
colbert.eval()
27+
28+
return colbert, checkpoint

colbert/evaluation/loaders.py

+196
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import os
2+
import ujson
3+
import torch
4+
import random
5+
6+
from collections import defaultdict, OrderedDict
7+
8+
from colbert.parameters import DEVICE
9+
from colbert.modeling.colbert import ColBERT
10+
from colbert.utils.utils import print_message, load_checkpoint
11+
from colbert.evaluation.load_model import load_model
12+
from colbert.utils.runs import Run
13+
14+
15+
def load_queries(queries_path):
16+
queries = OrderedDict()
17+
18+
print_message("#> Loading the queries from", queries_path, "...")
19+
20+
with open(queries_path) as f:
21+
for line in f:
22+
qid, query, *_ = line.strip().split('\t')
23+
qid = int(qid)
24+
25+
assert (qid not in queries), ("Query QID", qid, "is repeated!")
26+
queries[qid] = query
27+
28+
print_message("#> Got", len(queries), "queries. All QIDs are unique.\n")
29+
30+
return queries
31+
32+
33+
def load_qrels(qrels_path):
34+
if qrels_path is None:
35+
return None
36+
37+
print_message("#> Loading qrels from", qrels_path, "...")
38+
39+
qrels = OrderedDict()
40+
with open(qrels_path, mode='r', encoding="utf-8") as f:
41+
for line in f:
42+
qid, x, pid, y = map(int, line.strip().split('\t'))
43+
assert x == 0 and y == 1
44+
qrels[qid] = qrels.get(qid, [])
45+
qrels[qid].append(pid)
46+
47+
assert all(len(qrels[qid]) == len(set(qrels[qid])) for qid in qrels)
48+
49+
avg_positive = round(sum(len(qrels[qid]) for qid in qrels) / len(qrels), 2)
50+
51+
print_message("#> Loaded qrels for", len(qrels), "unique queries with",
52+
avg_positive, "positives per query on average.\n")
53+
54+
return qrels
55+
56+
57+
def load_topK(topK_path):
58+
queries = OrderedDict()
59+
topK_docs = OrderedDict()
60+
topK_pids = OrderedDict()
61+
62+
print_message("#> Loading the top-k per query from", topK_path, "...")
63+
64+
with open(topK_path) as f:
65+
for line_idx, line in enumerate(f):
66+
if line_idx and line_idx % (10*1000*1000) == 0:
67+
print(line_idx, end=' ', flush=True)
68+
69+
qid, pid, query, passage = line.split('\t')
70+
qid, pid = int(qid), int(pid)
71+
72+
assert (qid not in queries) or (queries[qid] == query)
73+
queries[qid] = query
74+
topK_docs[qid] = topK_docs.get(qid, [])
75+
topK_docs[qid].append(passage)
76+
topK_pids[qid] = topK_pids.get(qid, [])
77+
topK_pids[qid].append(pid)
78+
79+
print()
80+
81+
assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
82+
83+
Ks = [len(topK_pids[qid]) for qid in topK_pids]
84+
85+
print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
86+
print_message("#> Loaded the top-k per query for", len(queries), "unique queries.\n")
87+
88+
return queries, topK_docs, topK_pids
89+
90+
91+
def load_topK_pids(topK_path, qrels):
92+
topK_pids = defaultdict(list)
93+
topK_positives = defaultdict(list)
94+
95+
print_message("#> Loading the top-k PIDs per query from", topK_path, "...")
96+
97+
with open(topK_path) as f:
98+
for line_idx, line in enumerate(f):
99+
if line_idx and line_idx % (10*1000*1000) == 0:
100+
print(line_idx, end=' ', flush=True)
101+
102+
qid, pid, *rest = line.strip().split('\t')
103+
qid, pid = int(qid), int(pid)
104+
105+
topK_pids[qid].append(pid)
106+
107+
assert len(rest) in [1, 2, 3]
108+
109+
if len(rest) > 1:
110+
*_, label = rest
111+
label = int(label)
112+
assert label in [0, 1]
113+
114+
if label >= 1:
115+
topK_positives[qid].append(pid)
116+
117+
print()
118+
119+
assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
120+
assert all(len(topK_positives[qid]) == len(set(topK_positives[qid])) for qid in topK_positives)
121+
122+
# Make them sets for fast lookups later
123+
topK_positives = {qid: set(topK_positives[qid]) for qid in topK_positives}
124+
125+
Ks = [len(topK_pids[qid]) for qid in topK_pids]
126+
127+
print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
128+
print_message("#> Loaded the top-k per query for", len(topK_pids), "unique queries.\n")
129+
130+
if len(topK_positives) == 0:
131+
topK_positives = None
132+
else:
133+
assert len(topK_pids) >= len(topK_positives)
134+
135+
for qid in set.difference(set(topK_pids.keys()), set(topK_positives.keys())):
136+
topK_positives[qid] = []
137+
138+
assert len(topK_pids) == len(topK_positives)
139+
140+
avg_positive = round(sum(len(topK_positives[qid]) for qid in topK_positives) / len(topK_pids), 2)
141+
142+
print_message("#> Concurrently got annotations for", len(topK_positives), "unique queries with",
143+
avg_positive, "positives per query on average.\n")
144+
145+
assert qrels is None or topK_positives is None, "Cannot have both qrels and an annotated top-K file!"
146+
147+
if topK_positives is None:
148+
topK_positives = qrels
149+
150+
return topK_pids, topK_positives
151+
152+
153+
def load_collection(collection_path):
154+
print_message("#> Loading collection...")
155+
156+
collection = []
157+
158+
with open(collection_path) as f:
159+
for line_idx, line in enumerate(f):
160+
if line_idx % (1000*1000) == 0:
161+
print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True)
162+
163+
pid, passage, *rest = line.strip().split('\t')
164+
assert pid == 'id' or int(pid) == line_idx
165+
166+
if len(rest) >= 1:
167+
title = rest[0]
168+
passage = title + ' | ' + passage
169+
170+
collection.append(passage)
171+
172+
print()
173+
174+
return collection
175+
176+
177+
def load_colbert(args, do_print=True):
178+
colbert, checkpoint = load_model(args, do_print)
179+
180+
# TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used.
181+
# I.e., not their purely (i.e., training) default values.
182+
183+
for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']:
184+
if 'arguments' in checkpoint and hasattr(args, k):
185+
if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k):
186+
a, b = checkpoint['arguments'][k], getattr(args, k)
187+
Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})")
188+
189+
if 'arguments' in checkpoint:
190+
if args.rank < 1:
191+
print(ujson.dumps(checkpoint['arguments'], indent=4))
192+
193+
if do_print:
194+
print('\n')
195+
196+
return colbert, checkpoint

0 commit comments

Comments
 (0)