Skip to content

Commit

Permalink
Writing the binarize function to compress residuals into nbits bits.
Browse files Browse the repository at this point in the history
  • Loading branch information
codetalker7 committed Jun 18, 2024
1 parent b7e7ac1 commit e1b2036
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions src/indexing/codecs/residual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,40 @@ function compress_into_codes(codec::ResidualCodec, embs::Matrix{Float64})

codes
end

function binarize(codec::ResidualCodec, residuals::Matrix{Float64})
dim = codec.config.doc_settings.dim
nbits = codec.config.indexing_settings.nbits
num_embeddings = size(residuals)[2]

if dim % (nbits * 8) != 0
error("The embeddings dimension must be a multiple of nbits * 8!")
end

# need to subtract one here, to preserve the number of options (2 ^ nbits)
bucket_indices = (x -> searchsortedfirst(codec.bucket_cutoffs, x)).(residuals) .- 1 # torch.bucketize
bucket_indices = stack([bucket_indices for i in 1:nbits], dims = 1) # add an nbits-wide extra dimension
positionbits = fill(1, (nbits, 1, 1))
for i in 1:nbits
positionbits[i, :, :] .= 1 << (i - 1)
end

bucket_indices = Int.(floor.(bucket_indices ./ positionbits)) # divide by 2^bit for each bit position
bucket_indices = bucket_indices .& 1 # apply mod 1 to binarize
residuals_packed = reinterpret(UInt8, BitArray(vec(bucket_indices)).chunks) # flatten out the bits, and pack them into UInt8
residuals_packed = reshape(residuals_packed, (Int(dim / 8) * nbits, num_embeddings))# reshape back to get compressions for each embedding
end

# function compress(codec::ResidualCodec, embs::Matrix{Float64})
# codes, residuals = Vector{Int}(), Vector{Matrix{Float64}}()
#
# offset = 1
# bsize = 1 << 18
# while (offset <= size(embs[2])) # batch on second dimension
# batch = embs[:, offset:min(size(embs[2]), offset + bsize - 1)]
# codes_ = compress_into_codes(codec, batch) # get centroid codes
# centroids_ = codec.centroids[:, codes_] # get corresponding centroids
# residuals_ = batch - centroids_
# append(codes, codes_)
# end
# end

0 comments on commit e1b2036

Please sign in to comment.