From 98bbfd1b5cd90fc3643e3c6e9f33dcdfbb265b9e Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Mon, 24 Apr 2023 15:21:34 +0200 Subject: [PATCH 01/12] in-memory collections --- colbert/indexer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colbert/indexer.py b/colbert/indexer.py index 266c4dbf..aebc7234 100644 --- a/colbert/indexer.py +++ b/colbert/indexer.py @@ -55,10 +55,10 @@ def erase(self): return deleted - def index(self, name, collection, overwrite=False): + def index(self, name, collection, collection_name=None, overwrite=False): assert overwrite in [True, False, 'reuse', 'resume'] - self.configure(collection=collection, index_name=name, resume=overwrite=='resume') + self.configure(collection=collection_name or collection, index_name=name, resume=overwrite=='resume') self.configure(bsize=64, partitions=None) self.index_path = self.config.index_path_ From 1bf1f91306b901a29ba88222bb26d55171f7f46b Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Mon, 24 Apr 2023 15:22:12 +0200 Subject: [PATCH 02/12] remove collection loading for searching --- colbert/searcher.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/colbert/searcher.py b/colbert/searcher.py index 6588712d..7ef5cb6b 100644 --- a/colbert/searcher.py +++ b/colbert/searcher.py @@ -33,8 +33,7 @@ def __init__(self, index, checkpoint=None, collection=None, config=None): self.checkpoint_config = ColBERTConfig.load_from_checkpoint(self.checkpoint) self.config = ColBERTConfig.from_existing(self.checkpoint_config, self.index_config, initial_config) - self.collection = Collection.cast(collection or self.config.collection) - self.configure(checkpoint=self.checkpoint, collection=self.collection) + self.configure(checkpoint=self.checkpoint) self.checkpoint = Checkpoint(self.checkpoint, colbert_config=self.config) use_gpu = self.config.total_visible_gpus > 0 From d74de28d28f9bf11e47543b54e00083e8ce80aee Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Wed, 30 Aug 2023 16:32:13 +1000 Subject: [PATCH 03/12] fix config loading in basecolbert --- colbert/modeling/colbert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colbert/modeling/colbert.py b/colbert/modeling/colbert.py index 088513da..17af385e 100644 --- a/colbert/modeling/colbert.py +++ b/colbert/modeling/colbert.py @@ -18,7 +18,7 @@ class ColBERT(BaseColBERT): def __init__(self, name='bert-base-uncased', colbert_config=None): super().__init__(name, colbert_config) - self.use_gpu = colbert_config.total_visible_gpus > 0 + self.use_gpu = self.colbert_config.total_visible_gpus > 0 ColBERT.try_load_torch_extensions(self.use_gpu) From 0e07ca78895133d8a0759aaf5c0f18862855a592 Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Wed, 30 Aug 2023 16:51:02 +1000 Subject: [PATCH 04/12] move transformers auto checking --- colbert/modeling/hf_colbert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colbert/modeling/hf_colbert.py b/colbert/modeling/hf_colbert.py index bc6334c2..68225844 100644 --- a/colbert/modeling/hf_colbert.py +++ b/colbert/modeling/hf_colbert.py @@ -44,9 +44,9 @@ class XLMRobertaPreTrainedModel(RobertaPreTrainedModel): } -transformers_module = dir(transformers) def find_class_names(model_type, class_type): + transformers_module = dir(transformers) model_type = model_type.replace("-", "").lower() for item in transformers_module: if model_type + class_type == item.lower(): From 273becf5813a6b5936b8341f3986c15ce057938e Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Thu, 31 Aug 2023 11:09:04 +1000 Subject: [PATCH 05/12] map pids to collections pids --- colbert/data/collection.py | 9 +++++++-- colbert/indexing/collection_indexer.py | 1 + colbert/indexing/index_saver.py | 6 ++++++ colbert/search/index_loader.py | 7 +++++++ colbert/search/index_storage.py | 1 + 5 files changed, 22 insertions(+), 2 deletions(-) diff --git a/colbert/data/collection.py b/colbert/data/collection.py index d5efc943..469d7c9f 100644 --- a/colbert/data/collection.py +++ b/colbert/data/collection.py @@ -14,7 +14,12 @@ class Collection: def __init__(self, path=None, data=None): self.path = path - self.data = data or self._load_file(path) + data = data or self._load_file(path) + if isinstance(data, dict): + self.pids, self.data = zip(*data.items()) + else: + self.pids = range(len(data)) + self.data = data def __iter__(self): # TODO: If __data isn't there, stream from disk! @@ -88,7 +93,7 @@ def cast(cls, obj): if type(obj) is str: return cls(path=obj) - if type(obj) is list: + if isinstance(obj, dict) or isinstance(obj, list): return cls(data=obj) if type(obj) is cls: diff --git a/colbert/indexing/collection_indexer.py b/colbert/indexing/collection_indexer.py index b99c99ed..b57a292e 100644 --- a/colbert/indexing/collection_indexer.py +++ b/colbert/indexing/collection_indexer.py @@ -375,6 +375,7 @@ def finalize(self): if self.rank > 0: return + self.saver.save_pid_map(self.collection.pids) self._check_all_files_are_saved() self._collect_embedding_id_offset() diff --git a/colbert/indexing/index_saver.py b/colbert/indexing/index_saver.py index 436d130b..6d35447e 100644 --- a/colbert/indexing/index_saver.py +++ b/colbert/indexing/index_saver.py @@ -17,6 +17,12 @@ def __init__(self, config): def save_codec(self, codec): codec.save(index_path=self.config.index_path_) + def save_pid_map(self, pids): + pid_map_path = os.path.join(self.config.index_path_, 'pid.map') + with open(pid_map_path, 'w') as output_pid_map: + for pid in pids: + output_pid_map.write(f'{pid}\n') + def load_codec(self): return ResidualCodec.load(index_path=self.config.index_path_) diff --git a/colbert/search/index_loader.py b/colbert/search/index_loader.py index a8f377e7..a357f993 100644 --- a/colbert/search/index_loader.py +++ b/colbert/search/index_loader.py @@ -16,6 +16,7 @@ def __init__(self, index_path, use_gpu=True): self.use_gpu = use_gpu self._load_codec() + self._load_pid_map() self._load_ivf() self._load_doclens() @@ -24,6 +25,12 @@ def __init__(self, index_path, use_gpu=True): def _load_codec(self): print_message(f"#> Loading codec...") self.codec = ResidualCodec.load(self.index_path) + + def _load_pid_map(self): + pid_map_path = os.path.join(self.index_path, 'pid.map') + with open(pid_map_path) as input_pid_map: + pids = [line.strip() for line in input_pid_map] + self.pids = pids def _load_ivf(self): print_message(f"#> Loading IVF...") diff --git a/colbert/search/index_storage.py b/colbert/search/index_storage.py index 6fd91d35..3ae27020 100644 --- a/colbert/search/index_storage.py +++ b/colbert/search/index_storage.py @@ -87,6 +87,7 @@ def rank(self, config, Q, filter_fn=None): scores_sorter = scores.sort(descending=True) pids, scores = pids[scores_sorter.indices].tolist(), scores_sorter.values.tolist() + pids = [self.pids[pid] for pid in pids] return pids, scores From 2c8eda373da146a04f678ba9e77cf41d85167768 Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Fri, 1 Sep 2023 14:57:14 +1000 Subject: [PATCH 06/12] fix indexing for non consecutive collections --- colbert/data/collection.py | 7 ++++--- colbert/indexing/collection_indexer.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/colbert/data/collection.py b/colbert/data/collection.py index 469d7c9f..83683b53 100644 --- a/colbert/data/collection.py +++ b/colbert/data/collection.py @@ -16,9 +16,10 @@ def __init__(self, path=None, data=None): self.path = path data = data or self._load_file(path) if isinstance(data, dict): - self.pids, self.data = zip(*data.items()) + self.data = list(data.values()) + self.pid_doc_map = data else: - self.pids = range(len(data)) + self.pid_doc_map = {pid: doc for pid, doc in enumerate(data)} self.data = data def __iter__(self): @@ -27,7 +28,7 @@ def __iter__(self): def __getitem__(self, item): # TODO: Load from disk the first time this is called. Unless self.data is already not None. - return self.data[item] + return self.pid_doc_map[item] def __len__(self): # TODO: Load here too. Basically, let's make data a property function and, on first call, either load or get __data. diff --git a/colbert/indexing/collection_indexer.py b/colbert/indexing/collection_indexer.py index b57a292e..4bc67376 100644 --- a/colbert/indexing/collection_indexer.py +++ b/colbert/indexing/collection_indexer.py @@ -375,7 +375,7 @@ def finalize(self): if self.rank > 0: return - self.saver.save_pid_map(self.collection.pids) + self.saver.save_pid_map(self.collection.pid_doc_map) self._check_all_files_are_saved() self._collect_embedding_id_offset() From 60a7da06f25fef28a1426c3177adfdc5ddb2e3e9 Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Tue, 19 Dec 2023 09:51:44 +0100 Subject: [PATCH 07/12] reduce sampling size --- colbert/indexing/collection_indexer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/colbert/indexing/collection_indexer.py b/colbert/indexing/collection_indexer.py index 4bc67376..6c207ab1 100644 --- a/colbert/indexing/collection_indexer.py +++ b/colbert/indexing/collection_indexer.py @@ -103,7 +103,7 @@ def setup(self): self.num_embeddings_est = num_passages * avg_doclen_est self.num_partitions = int(2 ** np.floor(np.log2(16 * np.sqrt(self.num_embeddings_est)))) - Run().print_main(f'Creaing {self.num_partitions:,} partitions.') + Run().print_main(f'Creating {self.num_partitions:,} partitions.') Run().print_main(f'*Estimated* {int(self.num_embeddings_est):,} embeddings.') self._save_plan() @@ -116,13 +116,17 @@ def _sample_pids(self): # So the formula is max(100% * min(total, 100k), 15% * min(total, 1M), ...) # Then we subsample the vectors to 100 * num_partitions - typical_doclen = 120 # let's keep sampling independent of the actual doc_maxlen - sampled_pids = 16 * np.sqrt(typical_doclen * num_passages) - # sampled_pids = int(2 ** np.floor(np.log2(1 + sampled_pids))) + # typical_doclen = 120 # let's keep sampling independent of the actual doc_maxlen + # sampled_pids = 16 * np.sqrt(typical_doclen * num_passages) + # # sampled_pids = int(2 ** np.floor(np.log2(1 + sampled_pids))) + + sampled_pids = np.sqrt(self.config.doc_maxlen * num_passages) sampled_pids = min(1 + int(sampled_pids), num_passages) sampled_pids = random.sample(range(num_passages), sampled_pids) - Run().print_main(f"# of sampled PIDs = {len(sampled_pids)} \t sampled_pids[:3] = {sampled_pids[:3]}") + Run().print_main( + f"# of sampled PIDs = {len(sampled_pids)} \t sampled_pids[:3] = {sampled_pids[:3]}" + ) return set(sampled_pids) From 63c5328b5d0fc523f3aa4866967f6911efd5d7e9 Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Tue, 19 Dec 2023 09:52:08 +0100 Subject: [PATCH 08/12] make assignment a property --- colbert/infra/config/core_config.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/colbert/infra/config/core_config.py b/colbert/infra/config/core_config.py index 7188eaa2..582ff882 100644 --- a/colbert/infra/config/core_config.py +++ b/colbert/infra/config/core_config.py @@ -14,7 +14,7 @@ @dataclass class DefaultVal: val: Any - + def __hash__(self): return hash(repr(self.val)) @@ -28,17 +28,22 @@ def __post_init__(self): Source: https://stackoverflow.com/a/58081120/1493011 """ - self.assigned = {} - for field in fields(self): field_val = getattr(self, field.name) if isinstance(field_val, DefaultVal) or field_val is None: setattr(self, field.name, field.default.val) + @property + def assigned(self): + assigned = {} + + for field in fields(self): + field_val = getattr(self, field.name) if not isinstance(field_val, DefaultVal): - self.assigned[field.name] = True - + assigned[field.name] = True + return assigned + def assign_defaults(self): for field in fields(self): setattr(self, field.name, field.default.val) From 6b7eccb26545f869b955610e834263199f5d2268 Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Tue, 19 Dec 2023 09:52:18 +0100 Subject: [PATCH 09/12] add prune settings --- colbert/infra/config/settings.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/colbert/infra/config/settings.py b/colbert/infra/config/settings.py index d9f1b209..e078e19e 100644 --- a/colbert/infra/config/settings.py +++ b/colbert/infra/config/settings.py @@ -109,6 +109,7 @@ class DocSettings: dim: int = DefaultVal(128) doc_maxlen: int = DefaultVal(220) mask_punctuation: bool = DefaultVal(True) + prune: bool = DefaultVal(False) @dataclass @@ -163,6 +164,8 @@ class IndexingSettings: resume: bool = DefaultVal(False) + prune_threshold: float = DefaultVal(None) + @property def index_path_(self): return self.index_path or os.path.join(self.index_root_, self.index_name) From 5d234ea51977f07a8c84865032c86022a7e11823 Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Tue, 19 Dec 2023 09:52:59 +0100 Subject: [PATCH 10/12] add pruning head --- colbert/modeling/checkpoint.py | 13 +++++-------- colbert/modeling/hf_colbert.py | 35 ++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/colbert/modeling/checkpoint.py b/colbert/modeling/checkpoint.py index 4167749e..d6e96923 100644 --- a/colbert/modeling/checkpoint.py +++ b/colbert/modeling/checkpoint.py @@ -73,16 +73,13 @@ def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogre D, mask = [], [] for D_, mask_ in batches: - D.append(D_) - mask.append(mask_) + D.extend(D_) + mask.extend(mask_) - D, mask = torch.cat(D)[reverse_indices], torch.cat(mask)[reverse_indices] - - doclens = mask.squeeze(-1).sum(-1).tolist() - - D = D.view(-1, self.colbert_config.dim) - D = D[mask.bool().flatten()].cpu() + D = [D[idx][mask[idx][:, 0].bool()] for idx in reverse_indices.tolist()] + doclens = [mask[idx].sum().item() for idx in reverse_indices.tolist()] + D = torch.cat(D).cpu() return (D, doclens, *returned_text) assert keep_dims is False diff --git a/colbert/modeling/hf_colbert.py b/colbert/modeling/hf_colbert.py index 68225844..f59beb34 100644 --- a/colbert/modeling/hf_colbert.py +++ b/colbert/modeling/hf_colbert.py @@ -1,5 +1,8 @@ import importlib +import torch +from turtle import forward from unicodedata import name +from typing import Optional import torch.nn as nn import transformers from transformers import BertPreTrainedModel, BertModel, AutoTokenizer, AutoModel, AutoConfig @@ -54,6 +57,29 @@ def find_class_names(model_type, class_type): return None +class PruningHead(nn.Module): + def __init__(self, hidden_size: int) -> None: + super().__init__() + self.dim = 32 + self.heads = 4 + self.query = nn.Linear(hidden_size, self.dim) + self.key = nn.Linear(hidden_size, self.dim) + self.value = nn.Linear(hidden_size, self.dim) + self.attention = nn.MultiheadAttention(self.dim, self.heads, batch_first=True) + self.linear = nn.Linear(self.dim, 1) + + def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + query = self.query(hidden_states) + key = self.key(hidden_states) + value = self.value(hidden_states) + + if attention_mask is not None: + attention_mask = attention_mask[:, None].expand(-1, attention_mask.shape[-1], -1).repeat(self.heads, 1, 1) + + out = self.attention.forward(query, key, value, attn_mask=attention_mask) + + return self.linear(out[0]) + def class_factory(name_or_path): loadedConfig = AutoConfig.from_pretrained(name_or_path) @@ -92,6 +118,15 @@ def __init__(self, config, colbert_config): self.config = config self.dim = colbert_config.dim self.linear = nn.Linear(config.hidden_size, colbert_config.dim, bias=False) + self.pruning_head = None + if colbert_config.prune: + # self.pruning_head = nn.Sequential( + # nn.Linear(config.hidden_size, config.hidden_size), + # nn.ReLU(), + # nn.Linear(config.hidden_size, 1) + # ) + self.pruning_head = PruningHead(config.hidden_size) + # self.pruning_head = nn.Linear(config.hidden_size, 1) setattr(self,self.base_model_prefix, model_class_object(config)) # if colbert_config.relu: From 3554d14597fde1816c22453d68a1d41ff61d6b6c Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Tue, 19 Dec 2023 09:53:07 +0100 Subject: [PATCH 11/12] remove batch size from indexer --- colbert/indexer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colbert/indexer.py b/colbert/indexer.py index aebc7234..4182cf57 100644 --- a/colbert/indexer.py +++ b/colbert/indexer.py @@ -59,7 +59,7 @@ def index(self, name, collection, collection_name=None, overwrite=False): assert overwrite in [True, False, 'reuse', 'resume'] self.configure(collection=collection_name or collection, index_name=name, resume=overwrite=='resume') - self.configure(bsize=64, partitions=None) + self.configure(partitions=None) self.index_path = self.config.index_path_ index_does_not_exist = (not os.path.exists(self.config.index_path_)) From 7217b24d00a0b1e942674af55dff4bb4f3587bcc Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Wed, 20 Dec 2023 16:49:33 +0100 Subject: [PATCH 12/12] set default for prune threshold --- colbert/infra/config/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colbert/infra/config/settings.py b/colbert/infra/config/settings.py index e078e19e..fbaafe00 100644 --- a/colbert/infra/config/settings.py +++ b/colbert/infra/config/settings.py @@ -164,7 +164,7 @@ class IndexingSettings: resume: bool = DefaultVal(False) - prune_threshold: float = DefaultVal(None) + prune_threshold: float = 0.0 @property def index_path_(self):