Skip to content

Commit

Permalink
Merge pull request #27 from codetalker7/simpler_config
Browse files Browse the repository at this point in the history
Many design changes + optimizations.
  • Loading branch information
codetalker7 authored Aug 16, 2024
2 parents 2690cfa + cb99f55 commit 302b68c
Show file tree
Hide file tree
Showing 23 changed files with 1,595 additions and 1,808 deletions.
44 changes: 5 additions & 39 deletions examples/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,13 @@ using Random
# set the global seed
Random.seed!(0)

# create the config
dataroot = "downloads/lotte"
dataset = "lifestyle"
datasplit = "dev"
path = joinpath(dataroot, dataset, datasplit, "short_collection.tsv")

collection = Collection(path)
length(collection.data)

nbits = 2 # encode each dimension with 2 bits
doc_maxlen = 300 # truncate passages at 300 tokens

checkpoint = "colbert-ir/colbertv2.0" # the HF checkpoint
index_root = "experiments/notebook/indexes"
index_name = "short_$(dataset).$(datasplit).$(nbits)bits"
index_path = joinpath(index_root, index_name)

config = ColBERTConfig(
RunSettings(
experiment = "notebook",
use_gpu = true
),
TokenizerSettings(),
ResourceSettings(
checkpoint = checkpoint,
collection = collection,
index_name = index_name
),
DocSettings(
doc_maxlen = doc_maxlen,
),
QuerySettings(),
IndexingSettings(
index_path = index_path,
index_bsize = 3,
nbits = nbits,
kmeans_niters = 20
),
SearchSettings()
use_gpu = true,
collection = "./short_collection",
doc_maxlen = 300,
index_path = "./short_collection_index/",
chunksize = 3
)

indexer = Indexer(config)
index(indexer)
ColBERT.save(config)
216 changes: 216 additions & 0 deletions examples/load_colbert.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
using Transformers
using JSON3
using Transformers.HuggingFace
const HF = Transformers.HuggingFace

"""
_load_tokenizer_config(path_config)
Load tokenizer config locally.
"""
function _load_tokenizer_config(path_config::AbstractString)
@assert isfile(path_config) "Tokenizer config file not found: $path_config"
return JSON3.read(read(path_config))
end

"""
extract_tokenizer_type(tkr_type::AbstractString)
Extract tokenizer type from config.
"""
function extract_tokenizer_type(tkr_type::AbstractString)
m = match(r"(\S+)Tokenizer(Fast)?", tkr_type)
isnothing(m) &&
error("Unknown tokenizer: $tkr_type")
tkr_type = Symbol(lowercase(m.captures[1]))
end

"""
_load_tokenizer(cfg::HF.HGFConfig; path_tokenizer_config::AbstractString,
path_special_tokens_map::AbstractString, path_tokenizer::AbstractString)
Local tokenizer loader.
"""
function _load_tokenizer(cfg::HF.HGFConfig;
path_tokenizer_config::AbstractString,
path_special_tokens_map::AbstractString, path_tokenizer::AbstractString)
@assert isfile(path_tokenizer_config) "Tokenizer config file not found: $path_tokenizer_config"
@assert isfile(path_special_tokens_map) "Special tokens map file not found: $path_special_tokens_map"
@assert isfile(path_tokenizer) "Tokenizer file not found: $path_tokenizer"
## load tokenizer config
tkr_cfg = _load_tokenizer_config(path_tokenizer_config)
tkr_type_sym = extract_tokenizer_type(tkr_cfg.tokenizer_class)
tkr_type = HF.tokenizer_type(tkr_type_sym) # eg, Val(:bert)()
## load special tokens
special_tokens = HF.load_special_tokens_map(path_special_tokens_map)
## load tokenizer
kwargs = HF.extract_fast_tkr_kwargs(
tkr_type, tkr_cfg, cfg, special_tokens)
tokenizer, vocab, process_config, decode, textprocess = HF.load_fast_tokenizer(
tkr_type, path_tokenizer, cfg)
for (k, v) in process_config
kwargs[k] = v
end
## construct tokenizer and mutate the decode +textprocess pipelines
tkr = HF.encoder_construct(
tkr_type, tokenizer, vocab; kwargs...)
tkr = HF.setproperties!!(
tkr, (; decode, textprocess))
return tkr
end

"""
_load_model(cfg::HF.HGFConfig; path_model::AbstractString,
trainmode::Bool = false, lazy::Bool = false, mmap::Bool = true)
Local model loader.
"""
function _load_model(cfg::HF.HGFConfig;
path_model::AbstractString,
trainmode::Bool = false, lazy::Bool = false, mmap::Bool = true)
@assert isfile(path_model) "Model file not found: $path_model"
@assert endswith(path_model, ".bin") "Model file must end with .bin (type torch `pickle`): $path_model"
## Assume fixed
task = :model

## Load state dict
# We know we have pytorch_model.bin -> so format is :pickle and it's a single file
# status = HF.singlefilename(HF.WeightStatus{:pickle})
status = HF.HasSingleFile{:pickle}(path_model)
state_dict = HF.load_state_dict_from(
status; lazy, mmap)

##
model_type = HF.get_model_type(
HF.getconfigname(cfg), task)
basekey = String(HF.basemodelkey(model_type))
if HF.isbasemodel(model_type)
prefix = HF.haskeystartswith(
state_dict, basekey) ? basekey : ""
else
prefix = ""
if !HF.haskeystartswith(
state_dict, basekey)
new_state_dict = OrderedDict{
Any, Any}()
for (key, val) in state_dict
new_state_dict[joinname(basekey, key)] = val
end
state_dict = new_state_dict
end
end
model = load_model(
model_type, cfg, state_dict, prefix)
trainmode || (model = Layers.testmode(model))
return model
end

"""
load_hgf_pretrained_local(dir_spec::AbstractString;
path_config::Union{Nothing, AbstractString} = nothing,
path_tokenizer_config::Union{Nothing, AbstractString} = nothing,
path_special_tokens_map::Union{Nothing, AbstractString} = nothing,
path_tokenizer::Union{Nothing, AbstractString} = nothing,
path_model::Union{Nothing, AbstractString} = nothing,
kwargs...
)
Local model loader. Honors the `load_hgf_pretrained` interface, where you can request
specific files to be loaded, eg, `my/dir/to/model:tokenizer` or `my/dir/to/model:config`.
# Arguments
- `dir_spec::AbstractString`: Directory specification (item specific after the colon is optional), eg, `my/dir/to/model` or `my/dir/to/model:tokenizer`.
- `path_config::Union{Nothing, AbstractString}`: Path to config file.
- `path_tokenizer_config::Union{Nothing, AbstractString}`: Path to tokenizer config file.
- `path_special_tokens_map::Union{Nothing, AbstractString}`: Path to special tokens map file.
- `path_tokenizer::Union{Nothing, AbstractString}`: Path to tokenizer file.
- `path_model::Union{Nothing, AbstractString}`: Path to model file.
- `kwargs...`: Additional keyword arguments for `_load_model` function like `mmap`, `lazy`, `trainmode`.
"""
function load_hgf_pretrained_local(
dir_spec::AbstractString;
path_config::Union{
Nothing, AbstractString} = nothing,
path_tokenizer_config::Union{
Nothing, AbstractString} = nothing,
path_special_tokens_map::Union{
Nothing, AbstractString} = nothing,
path_tokenizer::Union{
Nothing, AbstractString} = nothing,
path_model::Union{
Nothing, AbstractString} = nothing,
kwargs...
)

## Extract if item was provided
name_item = rsplit(dir_spec, ':'; limit = 2)
all = length(name_item) == 1
dir, item = if all
dir_spec, "model"
else
Iterators.map(String, name_item)
end
item = lowercase(item)
## Set paths
@assert isdir(dir) "Local directory not found: $dir"
if isnothing(path_config)
path_config = joinpath(dir, "config.json")
end
if isnothing(path_tokenizer_config)
path_tokenizer_config = joinpath(
dir, "tokenizer_config.json")
end
if isnothing(path_special_tokens_map)
path_special_tokens_map = joinpath(
dir, "special_tokens_map.json")
end
if isnothing(path_tokenizer)
path_tokenizer = joinpath(
dir, "tokenizer.json")
end
if isnothing(path_model)
path_model = joinpath(
dir, "pytorch_model.bin")
end
## Check if they exist
@assert isfile(path_config) "Config file not found: $path_config"
@assert isfile(path_tokenizer_config) "Tokenizer config file not found: $path_tokenizer_config"
@assert isfile(path_special_tokens_map) "Special tokens map file not found: $path_special_tokens_map"
@assert isfile(path_tokenizer) "Tokenizer file not found: $path_tokenizer"
@assert isfile(path_model) "Model file not found: $path_model"

## load config
cfg = HF._load_config(path_config)
item == "config" && return cfg

## load tokenizer
if item == "tokenizer" || all
tkr = _load_tokenizer(
cfg; path_tokenizer_config,
path_special_tokens_map,
path_tokenizer)
end
item == "tokenizer" && return tkr

## load model
model = _load_model(
cfg; path_model, kwargs...)

if all
return tkr, model
else
return model
end
end

## Example
using Transformers.TextEncoders

# My files can be found in this directory
dir = "colbert-ir"
textenc, model = load_hgf_pretrained_local(dir)

encoded = encode(textenc, "test it")
output = model(encoded)
22 changes: 7 additions & 15 deletions examples/searching.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
using ColBERT
using CUDA

# create the config
dataroot = "downloads/lotte"
dataset = "lifestyle"
datasplit = "dev"
path = joinpath(dataroot, dataset, datasplit, "short_collection.tsv")

nbits = 2 # encode each dimension with 2 bits

index_root = "experiments/notebook/indexes"
index_name = "short_$(dataset).$(datasplit).$(nbits)bits"
index_path = joinpath(index_root, index_name)

# build the searcher
index_path = "short_collection_index"
searcher = Searcher(index_path)

# load the collection
collection = readlines(searcher.config.collection)

# search for a query
query = "what are white spots on raspberries?"
pids, scores = search(searcher, query, 2)
print(searcher.config.resource_settings.collection.data[pids])
print(collection[pids])

query = "are rabbits easy to housebreak?"
pids, scores = search(searcher, query, 9)
print(searcher.config.resource_settings.collection.data[pids])
pids, scores = search(searcher, query, 1)
print(collection[pids])
23 changes: 8 additions & 15 deletions src/ColBERT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,29 @@ using Transformers
# utils
include("utils/utils.jl")

# datasets
include("data/collection.jl")
include("data/queries.jl")
export Collection, Queries

# config and other infra
include("infra/settings.jl")
include("infra/config.jl")
export RunSettings, TokenizerSettings, ResourceSettings,
DocSettings, QuerySettings, IndexingSettings,
SearchSettings, ColBERTConfig
export ColBERTConfig

# models, document/query tokenizers
include("modelling/tokenization/doc_tokenization.jl")
include("modelling/tokenization/query_tokenization.jl")
include("modelling/checkpoint.jl")
export BaseColBERT, Checkpoint, DocTokenizer, QueryTokenizer
export BaseColBERT, Checkpoint

# indexer
include("indexing/codecs/residual.jl")
include("indexing.jl")
include("indexing/collection_encoder.jl")
include("indexing/index_saver.jl")
include("indexing/collection_indexer.jl")
export Indexer, CollectionIndexer, index
export Indexer, index

# searcher
include("search/strided_tensor.jl")
include("search/index_storage.jl")
include("search/ranking.jl")
include("searching.jl")
export Searcher, search

# loaders and savers
include("loaders.jl")
include("savers.jl")

end
Loading

0 comments on commit 302b68c

Please sign in to comment.