diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index ed547e9..9e0e7fa 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -46,16 +46,20 @@ function binarize(codec::ResidualCodec, residuals::Matrix{Float64}) residuals_packed = reshape(residuals_packed, (Int(dim / 8) * nbits, num_embeddings)) # reshape back to get compressions for each embedding end -# function compress(codec::ResidualCodec, embs::Matrix{Float64}) -# codes, residuals = Vector{Int}(), Vector{Matrix{Float64}}() -# -# offset = 1 -# bsize = 1 << 18 -# while (offset <= size(embs[2])) # batch on second dimension -# batch = embs[:, offset:min(size(embs[2]), offset + bsize - 1)] -# codes_ = compress_into_codes(codec, batch) # get centroid codes -# centroids_ = codec.centroids[:, codes_] # get corresponding centroids -# residuals_ = batch - centroids_ -# append(codes, codes_) -# end -# end +function compress(codec::ResidualCodec, embs::Matrix{Float64}) + codes, residuals = Vector{Int}(), Vector{Matrix{UInt8}}() + + offset = 1 + bsize = 1 << 18 + while (offset <= size(embs[2])) # batch on second dimension + batch = embs[:, offset:min(size(embs)[2], offset + bsize - 1)] + codes_ = compress_into_codes(codec, batch) # get centroid codes + centroids_ = codec.centroids[:, codes_] # get corresponding centroids + residuals_ = batch - centroids_ + append!(codes, codes_) + push!(residuals, binarize(codec, residuals_)) + end + residuals = cat(residuals..., dims = 2) + + codes, residuals +end