Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2.5 training #361

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions colbert/infra/config/settings.py
Original file line number Diff line number Diff line change
@@ -119,6 +119,10 @@ class QuerySettings:
query_maxlen: int = DefaultVal(32)
attend_to_mask_tokens: bool = DefaultVal(False)
interaction: str = DefaultVal("colbert")
# V2.5
cap_padding: int = DefaultVal(0)
dynamic_query_maxlen: bool = DefaultVal(False)
dynamic_querylen_multiples: int = DefaultVal(32)


@dataclass
@@ -156,6 +160,35 @@ class TrainingSettings:

model_name: str = DefaultVal(None) # DefaultVal('bert-base-uncased')

# V2.5

schedule_free: bool = DefaultVal(False)

schedule_free_wd: float = DefaultVal(0.0)

kldiv_loss: bool = DefaultVal(True)

marginmse_loss: bool = DefaultVal(False)

kldiv_weight: float = DefaultVal(1.0)

marginmse_weight: float = DefaultVal(0.05)

ib_loss_weight: float = DefaultVal(1.0)

normalise_training_scores: bool = DefaultVal(False)

# Can be 'minmax', 'querylen'
normalization_method: str = DefaultVal("minmax")

# TODO

quant_aware: bool = DefaultVal(False)

highest_quant_level: int = DefaultVal(8)

lowest_quant_level: int = DefaultVal(2)


@dataclass
class IndexingSettings:
191 changes: 191 additions & 0 deletions colbert/modeling/tokenization/query_tokenization.py
Original file line number Diff line number Diff line change
@@ -77,9 +77,30 @@ def tensorize(self, batch_text, bsize=None, context=None, full_length_search=Fal
ids, mask = obj['input_ids'], obj['attention_mask']

# postprocess for the [Q] marker and the [MASK] augmentation
# Log original size
ids[:, 1] = self.Q_marker_token_id
unpadded_sizes = (ids != self.pad_token_id).sum(dim=1)
# Log original sizes
original_sizes = unpadded_sizes.clone()
ids[ids == self.pad_token_id] = self.mask_token_id

# Shorten ids and mask if necessary
if self.config.cap_padding > 0:
for i in range(ids.size(0)):
unpadded_size = unpadded_sizes[i].item()
# Add 8 to the query size itself, per query
max_allowed_length = unpadded_size + self.config.cap_padding
if ids.size(1) > max_allowed_length:
ids[i, max_allowed_length:] = self.pad_token_id
mask[i, max_allowed_length:] = 0
# Trim the batch to the maximum allowed length across all queries
max_length = max(unpadded_size + self.config.cap_padding for unpadded_size in unpadded_sizes)
max_length = min(max_length, ids.size(1))
ids = ids[:, :max_length]
mask = mask[:, :max_length]
# Note: This implementation already adds 8 (or the value of cap_padding) to each query individually


if context is not None:
assert len(context) == len(batch_text), (len(context), len(batch_text))

@@ -116,3 +137,173 @@ def tensorize(self, batch_text, bsize=None, context=None, full_length_search=Fal
# Ensure that query_maxlen <= length <= 500 tokens
def max_len(self, length):
return min(500, max(self.query_maxlen, length))


import torch
# import math

# from colbert.modeling.hf_colbert import class_factory
# from colbert.infra import ColBERTConfig
# from colbert.modeling.tokenization.utils import _split_into_batches
# from colbert.utils.utils import batch
# from colbert.parameters import DEVICE


# class QueryTokenizer():
# def __init__(self, config: ColBERTConfig, verbose: int = 3):
# HF_ColBERT = class_factory(config.checkpoint)
# self.tok = HF_ColBERT.raw_tokenizer_from_pretrained(config.checkpoint)
# self.verbose = verbose

# self.config = config
# self.query_maxlen = config.query_maxlen
# self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable

# self.Q_marker_token, self.Q_marker_token_id = config.query_token, self.tok.convert_tokens_to_ids(config.query_token_id)
# self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
# self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
# self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id
# self.pad_token,self.pad_token_id = self.tok.pad_token,self.tok.pad_token_id
# self.used = False

# def tokenize(self, batch_text, add_special_tokens=False):
# assert type(batch_text) in [list, tuple], (type(batch_text))

# tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

# if not add_special_tokens:
# return tokens

# prefix, suffix = [self.cls_token, self.Q_marker_token], [self.sep_token]
# tokens = [prefix + lst + suffix + [self.mask_token] * (self.query_maxlen - (len(lst)+3)) for lst in tokens]

# return tokens

# def encode(self, batch_text, add_special_tokens=False):
# assert type(batch_text) in [list, tuple], (type(batch_text))

# ids = self.tok(batch_text, add_special_tokens=False).to(DEVICE)['input_ids']

# if not add_special_tokens:
# return ids

# prefix, suffix = [self.cls_token_id, self.Q_marker_token_id], [self.sep_token_id]
# ids = [prefix + lst + suffix + [self.mask_token_id] * (self.query_maxlen - (len(lst)+3)) for lst in ids]

# return ids

# def tensorize(self, batch_text, bsize=None, context=None, full_length_search=False):
# assert type(batch_text) in [list, tuple], (type(batch_text))

# # add placehold for the [Q] marker
# batch_text = ['. ' + x for x in batch_text]

# # Full length search is only available for single inference (for now)
# # Batched full length search requires far deeper changes to the code base
# assert(full_length_search == False or (type(batch_text) == list and len(batch_text) == 1))

# if full_length_search:
# # Tokenize each string in the batch
# un_truncated_ids = self.tok(batch_text, add_special_tokens=False).to(DEVICE)['input_ids']
# # Get the longest length in the batch
# max_length_in_batch = max(len(x) for x in un_truncated_ids)
# # Set the max length
# max_length = self.max_len(max_length_in_batch)
# else:
# # Max length is the default max length from the config
# max_length = self.query_maxlen

# if self.config.dynamic_query_maxlen:
# max_length = self.config.doc_maxlen
# obj = self.tok(batch_text, padding=False, truncation=True,
# return_tensors='pt', max_length=max_length).to(DEVICE)

# ids, mask = obj['input_ids'], obj['attention_mask']

# # postprocess for the [Q] marker and the [MASK] augmentation
# # Log original size
# ids[:, 1] = self.Q_marker_token_id
# unpadded_sizes = (ids != self.pad_token_id).sum(dim=1)
# # Log original sizes
# original_sizes = unpadded_sizes.clone()
# ids[ids == self.pad_token_id] = self.mask_token_id

# # Shorten ids and mask if necessary
# if self.config.cap_padding > 0:
# for i in range(ids.size(0)):
# unpadded_size = unpadded_sizes[i].item()
# # Add 8 to the query size itself, per query
# max_allowed_length = unpadded_size + self.config.cap_padding
# if ids.size(1) > max_allowed_length:
# ids[i, max_allowed_length:] = self.pad_token_id
# mask[i, max_allowed_length:] = 0
# # Trim the batch to the maximum allowed length across all queries
# max_length = max(unpadded_size + self.config.cap_padding for unpadded_size in unpadded_sizes)
# max_length = min(max_length, ids.size(1))
# ids = ids[:, :max_length]
# mask = mask[:, :max_length]
# # Note: This implementation already adds 8 (or the value of cap_padding) to each query individually

# if self.config.dynamic_query_maxlen:
# new_ids = []
# new_mask = []
# for i in range(ids.size(0)):
# original_length = original_sizes[i].item()
# if original_length % self.config.dynamic_querylen_multiples <= 8:
# QLEN = original_length + 8
# else:
# QLEN = math.ceil(original_length / self.config.dynamic_querylen_multiples) * self.config.dynamic_querylen_multiples

# if original_length < QLEN:
# print("Entering padding")
# print("Original length: ", original_length)
# print("QLEN: ", QLEN)
# pad_length = QLEN - original_length
# padded_ids = ids[i, :original_length].tolist() + [self.mask_token_id] * pad_length
# padded_mask = mask[i, :original_length].tolist() + [0] * pad_length
# else:
# padded_ids = ids.tolist()
# padded_mask = mask.tolist()

# new_ids.append(padded_ids)
# new_mask.append(padded_mask)

# ids = torch.tensor(new_ids, device=DEVICE)
# mask = torch.tensor(new_mask, device=DEVICE)

# if context is not None:
# assert len(context) == len(batch_text), (len(context), len(batch_text))

# obj_2 = self.tok(context, padding='longest', truncation=True,
# return_tensors='pt', max_length=self.background_maxlen).to(DEVICE)

# ids_2, mask_2 = obj_2['input_ids'][:, 1:], obj_2['attention_mask'][:, 1:] # Skip the first [SEP]

# ids = torch.cat((ids, ids_2), dim=-1)
# mask = torch.cat((mask, mask_2), dim=-1)

# if self.config.attend_to_mask_tokens:
# mask[ids == self.mask_token_id] = 1
# assert mask.sum().item() == mask.size(0) * mask.size(1), mask

# if bsize:
# batches = _split_into_batches(ids, mask, bsize)
# return batches

# if self.used is False:
# self.used = True

# firstbg = (context is None) or context[0]
# if self.verbose > 1:
# print()
# print("#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==")
# print(f"#> Input: {batch_text[0]}, \t\t {firstbg}, \t\t {bsize}")
# print(f"#> Output IDs: {ids[0].size()}, {ids[0]}")
# print(f"#> Output Mask: {mask[0].size()}, {mask[0]}")
# print()

# return ids, mask

# # Ensure that query_maxlen <= length <= 500 tokens
# def max_len(self, length):
# return min(500, max(self.query_maxlen, length))
1 change: 1 addition & 0 deletions colbert/parameters.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
SAVED_CHECKPOINTS = [32*1000, 100*1000, 150*1000, 200*1000, 250*1000, 300*1000, 400*1000]
SAVED_CHECKPOINTS += [10*1000, 20*1000, 30*1000, 40*1000, 50*1000, 60*1000, 70*1000, 80*1000, 90*1000]
SAVED_CHECKPOINTS += [25*1000, 50*1000, 75*1000]
SAVED_CHECKPOINTS += [2000, 5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000, 45000]

SAVED_CHECKPOINTS = set(SAVED_CHECKPOINTS)

2 changes: 1 addition & 1 deletion colbert/searcher.py
Original file line number Diff line number Diff line change
@@ -55,7 +55,7 @@ def configure(self, **kw_args):

def encode(self, text: TextQueries, full_length_search=False):
queries = text if type(text) is list else [text]
bsize = 128 if len(queries) > 128 else None
bsize = 512 if len(queries) > 512 else None

self.checkpoint.query_tokenizer.query_maxlen = self.config.query_maxlen
Q = self.checkpoint.queryFromText(queries, bsize=bsize, to_cpu=True, full_length_search=full_length_search)
160 changes: 137 additions & 23 deletions colbert/training/training.py
Original file line number Diff line number Diff line change
@@ -18,10 +18,11 @@
from colbert.utils.utils import print_message
from colbert.training.utils import print_progress, manage_checkpoints

from schedulefree import AdamWScheduleFree


def train(config: ColBERTConfig, triples, queries=None, collection=None):
config.checkpoint = config.checkpoint or 'bert-base-uncased'
config.checkpoint = config.checkpoint or "bert-base-uncased"

if config.rank < 1:
config.help()
@@ -34,13 +35,32 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None):
assert config.bsize % config.nranks == 0, (config.bsize, config.nranks)
config.bsize = config.bsize // config.nranks

print("Using config.bsize =", config.bsize, "(per process) and config.accumsteps =", config.accumsteps)
print(
"Using config.bsize =",
config.bsize,
"(per process) and config.accumsteps =",
config.accumsteps,
)

if collection is not None:
if config.reranker:
reader = RerankBatcher(config, triples, queries, collection, (0 if config.rank == -1 else config.rank), config.nranks)
reader = RerankBatcher(
config,
triples,
queries,
collection,
(0 if config.rank == -1 else config.rank),
config.nranks,
)
else:
reader = LazyBatcher(config, triples, queries, collection, (0 if config.rank == -1 else config.rank), config.nranks)
reader = LazyBatcher(
config,
triples,
queries,
collection,
(0 if config.rank == -1 else config.rank),
config.nranks,
)
else:
raise NotImplementedError()

@@ -52,18 +72,45 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None):
colbert = colbert.to(DEVICE)
colbert.train()

colbert = torch.nn.parallel.DistributedDataParallel(colbert, device_ids=[config.rank],
output_device=config.rank,
find_unused_parameters=True)

optimizer = AdamW(filter(lambda p: p.requires_grad, colbert.parameters()), lr=config.lr, eps=1e-8)
colbert = torch.nn.parallel.DistributedDataParallel(
colbert,
device_ids=[config.rank],
output_device=config.rank,
find_unused_parameters=True,
)

if config.schedule_free is False:
optimizer = AdamW(
filter(lambda p: p.requires_grad, colbert.parameters()),
lr=config.lr,
eps=1e-8,
)
else:
print("WARNING, USING SCHEDULE FREE")
print("WARNING, USING SCHEDULE FREE")
print("WARNING, USING SCHEDULE FREE")
print("WARNING, USING SCHEDULE FREE")
print("WARNING, USING SCHEDULE FREE")
optimizer = AdamWScheduleFree(
filter(lambda p: p.requires_grad, colbert.parameters()),
lr=config.lr,
warmup_steps=config.warmup,
weight_decay=config.schedule_free_wd,
)
if config.schedule_free:
optimizer.train()
optimizer.zero_grad()

scheduler = None
if config.warmup is not None:
print(f"#> LR will use {config.warmup} warmup steps and linear decay over {config.maxsteps} steps.")
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup,
num_training_steps=config.maxsteps)
if config.warmup is not None and config.schedule_free is False:
print(
f"#> LR will use {config.warmup} warmup steps and linear decay over {config.maxsteps} steps."
)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=config.warmup,
num_training_steps=config.maxsteps,
)

warmup_bert = config.warmup_bert
if warmup_bert is not None:
@@ -100,27 +147,74 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None):
encoding, target_scores = batch
encoding = [encoding.to(DEVICE)]

scores = colbert(*encoding)
if not config.quant_aware:
scores = colbert(*encoding)
else:
raise NotImplementedError

if config.use_ib_negatives:
scores, ib_loss = scores
ib_loss = ib_loss * config.ib_loss_weight

scores = scores.view(-1, config.nway)
if config.normalise_training_scores:
if config.normalization_method == "minmax":
print('norm')
scores = (scores - scores.min(dim=-1, keepdim=True)[0]) / (
scores.max(dim=-1, keepdim=True)[0]
- scores.min(dim=-1, keepdim=True)[0]
+ 1e-8
)
elif config.normalization_method == "querylen":
scores = scores / (
config.query_maxlen + 1e-8
) # Divide by the number of tokens in the queries

if len(target_scores) and not config.ignore_scores:
target_scores = torch.tensor(target_scores).view(-1, config.nway).to(DEVICE)
target_scores = (
torch.tensor(target_scores).view(-1, config.nway).to(DEVICE)
)
target_scores = target_scores * config.distillation_alpha
target_scores = torch.nn.functional.log_softmax(target_scores, dim=-1)

log_scores = torch.nn.functional.log_softmax(scores, dim=-1)
loss = torch.nn.KLDivLoss(reduction='batchmean', log_target=True)(log_scores, target_scores)
if config.kldiv_loss:
target_scores = torch.nn.functional.log_softmax(
target_scores, dim=-1
)

log_scores = torch.nn.functional.log_softmax(scores, dim=-1)
kldivloss = torch.nn.KLDivLoss(
reduction="batchmean", log_target=True
)(log_scores, target_scores)

if config.marginmse_loss:
margin = scores[:, 0].unsqueeze(1) - scores[:, 1:]
target_margin = target_scores[:, 0].unsqueeze(1) - target_scores[:, 1:]
marginmse_loss = torch.nn.MSELoss()(margin, target_margin)

if config.kldiv_loss and config.marginmse_loss:
weighted_kldiv = kldivloss * config.kldiv_weight
weighted_marginmse = marginmse_loss * config.marginmse_weight
loss = (
weighted_kldiv
+ weighted_marginmse
)
elif config.kldiv_loss:
loss = kldivloss
elif config.marginmse_loss:
loss = marginmse_loss
else:
raise ValueError(
"One or both of config.kldiv_loss and config.marginmse_loss must be True if distillation is enabled!"
)
else:
loss = nn.CrossEntropyLoss()(scores, labels[:scores.size(0)])
raise ValueError("crossentropy loss shouldn't be used here")
loss = nn.CrossEntropyLoss()(scores, labels[: scores.size(0)])

if config.use_ib_negatives:
if config.rank < 1:
print('\t\t\t\t', loss.item(), ib_loss.item())
print("\t\t\t\t", loss.item(), ib_loss.item())

og_loss = loss
loss += ib_loss

loss = loss / config.accumsteps
@@ -135,20 +229,40 @@ def train(config: ColBERTConfig, triples, queries=None, collection=None):
train_loss = this_batch_loss if train_loss is None else train_loss
train_loss = train_loss_mu * train_loss + (1 - train_loss_mu) * this_batch_loss

if config.schedule_free:
assert scheduler is None

amp.step(colbert, optimizer, scheduler)

if config.rank < 1:
if config.use_ib_negatives:
print_message(f"IB Loss: {ib_loss}")
print_message(f"KL-D loss: {og_loss}")
if config.kldiv_loss and config.marginmse_loss:
TOTAL = weighted_kldiv + weighted_marginmse
kldiv_proportion = weighted_kldiv / TOTAL
marginmse_proportion = weighted_marginmse / TOTAL
print_message(f"Weighted KL-D loss: {weighted_kldiv:.4f}")
print_message(f"Weighted MarginMSE loss: {weighted_marginmse:.4f}")
print_message(f"Respective proportions: KL-D {kldiv_proportion:.2%}, MarginMSE {marginmse_proportion:.2%}")
print_message(batch_idx, train_loss)
manage_checkpoints(config, colbert, optimizer, batch_idx+1, savepath=None)
manage_checkpoints(config, colbert, optimizer, batch_idx + 1, savepath=None)

if config.rank < 1:
print_message("#> Done with all triples!")
ckpt_path = manage_checkpoints(config, colbert, optimizer, batch_idx+1, savepath=None, consumed_all_triples=True)
ckpt_path = manage_checkpoints(
config,
colbert,
optimizer,
batch_idx + 1,
savepath=None,
consumed_all_triples=True,
is_schedule_free=config.schedule_free,
)

return ckpt_path # TODO: This should validate and return the best checkpoint, not just the last one.



def set_bert_grad(colbert, value):
try:
for p in colbert.bert.parameters():
26 changes: 20 additions & 6 deletions colbert/training/utils.py
Original file line number Diff line number Diff line change
@@ -8,16 +8,23 @@


def print_progress(scores):
positive_avg, negative_avg = round(scores[:, 0].mean().item(), 2), round(scores[:, 1].mean().item(), 2)
print("#>>> ", positive_avg, negative_avg, '\t\t|\t\t', positive_avg - negative_avg)
positive_avg, negative_avg = (
round(scores[:, 0].mean().item(), 2),
round(scores[:, 1].mean().item(), 2),
)
print(
"#>>> ", positive_avg, negative_avg, "\t\t|\t\t", positive_avg - negative_avg
)


def manage_checkpoints(args, colbert, optimizer, batch_idx, savepath=None, consumed_all_triples=False):
def manage_checkpoints(
args, colbert, optimizer, batch_idx, savepath=None, consumed_all_triples=False, is_schedule_free=False
):
# arguments = dict(args)

# TODO: Call provenance() on the values that support it??

checkpoints_path = savepath or os.path.join(Run().path_, 'checkpoints')
checkpoints_path = savepath or os.path.join(Run().path_, "checkpoints")
name = None

try:
@@ -27,29 +34,36 @@ def manage_checkpoints(args, colbert, optimizer, batch_idx, savepath=None, consu

if not os.path.exists(checkpoints_path):
os.makedirs(checkpoints_path)

path_save = None

if consumed_all_triples or (batch_idx % 2000 == 0):
# name = os.path.join(path, "colbert.dnn")
# save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments)
path_save = os.path.join(checkpoints_path, "colbert")
if is_schedule_free:
optimizer.eval()

if batch_idx in SAVED_CHECKPOINTS:
# name = os.path.join(path, "colbert-{}.dnn".format(batch_idx))
# save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments)
path_save = os.path.join(checkpoints_path, f"colbert-{batch_idx}")
if is_schedule_free:
optimizer.eval()

if path_save:
print(f"#> Saving a checkpoint to {path_save} ..")

checkpoint = {}
checkpoint['batch'] = batch_idx
checkpoint["batch"] = batch_idx
# checkpoint['epoch'] = 0
# checkpoint['model_state_dict'] = model.state_dict()
# checkpoint['optimizer_state_dict'] = optimizer.state_dict()
# checkpoint['arguments'] = arguments

save(path_save)

if not consumed_all_triples and is_schedule_free:
optimizer.train()

return path_save
6 changes: 4 additions & 2 deletions colbert/utils/amp.py
Original file line number Diff line number Diff line change
@@ -23,12 +23,14 @@ def backward(self, loss):
def step(self, colbert, optimizer, scheduler=None):
if self.activated:
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0, error_if_nonfinite=False)
if scheduler is not None:
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0, error_if_nonfinite=False)

self.scaler.step(optimizer)
self.scaler.update()
else:
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0)
if scheduler is not None:
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0)
optimizer.step()

if scheduler is not None: