Skip to content

Commit

Permalink
Adding a constructor to load the BaseColBERT.
Browse files Browse the repository at this point in the history
  • Loading branch information
codetalker7 committed May 30, 2024
1 parent 79f1340 commit b28af5c
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/modelling/checkpoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,21 @@ struct BaseColBERT
tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder
end

function BaseColBERT(checkpoint::String, config::ColBERTConfig)
# since Transformers.jl doesn't support local loading
# we manually load the linear layer
bert_config = Transformers.load_config(checkpoint)
bert_state_dict = HuggingFace.load_state_dict(checkpoint)
bert_model = HuggingFace.load_model(:bert, checkpoint, :model, bert_state_dict; config = bert_config)
linear = HuggingFace._load_dense(bert_state_dict, "linear", bert_config.hidden_size, config.doc_settings.dim, bert_config.initializer_range, true)

tokenizer = Transformers.load_tokenizer(checkpoint)

return BaseColBERT(bert_model, linear, tokenizer)
end

struct Checkpoint

model::BaseColBERT
doc_tokenizer::Any
colbert_config::Any
Expand Down

0 comments on commit b28af5c

Please sign in to comment.