Skip to content

Commit 397988d

Browse files
author
Omar Khattab
committed
Add colbert/data
1 parent 4120feb commit 397988d

File tree

7 files changed

+433
-4
lines changed

7 files changed

+433
-4
lines changed

.gitignore

+12-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
experiments/
2-
checkpoints/
3-
data/
4-
logs/
1+
/experiments/
2+
/checkpoints/
3+
/data/
4+
/logs/
5+
/mlruns/
6+
/profiler/
7+
/logs/
58

69
# Byte-compiled / optimized / DLL files
710
__pycache__/
@@ -10,6 +13,11 @@ __pycache__/
1013

1114
# Jupyter Notebook
1215
.ipynb_checkpoints
16+
# notebooks/
1317

1418
# mac
1519
.DS_Store
20+
21+
# Other
22+
.vscode
23+
*.tsv

colbert/data/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .collection import *
2+
from .queries import *
3+
4+
from .ranking import *
5+
from .examples import *

colbert/data/collection.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
2+
# Could be .tsv or .json. The latter always allows more customization via optional parameters.
3+
# I think it could be worth doing some kind of parallel reads too, if the file exceeds 1 GiBs.
4+
# Just need to use a datastructure that shares things across processes without too much pickling.
5+
# I think multiprocessing.Manager can do that!
6+
7+
import os
8+
import itertools
9+
10+
from colbert.evaluation.loaders import load_collection
11+
from colbert.infra.run import Run
12+
13+
14+
class Collection:
15+
def __init__(self, path=None, data=None):
16+
self.path = path
17+
self.data = data or self._load_file(path)
18+
19+
def __iter__(self):
20+
# TODO: If __data isn't there, stream from disk!
21+
return self.data.__iter__()
22+
23+
def __getitem__(self, item):
24+
# TODO: Load from disk the first time this is called. Unless self.data is already not None.
25+
return self.data[item]
26+
27+
def __len__(self):
28+
# TODO: Load here too. Basically, let's make data a property function and, on first call, either load or get __data.
29+
return len(self.data)
30+
31+
def _load_file(self, path):
32+
self.path = path
33+
return self._load_tsv(path) if path.endswith('.tsv') else self._load_jsonl(path)
34+
35+
def _load_tsv(self, path):
36+
return load_collection(path)
37+
38+
def _load_jsonl(self, path):
39+
raise NotImplementedError()
40+
41+
def provenance(self):
42+
return self.path
43+
44+
def save(self, new_path):
45+
assert new_path.endswith('.tsv'), "TODO: Support .json[l] too."
46+
assert not os.path.exists(new_path), new_path
47+
48+
with Run().open(new_path, 'w') as f:
49+
# TODO: expects content to always be a string here; no separate title!
50+
for pid, content in enumerate(self.data):
51+
content = f'{pid}\t{content}\n'
52+
f.write(content)
53+
54+
return f.name
55+
56+
def enumerate(self, rank):
57+
for _, offset, passages in self.enumerate_batches(rank=rank):
58+
for idx, passage in enumerate(passages):
59+
yield (offset + idx, passage)
60+
61+
def enumerate_batches(self, rank, chunksize=None):
62+
assert rank is not None, "TODO: Add support for the rank=None case."
63+
64+
chunksize = chunksize or self.get_chunksize()
65+
66+
offset = 0
67+
iterator = iter(self)
68+
69+
for chunk_idx, owner in enumerate(itertools.cycle(range(Run().nranks))):
70+
L = [line for _, line in zip(range(chunksize), iterator)]
71+
72+
if len(L) > 0 and owner == rank:
73+
yield (chunk_idx, offset, L)
74+
75+
offset += len(L)
76+
77+
if len(L) < chunksize:
78+
return
79+
80+
def get_chunksize(self):
81+
return min(25_000, 1 + len(self) // Run().nranks) # 25k is great, 10k allows things to reside on GPU??
82+
83+
@classmethod
84+
def cast(cls, obj):
85+
if type(obj) is str:
86+
return cls(path=obj)
87+
88+
if type(obj) is list:
89+
return cls(data=obj)
90+
91+
if type(obj) is cls:
92+
return obj
93+
94+
assert False, f"obj has type {type(obj)} which is not compatible with cast()"
95+
96+
97+
# TODO: Look up path in some global [per-thread or thread-safe] list.

colbert/data/dataset.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
3+
# Not just the corpus, but also an arbitrary number of query sets, indexed by name in a dictionary/dotdict.
4+
# And also query sets with top-k PIDs.
5+
# QAs too? TripleSets too?
6+
7+
8+
class Dataset:
9+
def __init__(self):
10+
pass
11+
12+
def select(self, key):
13+
# Select the {corpus, queryset, tripleset, rankingset} determined by uniqueness or by key and return a "unique" dataset (e.g., for key=train)
14+
pass

colbert/data/examples.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from colbert.infra.run import Run
2+
import os
3+
import ujson
4+
5+
from colbert.utils.utils import print_message
6+
7+
8+
class Examples:
9+
def __init__(self, path=None, data=None):
10+
self.path = path
11+
self.data = data or self._load_file(path)
12+
13+
def provenance(self):
14+
return self.path
15+
16+
def _load_file(self, path):
17+
examples = []
18+
19+
with open(path) as f:
20+
for line in f:
21+
examples.append(ujson.loads(line))
22+
23+
return examples
24+
25+
26+
def tolist(self, rank=None, nranks=None):
27+
"""
28+
NOTE: For distributed sampling, this isn't equivalent to perfectly uniform sampling.
29+
In particular, each subset is perfectly represented in every batch! However, since we never
30+
repeat passes over the data, we never repeat any particular triple, and the split across
31+
nodes is random (since the underlying file is pre-shuffled), there's no concern here.
32+
"""
33+
34+
if rank or nranks:
35+
assert rank in range(nranks), (rank, nranks)
36+
return [self.data[idx] for idx in range(0, len(self.data), nranks)] # if line_idx % nranks == rank
37+
38+
return list(self.data)
39+
40+
def save(self, new_path):
41+
assert 'json' in new_path.strip('/').split('/')[-1].split('.'), "TODO: Support .json[l] too."
42+
43+
print_message(f"#> Writing {len(self.data) / 1000_000.0}M examples to {new_path}")
44+
45+
with Run().open(new_path, 'w') as f:
46+
for example in self.data:
47+
ujson.dump(example, f)
48+
f.write('\n')
49+
50+
return f.name
51+
# print_message(f"#> Saved ranking of {len(self.data)} queries and {len(self.flat_ranking)} lines to {new_path}")
52+
53+
@classmethod
54+
def cast(cls, obj):
55+
if type(obj) is str:
56+
return cls(path=obj)
57+
58+
if isinstance(obj, list):
59+
return cls(data=obj)
60+
61+
if type(obj) is cls:
62+
return obj
63+
64+
assert False, f"obj has type {type(obj)} which is not compatible with cast()"

colbert/data/queries.py

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from colbert.infra.run import Run
2+
import os
3+
import ujson
4+
5+
from colbert.evaluation.loaders import load_queries
6+
7+
# TODO: Look up path in some global [per-thread or thread-safe] list.
8+
# TODO: path could be a list of paths...? But then how can we tell it's not a list of queries..
9+
10+
11+
class Queries:
12+
def __init__(self, path=None, data=None):
13+
self.path = path
14+
15+
if data:
16+
assert isinstance(data, dict), type(data)
17+
self._load_data(data) or self._load_file(path)
18+
19+
def __len__(self):
20+
return len(self.data)
21+
22+
def __iter__(self):
23+
return iter(self.data.items())
24+
25+
def provenance(self):
26+
return self.path
27+
28+
def _load_data(self, data):
29+
if data is None:
30+
return None
31+
32+
self.data = {}
33+
self._qas = {}
34+
35+
for qid, content in data.items():
36+
if isinstance(content, dict):
37+
self.data[qid] = content['question']
38+
self._qas[qid] = content
39+
else:
40+
self.data[qid] = content
41+
42+
if len(self._qas) == 0:
43+
del self._qas
44+
45+
return True
46+
47+
def _load_file(self, path):
48+
if path.endswith('.tsv'):
49+
self.data = load_queries(path)
50+
return True
51+
52+
# Load QAs
53+
self.data = {}
54+
self._qas = {}
55+
56+
with open(path) as f:
57+
for line in f:
58+
qa = ujson.loads(line)
59+
60+
assert qa['qid'] not in self.data
61+
self.data[qa['qid']] = qa['question']
62+
self._qas[qa['qid']] = qa
63+
64+
return self.data
65+
66+
def qas(self):
67+
return dict(self._qas)
68+
69+
def __getitem__(self, key):
70+
return self.data[key]
71+
72+
def keys(self):
73+
return self.data.keys()
74+
75+
def values(self):
76+
return self.data.values()
77+
78+
def items(self):
79+
return self.data.items()
80+
81+
def save(self, new_path):
82+
assert new_path.endswith('.tsv')
83+
assert not os.path.exists(new_path), new_path
84+
85+
with Run().open(new_path, 'w') as f:
86+
for qid, content in self.data.items():
87+
content = f'{qid}\t{content}\n'
88+
f.write(content)
89+
90+
return f.name
91+
92+
def save_qas(self, new_path):
93+
assert new_path.endswith('.json')
94+
assert not os.path.exists(new_path), new_path
95+
96+
with open(new_path, 'w') as f:
97+
for qid, qa in self._qas.items():
98+
qa['qid'] = qid
99+
f.write(ujson.dumps(qa) + '\n')
100+
101+
def _load_tsv(self, path):
102+
raise NotImplementedError
103+
104+
def _load_jsonl(self, path):
105+
raise NotImplementedError
106+
107+
@classmethod
108+
def cast(cls, obj):
109+
if type(obj) is str:
110+
return cls(path=obj)
111+
112+
if isinstance(obj, dict) or isinstance(obj, list):
113+
return cls(data=obj)
114+
115+
if type(obj) is cls:
116+
return obj
117+
118+
assert False, f"obj has type {type(obj)} which is not compatible with cast()"
119+
120+
121+
# class QuerySet:
122+
# def __init__(self, *paths, renumber=False):
123+
# self.paths = paths
124+
# self.original_queries = [load_queries(path) for path in paths]
125+
126+
# if renumber:
127+
# self.queries = flatten([q.values() for q in self.original_queries])
128+
# self.queries = {idx: text for idx, text in enumerate(self.queries)}
129+
130+
# else:
131+
# self.queries = {}
132+
133+
# for queries in self.original_queries:
134+
# assert len(set.intersection(set(queries.keys()), set(self.queries.keys()))) == 0, \
135+
# "renumber=False requires non-overlapping query IDs"
136+
137+
# self.queries.update(queries)
138+
139+
# assert len(self.queries) == sum(map(len, self.original_queries))
140+
141+
# def todict(self):
142+
# return dict(self.queries)
143+
144+
# def tolist(self):
145+
# return list(self.queries.values())
146+
147+
# def query_sets(self):
148+
# return self.original_queries
149+
150+
# def split_rankings(self, rankings):
151+
# assert type(rankings) is list
152+
# assert len(rankings) == len(self.queries)
153+
154+
# sub_rankings = []
155+
# offset = 0
156+
# for source in self.original_queries:
157+
# sub_rankings.append(rankings[offset:offset+len(source)])
158+
# offset += len(source)
159+
160+
# return sub_rankings

0 commit comments

Comments
 (0)