Skip to content

Commit

Permalink
Formatting the source files.
Browse files Browse the repository at this point in the history
  • Loading branch information
codetalker7 committed May 29, 2024
1 parent c258729 commit 52d345e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
7 changes: 3 additions & 4 deletions src/ColBERT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ using CSV
using Logging
using Transformers


# datasets
include("data/collection.jl")
include("data/queries.jl")
Expand All @@ -12,9 +11,9 @@ export Collection, Queries
# config and other infra
include("infra/settings.jl")
include("infra/config.jl")
export RunSettings, TokenizerSettings, ResourceSettings,
DocSettings, QuerySettings, IndexingSettings,
SearchSettings, ColBERTConfig
export RunSettings, TokenizerSettings, ResourceSettings,
DocSettings, QuerySettings, IndexingSettings,
SearchSettings, ColBERTConfig

# models
include("modelling/checkpoint.jl")
Expand Down
12 changes: 8 additions & 4 deletions src/data/collection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ struct Collection
end

function Collection(path::String)
file = CSV.File(path; delim='\t', header = [:pid, :text], types = Dict(:pid => Int, :text => String), debug=true, quoted=false)
file = CSV.File(path; delim = '\t', header = [:pid, :text],
types = Dict(:pid => Int, :text => String), debug = true, quoted = false)
@info "Loaded $(length(file.text)[1]) passages."
Collection(path, file.text)
end
Expand All @@ -15,7 +16,9 @@ function get_chunksize(collection::Collection, nranks::Int)
min(25000, 1 + floor(length(collection.data) / nranks))
end

function enumerate_batches(collection::Collection, chunksize::Union{Int, Missing} = missing, nranks::Union{Int, Missing} = missing)
function enumerate_batches(
collection::Collection, chunksize::Union{Int, Missing} = missing,
nranks::Union{Int, Missing} = missing)
if ismissing(chunksize)
if ismissing(nranks)
error("Atleast one of the arguments chunksize or nranks must be specified!")
Expand All @@ -27,7 +30,9 @@ function enumerate_batches(collection::Collection, chunksize::Union{Int, Missing
batches = Vector{Tuple{Int, Int, Vector{String}}}()
chunk_idx, offset = 1, 1
while true
push!(batches, (chunk_idx, offset, collection.data[offset:min(offset + chunksize - 1, num_passages)]))
push!(batches,
(chunk_idx, offset,
collection.data[offset:min(offset + chunksize - 1, num_passages)]))
chunk_idx += 1
offset += chunksize

Expand All @@ -37,4 +42,3 @@ function enumerate_batches(collection::Collection, chunksize::Union{Int, Missing
end
batches
end

8 changes: 4 additions & 4 deletions src/modelling/checkpoint.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
struct BaseColBERT
bert
linear
bert::Any
linear::Any
tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder
end

struct Checkpoint
model::BaseColBERT
doc_tokenizer
colbert_config
doc_tokenizer::Any
colbert_config::Any
end

0 comments on commit 52d345e

Please sign in to comment.