diff --git a/src/ColBERT.jl b/src/ColBERT.jl index 41fb017..a52d9a4 100644 --- a/src/ColBERT.jl +++ b/src/ColBERT.jl @@ -3,7 +3,6 @@ using CSV using Logging using Transformers - # datasets include("data/collection.jl") include("data/queries.jl") @@ -12,9 +11,9 @@ export Collection, Queries # config and other infra include("infra/settings.jl") include("infra/config.jl") -export RunSettings, TokenizerSettings, ResourceSettings, - DocSettings, QuerySettings, IndexingSettings, - SearchSettings, ColBERTConfig +export RunSettings, TokenizerSettings, ResourceSettings, + DocSettings, QuerySettings, IndexingSettings, + SearchSettings, ColBERTConfig # models include("modelling/checkpoint.jl") diff --git a/src/data/collection.jl b/src/data/collection.jl index dbfeec0..db21e86 100644 --- a/src/data/collection.jl +++ b/src/data/collection.jl @@ -6,7 +6,8 @@ struct Collection end function Collection(path::String) - file = CSV.File(path; delim='\t', header = [:pid, :text], types = Dict(:pid => Int, :text => String), debug=true, quoted=false) + file = CSV.File(path; delim = '\t', header = [:pid, :text], + types = Dict(:pid => Int, :text => String), debug = true, quoted = false) @info "Loaded $(length(file.text)[1]) passages." Collection(path, file.text) end @@ -15,7 +16,9 @@ function get_chunksize(collection::Collection, nranks::Int) min(25000, 1 + floor(length(collection.data) / nranks)) end -function enumerate_batches(collection::Collection, chunksize::Union{Int, Missing} = missing, nranks::Union{Int, Missing} = missing) +function enumerate_batches( + collection::Collection, chunksize::Union{Int, Missing} = missing, + nranks::Union{Int, Missing} = missing) if ismissing(chunksize) if ismissing(nranks) error("Atleast one of the arguments chunksize or nranks must be specified!") @@ -27,7 +30,9 @@ function enumerate_batches(collection::Collection, chunksize::Union{Int, Missing batches = Vector{Tuple{Int, Int, Vector{String}}}() chunk_idx, offset = 1, 1 while true - push!(batches, (chunk_idx, offset, collection.data[offset:min(offset + chunksize - 1, num_passages)])) + push!(batches, + (chunk_idx, offset, + collection.data[offset:min(offset + chunksize - 1, num_passages)])) chunk_idx += 1 offset += chunksize @@ -37,4 +42,3 @@ function enumerate_batches(collection::Collection, chunksize::Union{Int, Missing end batches end - diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index 8faa736..3df5265 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -1,11 +1,11 @@ struct BaseColBERT - bert - linear + bert::Any + linear::Any tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder end struct Checkpoint model::BaseColBERT - doc_tokenizer - colbert_config + doc_tokenizer::Any + colbert_config::Any end