diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index 3df5265..b7f617e 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -4,7 +4,21 @@ struct BaseColBERT tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder end +function BaseColBERT(checkpoint::String, config::ColBERTConfig) + # since Transformers.jl doesn't support local loading + # we manually load the linear layer + bert_config = Transformers.load_config(checkpoint) + bert_state_dict = HuggingFace.load_state_dict(checkpoint) + bert_model = HuggingFace.load_model(:bert, checkpoint, :model, bert_state_dict; config = bert_config) + linear = HuggingFace._load_dense(bert_state_dict, "linear", bert_config.hidden_size, config.doc_settings.dim, bert_config.initializer_range, true) + + tokenizer = Transformers.load_tokenizer(checkpoint) + + return BaseColBERT(bert_model, linear, tokenizer) +end + struct Checkpoint + model::BaseColBERT doc_tokenizer::Any colbert_config::Any