Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Many design changes + optimizations. #27

Merged
merged 59 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
8dda46f
Moving all the config to just one struct + even better defaults.
codetalker7 Aug 10, 2024
e2e9274
Adding functions to load and save the config to JSON.
codetalker7 Aug 10, 2024
a5213cd
Changing name from `load` to `load_config` to avoid future
codetalker7 Aug 10, 2024
bdcbf51
Minor change in how config fields are accessed.
codetalker7 Aug 11, 2024
5fb3462
Updating the `BaseColBERT` constructor to only take the config. Also
codetalker7 Aug 11, 2024
a4d3ea0
Removing the `config` and `doc`/`query_tokenizers` from `Checkpoint`, as
codetalker7 Aug 11, 2024
408a1cd
Removing the `DocTokenizer` type (it's unnecessary), and changing
codetalker7 Aug 11, 2024
6f1cc7e
Adding `config` as an argument to `docs` and `docFromText`; this is a
codetalker7 Aug 11, 2024
91e251b
Minor edit in docs.
codetalker7 Aug 11, 2024
c1d7fbf
Removing the `QueryTokenizer` struct, and renaming `tensorize` to
codetalker7 Aug 11, 2024
48720bd
Adding the `config` as an argument to the `query` and `queryFromText`
codetalker7 Aug 11, 2024
20f0402
Simplyfying the signatures of the setup functions; only using primitive
codetalker7 Aug 11, 2024
bda3168
Using `JLD2.save_object` instead of `JLD2.save`.
codetalker7 Aug 11, 2024
eca7bab
Creating the index dir in the setup function if it doesn't exist.
codetalker7 Aug 11, 2024
b33b7bf
Simplyfing all the clustering related code.
codetalker7 Aug 11, 2024
314219b
Simplyfing saving and loading of the codec to/from the index path.
codetalker7 Aug 11, 2024
7ad3d1c
Minor change in `compress_into_codes`; using `centroids` as an argument
codetalker7 Aug 11, 2024
c5f39f3
Allowing a custom chunksize.
codetalker7 Aug 12, 2024
bebb4a5
Returning a codec `Dict` from `load_codec`.
codetalker7 Aug 12, 2024
07c2b8a
Simplyfying the arguments of the `binarize` and `compress` functions;
codetalker7 Aug 12, 2024
a2870a5
Saving the chunksize in the indexing plan metadata.
codetalker7 Aug 12, 2024
4f9353d
Simplyfing the `index` and `save_chunk` functions; using mostly
codetalker7 Aug 12, 2024
f6c7308
Removing unused exports.
codetalker7 Aug 12, 2024
fb5c0c3
Simplyfying `load_codes.`
codetalker7 Aug 12, 2024
0b93025
Simplyfying the `finalize` functions.
codetalker7 Aug 12, 2024
01d6a95
Updating docstrings for indexing functions.
codetalker7 Aug 12, 2024
b53801c
Updating docstrings of the functions in `residual.jl`.
codetalker7 Aug 12, 2024
92a46ae
Removing unnecessary files.
codetalker7 Aug 12, 2024
5a5be9c
Updating the examples file.
codetalker7 Aug 12, 2024
716eb76
Some minor optimizations in the code for `docFromText`.
codetalker7 Aug 12, 2024
ddf22cb
Making `tensorize_docs` return only the ids and mask.
codetalker7 Aug 12, 2024
262b7c7
Processing `config.passages_batch_size` passages in `encode_passages`,
codetalker7 Aug 12, 2024
f0b941a
Removing unnecessary util functions.
codetalker7 Aug 12, 2024
528c1ec
Applying format; changing the default `index_bsize` to `32`, and
codetalker7 Aug 12, 2024
ea61775
Adding some asserts.
codetalker7 Aug 12, 2024
9515226
Changing the internals of `Indexer`, and adding a new constructor.
codetalker7 Aug 14, 2024
78dc814
Keeping the index function; will change it later.
codetalker7 Aug 14, 2024
23cc67b
Adding some type checks.
codetalker7 Aug 15, 2024
6c63834
Making the `setup`, `train` and `_sample_embeddings` functions test
codetalker7 Aug 15, 2024
28244b7
Updating docstrings.
codetalker7 Aug 15, 2024
15c525f
Completing the `Indexer` functions.
codetalker7 Aug 15, 2024
7c98940
Updating the example indexing script.
codetalker7 Aug 15, 2024
de2db68
Simplifying modelling functions for queries; making them more test
codetalker7 Aug 16, 2024
88c1fc5
Saving all metadata in `plan.json`.
codetalker7 Aug 16, 2024
9b6a018
Minor fix in loading codes.
codetalker7 Aug 16, 2024
ac90533
Adding more fields to the `Searcher`, and updating it's constructor.
codetalker7 Aug 16, 2024
cf69b8d
Minor change to the `encode_query` function.
codetalker7 Aug 16, 2024
c661775
Removing unnecessary functions.
codetalker7 Aug 16, 2024
1702ae3
Simplyfying the decompression functions; making them test friendly.
codetalker7 Aug 16, 2024
65d8d52
Simplifying the search code, and removing the `IndexScorer`.
codetalker7 Aug 16, 2024
d8078eb
File rename.
codetalker7 Aug 16, 2024
affdb5e
Removing `ResidalCodec`.
codetalker7 Aug 16, 2024
5a49d92
Removing data structs; they aren't really needed.
codetalker7 Aug 16, 2024
6fda1e3
Moving loading/saving functions to their own files.
codetalker7 Aug 16, 2024
87d8534
Simplifying `Searcher` constructor.
codetalker7 Aug 16, 2024
2af3336
Refactoring the ranking code; making it more test friendly.
codetalker7 Aug 16, 2024
60a4e30
Updating the examples.
codetalker7 Aug 16, 2024
60e0d81
Removing strided tensors for now.
codetalker7 Aug 16, 2024
cb99f55
Running the Julia formatter.
codetalker7 Aug 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 5 additions & 39 deletions examples/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,13 @@ using Random
# set the global seed
Random.seed!(0)

# create the config
dataroot = "downloads/lotte"
dataset = "lifestyle"
datasplit = "dev"
path = joinpath(dataroot, dataset, datasplit, "short_collection.tsv")

collection = Collection(path)
length(collection.data)

nbits = 2 # encode each dimension with 2 bits
doc_maxlen = 300 # truncate passages at 300 tokens

checkpoint = "colbert-ir/colbertv2.0" # the HF checkpoint
index_root = "experiments/notebook/indexes"
index_name = "short_$(dataset).$(datasplit).$(nbits)bits"
index_path = joinpath(index_root, index_name)

config = ColBERTConfig(
RunSettings(
experiment = "notebook",
use_gpu = true
),
TokenizerSettings(),
ResourceSettings(
checkpoint = checkpoint,
collection = collection,
index_name = index_name
),
DocSettings(
doc_maxlen = doc_maxlen,
),
QuerySettings(),
IndexingSettings(
index_path = index_path,
index_bsize = 3,
nbits = nbits,
kmeans_niters = 20
),
SearchSettings()
use_gpu = true,
collection = "./short_collection",
doc_maxlen = 300,
index_path = "./short_collection_index/",
chunksize = 3
)

indexer = Indexer(config)
index(indexer)
ColBERT.save(config)
216 changes: 216 additions & 0 deletions examples/load_colbert.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
using Transformers
using JSON3
using Transformers.HuggingFace
const HF = Transformers.HuggingFace

"""
_load_tokenizer_config(path_config)
Load tokenizer config locally.
"""
function _load_tokenizer_config(path_config::AbstractString)
@assert isfile(path_config) "Tokenizer config file not found: $path_config"
return JSON3.read(read(path_config))
end

"""
extract_tokenizer_type(tkr_type::AbstractString)
Extract tokenizer type from config.
"""
function extract_tokenizer_type(tkr_type::AbstractString)
m = match(r"(\S+)Tokenizer(Fast)?", tkr_type)
isnothing(m) &&
error("Unknown tokenizer: $tkr_type")
tkr_type = Symbol(lowercase(m.captures[1]))
end

"""
_load_tokenizer(cfg::HF.HGFConfig; path_tokenizer_config::AbstractString,
path_special_tokens_map::AbstractString, path_tokenizer::AbstractString)
Local tokenizer loader.
"""
function _load_tokenizer(cfg::HF.HGFConfig;
path_tokenizer_config::AbstractString,
path_special_tokens_map::AbstractString, path_tokenizer::AbstractString)
@assert isfile(path_tokenizer_config) "Tokenizer config file not found: $path_tokenizer_config"
@assert isfile(path_special_tokens_map) "Special tokens map file not found: $path_special_tokens_map"
@assert isfile(path_tokenizer) "Tokenizer file not found: $path_tokenizer"
## load tokenizer config
tkr_cfg = _load_tokenizer_config(path_tokenizer_config)
tkr_type_sym = extract_tokenizer_type(tkr_cfg.tokenizer_class)
tkr_type = HF.tokenizer_type(tkr_type_sym) # eg, Val(:bert)()
## load special tokens
special_tokens = HF.load_special_tokens_map(path_special_tokens_map)
## load tokenizer
kwargs = HF.extract_fast_tkr_kwargs(
tkr_type, tkr_cfg, cfg, special_tokens)
tokenizer, vocab, process_config, decode, textprocess = HF.load_fast_tokenizer(
tkr_type, path_tokenizer, cfg)
for (k, v) in process_config
kwargs[k] = v
end
## construct tokenizer and mutate the decode +textprocess pipelines
tkr = HF.encoder_construct(
tkr_type, tokenizer, vocab; kwargs...)
tkr = HF.setproperties!!(
tkr, (; decode, textprocess))
return tkr
end

"""
_load_model(cfg::HF.HGFConfig; path_model::AbstractString,
trainmode::Bool = false, lazy::Bool = false, mmap::Bool = true)
Local model loader.
"""
function _load_model(cfg::HF.HGFConfig;
path_model::AbstractString,
trainmode::Bool = false, lazy::Bool = false, mmap::Bool = true)
@assert isfile(path_model) "Model file not found: $path_model"
@assert endswith(path_model, ".bin") "Model file must end with .bin (type torch `pickle`): $path_model"
## Assume fixed
task = :model

## Load state dict
# We know we have pytorch_model.bin -> so format is :pickle and it's a single file
# status = HF.singlefilename(HF.WeightStatus{:pickle})
status = HF.HasSingleFile{:pickle}(path_model)
state_dict = HF.load_state_dict_from(
status; lazy, mmap)

##
model_type = HF.get_model_type(
HF.getconfigname(cfg), task)
basekey = String(HF.basemodelkey(model_type))
if HF.isbasemodel(model_type)
prefix = HF.haskeystartswith(
state_dict, basekey) ? basekey : ""
else
prefix = ""
if !HF.haskeystartswith(
state_dict, basekey)
new_state_dict = OrderedDict{
Any, Any}()
for (key, val) in state_dict
new_state_dict[joinname(basekey, key)] = val
end
state_dict = new_state_dict
end
end
model = load_model(
model_type, cfg, state_dict, prefix)
trainmode || (model = Layers.testmode(model))
return model
end

"""
load_hgf_pretrained_local(dir_spec::AbstractString;
path_config::Union{Nothing, AbstractString} = nothing,
path_tokenizer_config::Union{Nothing, AbstractString} = nothing,
path_special_tokens_map::Union{Nothing, AbstractString} = nothing,
path_tokenizer::Union{Nothing, AbstractString} = nothing,
path_model::Union{Nothing, AbstractString} = nothing,
kwargs...
)
Local model loader. Honors the `load_hgf_pretrained` interface, where you can request
specific files to be loaded, eg, `my/dir/to/model:tokenizer` or `my/dir/to/model:config`.
# Arguments
- `dir_spec::AbstractString`: Directory specification (item specific after the colon is optional), eg, `my/dir/to/model` or `my/dir/to/model:tokenizer`.
- `path_config::Union{Nothing, AbstractString}`: Path to config file.
- `path_tokenizer_config::Union{Nothing, AbstractString}`: Path to tokenizer config file.
- `path_special_tokens_map::Union{Nothing, AbstractString}`: Path to special tokens map file.
- `path_tokenizer::Union{Nothing, AbstractString}`: Path to tokenizer file.
- `path_model::Union{Nothing, AbstractString}`: Path to model file.
- `kwargs...`: Additional keyword arguments for `_load_model` function like `mmap`, `lazy`, `trainmode`.
"""
function load_hgf_pretrained_local(
dir_spec::AbstractString;
path_config::Union{
Nothing, AbstractString} = nothing,
path_tokenizer_config::Union{
Nothing, AbstractString} = nothing,
path_special_tokens_map::Union{
Nothing, AbstractString} = nothing,
path_tokenizer::Union{
Nothing, AbstractString} = nothing,
path_model::Union{
Nothing, AbstractString} = nothing,
kwargs...
)

## Extract if item was provided
name_item = rsplit(dir_spec, ':'; limit = 2)
all = length(name_item) == 1
dir, item = if all
dir_spec, "model"
else
Iterators.map(String, name_item)
end
item = lowercase(item)
## Set paths
@assert isdir(dir) "Local directory not found: $dir"
if isnothing(path_config)
path_config = joinpath(dir, "config.json")
end
if isnothing(path_tokenizer_config)
path_tokenizer_config = joinpath(
dir, "tokenizer_config.json")
end
if isnothing(path_special_tokens_map)
path_special_tokens_map = joinpath(
dir, "special_tokens_map.json")
end
if isnothing(path_tokenizer)
path_tokenizer = joinpath(
dir, "tokenizer.json")
end
if isnothing(path_model)
path_model = joinpath(
dir, "pytorch_model.bin")
end
## Check if they exist
@assert isfile(path_config) "Config file not found: $path_config"
@assert isfile(path_tokenizer_config) "Tokenizer config file not found: $path_tokenizer_config"
@assert isfile(path_special_tokens_map) "Special tokens map file not found: $path_special_tokens_map"
@assert isfile(path_tokenizer) "Tokenizer file not found: $path_tokenizer"
@assert isfile(path_model) "Model file not found: $path_model"

## load config
cfg = HF._load_config(path_config)
item == "config" && return cfg

## load tokenizer
if item == "tokenizer" || all
tkr = _load_tokenizer(
cfg; path_tokenizer_config,
path_special_tokens_map,
path_tokenizer)
end
item == "tokenizer" && return tkr

## load model
model = _load_model(
cfg; path_model, kwargs...)

if all
return tkr, model
else
return model
end
end

## Example
using Transformers.TextEncoders

# My files can be found in this directory
dir = "colbert-ir"
textenc, model = load_hgf_pretrained_local(dir)

encoded = encode(textenc, "test it")
output = model(encoded)
22 changes: 7 additions & 15 deletions examples/searching.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
using ColBERT
using CUDA

# create the config
dataroot = "downloads/lotte"
dataset = "lifestyle"
datasplit = "dev"
path = joinpath(dataroot, dataset, datasplit, "short_collection.tsv")

nbits = 2 # encode each dimension with 2 bits

index_root = "experiments/notebook/indexes"
index_name = "short_$(dataset).$(datasplit).$(nbits)bits"
index_path = joinpath(index_root, index_name)

# build the searcher
index_path = "short_collection_index"
searcher = Searcher(index_path)

# load the collection
collection = readlines(searcher.config.collection)

# search for a query
query = "what are white spots on raspberries?"
pids, scores = search(searcher, query, 2)
print(searcher.config.resource_settings.collection.data[pids])
print(collection[pids])

query = "are rabbits easy to housebreak?"
pids, scores = search(searcher, query, 9)
print(searcher.config.resource_settings.collection.data[pids])
pids, scores = search(searcher, query, 1)
print(collection[pids])
23 changes: 8 additions & 15 deletions src/ColBERT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,29 @@ using Transformers
# utils
include("utils/utils.jl")

# datasets
include("data/collection.jl")
include("data/queries.jl")
export Collection, Queries

# config and other infra
include("infra/settings.jl")
include("infra/config.jl")
export RunSettings, TokenizerSettings, ResourceSettings,
DocSettings, QuerySettings, IndexingSettings,
SearchSettings, ColBERTConfig
export ColBERTConfig

# models, document/query tokenizers
include("modelling/tokenization/doc_tokenization.jl")
include("modelling/tokenization/query_tokenization.jl")
include("modelling/checkpoint.jl")
export BaseColBERT, Checkpoint, DocTokenizer, QueryTokenizer
export BaseColBERT, Checkpoint

# indexer
include("indexing/codecs/residual.jl")
include("indexing.jl")
include("indexing/collection_encoder.jl")
include("indexing/index_saver.jl")
include("indexing/collection_indexer.jl")
export Indexer, CollectionIndexer, index
export Indexer, index

# searcher
include("search/strided_tensor.jl")
include("search/index_storage.jl")
include("search/ranking.jl")
include("searching.jl")
export Searcher, search

# loaders and savers
include("loaders.jl")
include("savers.jl")

end
Loading
Loading