Skip to content

Commit 6eb7ff5

Browse files
committed
Minor updates to Baleen
1 parent c0b180a commit 6eb7ff5

File tree

5 files changed

+16
-32
lines changed

5 files changed

+16
-32
lines changed

baleen/condenser/condense.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ def __init__(self, collectionX_path, checkpointL1, checkpointL2, deviceL1='cuda'
2222

2323
def condense(self, query, backs, ranking):
2424
stage1_preds = self._stage1(query, backs, ranking)
25-
stage2_preds = self._stage2(query, stage1_preds)
25+
stage2_preds, stage2_preds_L3x = self._stage2(query, stage1_preds)
2626

27-
return stage1_preds, stage2_preds
27+
return stage1_preds, stage2_preds, stage2_preds_L3x
2828

2929
def _load_model(self, path, device):
3030
model = torch.load(path, map_location='cpu')
@@ -128,8 +128,14 @@ def _stage2(self, query, preds):
128128

129129
preds = [(score, (pid, sid)) for (pid, sid), score in zip(preds, scores)]
130130
preds = sorted(preds, reverse=True)[:5]
131+
132+
preds_L3x = [x for score, x in preds if score > min(0, preds[1][0] - 1e-10)] # Take at least 2!
131133
preds = [x for score, x in preds if score > 0]
132134

133-
# TODO: Apply L3x for final stage.
134-
135-
return preds
135+
earliest_pids = f7([pid for pid, _ in preds_L3x])[:4] # Take at most 4 docs.
136+
preds_L3x = [(pid, sid) for pid, sid in preds_L3x if pid in earliest_pids]
137+
138+
assert len(preds_L3x) >= 2
139+
assert len(f7([pid for pid, _ in preds_L3x])) <= 4
140+
141+
return preds, preds_L3x

baleen/engine.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ def search(self, query, num_hops, depth=100, verbose=False):
3636
if len(pids_bag) < k * (hop_idx+1):
3737
pids_bag.add(pid)
3838

39-
stage1_preds, facts = condenser.condense(query, backs=facts, ranking=ranking_)
39+
stage1_preds, facts, stage2_L3x = condenser.condense(query, backs=facts, ranking=ranking_)
4040
context = ' [SEP] '.join([collectionX.get((pid, sid), '') for pid, sid in facts])
4141

4242
assert len(pids_bag) == depth
4343

44-
return facts, pids_bag, stage1_preds
44+
return stage2_L3x, pids_bag, stage1_preds
4545

4646

4747

colbert/indexing/collection_indexer.py

-15
Original file line numberDiff line numberDiff line change
@@ -299,26 +299,11 @@ def _collect_embedding_id_offset(self):
299299
assert len(self.embedding_offsets) == self.num_chunks
300300

301301
def _build_ivf(self):
302-
# TODO: If this is slow or memory intensive, it can be done as a torch.sort, torch.add offset, torch.unique
303-
# operations over the concatenated codes.
304-
# On MS MARCO, this seems to take 10 minutes! I can imagine that's 40 minutes on Wikipedia.
305-
306-
307-
"""
308-
codes = ResidualCodec.Embeddings.load_all_codes(index_path)
309-
310-
codes.sort().{values, indices}
311-
312-
values.unique_consecutive -> counts -> cumsum -> offsets
313-
(indices.int32(), offsets)
314-
315-
316302
# Maybe we should several small IVFs? Every 250M embeddings, so that's every 1 GB.
317303
# It would save *memory* here and *disk space* regarding the int64.
318304
# But we'd have to decide how many IVFs to use during retrieval: many (loop) or one?
319305
# A loop seems nice if we can find a size that's large enough for speed yet small enough to fit on GPU!
320306
# Then it would help nicely for batching later: 1GB.
321-
"""
322307

323308
codes = torch.empty(self.num_embeddings,)
324309
print_memory_stats(f'RANK:{self.rank}')

colbert/utilities/annotate_em.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,8 @@ def save(self, new_path):
106106

107107

108108
if __name__ == '__main__':
109-
r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.50.02/ranking.tsv'
110-
r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.59.37/ranking.tsv'
111-
r = sys.argv[1]
109+
r = sys.argv[2]
112110

113-
a = AnnotateEM(collection='/future/u/okhattab/root/unit/data/NQ-mini/collection.tsv',
114-
qas='/future/u/okhattab/root/unit/data/NQ-mini/dev/qas.json')
111+
a = AnnotateEM(collection='/dfs/scratch0/okhattab/OpenQA/collection.tsv',
112+
qas=sys.argv[1])
115113
a.annotate(ranking=r)

colbert/utils/amp.py

-5
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,9 @@
33
from contextlib import contextmanager
44
from colbert.utils.utils import NullContextManager
55

6-
PyTorch_over_1_6 = True # float('.'.join(torch.__version__.split('.')[0:2])) >= 1.6
7-
86

97
class MixedPrecisionManager():
108
def __init__(self, activated):
11-
assert (not activated) or PyTorch_over_1_6, "Cannot use AMP for PyTorch version < 1.6"
12-
139
self.activated = activated
1410

1511
if self.activated:
@@ -22,7 +18,6 @@ def backward(self, loss):
2218
if self.activated:
2319
self.scaler.scale(loss).backward()
2420
else:
25-
assert False, "for now"
2621
loss.backward()
2722

2823
def step(self, colbert, optimizer, scheduler=None):

0 commit comments

Comments
 (0)