From a9c0550a21f0a3f7a5c94e7a7664254f6c96a261 Mon Sep 17 00:00:00 2001 From: Plamen Minev Date: Tue, 14 Jan 2025 16:39:12 +0200 Subject: [PATCH] feat: add example and test for the embedding, ref #24 --- CMakeLists.txt | 2 +- code/ac/llama/InstanceEmbedding.cpp | 84 +++++++++++++++-------------- code/ac/llama/InstanceEmbedding.hpp | 6 ++- example/CMakeLists.txt | 19 ++++--- example/e-embedding.cpp | 67 +++++++++++++++++++++++ test/t-integration.cpp | 30 +++++++++++ 6 files changed, 158 insertions(+), 50 deletions(-) create mode 100644 example/e-embedding.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 628cd17..725dd81 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,7 +71,7 @@ if(AC_LLAMA_BUILD_TESTS OR AC_LLAMA_BUILD_EXAMPLES) NAME ac-test-data-llama VERSION 1.0.0 GIT_REPOSITORY https://huggingface.co/alpaca-core/ac-test-data-llama - GIT_TAG 989ab9fe8f85a706453fe1f74a324700132f5881 + GIT_TAG 164e484873785c804fe724f88bdb96088b573ebc ) endif() diff --git a/code/ac/llama/InstanceEmbedding.cpp b/code/ac/llama/InstanceEmbedding.cpp index 7518837..1a04725 100644 --- a/code/ac/llama/InstanceEmbedding.cpp +++ b/code/ac/llama/InstanceEmbedding.cpp @@ -28,6 +28,7 @@ llama_context_params llamaFromInstanceInitParams(const InstanceEmbedding::InitPa llamaParams.n_batch = params.batchSize; llamaParams.n_ubatch = params.ubatchSize; llamaParams.flash_attn = params.flashAttn; + llamaParams.embeddings = true; return llamaParams; } } // namespace @@ -35,6 +36,7 @@ llama_context_params llamaFromInstanceInitParams(const InstanceEmbedding::InitPa InstanceEmbedding::InstanceEmbedding(Model& model, InitParams params) : m_model(model) , m_sampler(model, {}) + , m_params(std::move(params)) , m_lctx(llama_new_context_with_model(model.lmodel(), llamaFromInstanceInitParams(params)), llama_free) { if (!m_lctx) { @@ -63,14 +65,6 @@ InstanceEmbedding::InstanceEmbedding(Model& model, InitParams params) InstanceEmbedding::~InstanceEmbedding() = default; namespace { -llama_batch makeInputBatch(std::span tokens) { - // well, llama.cpp does not touch the tokens for input batches, but llama_batch needs them to be non-const - // (mostly for stupid C reasons) - // so... we have to do something evil here - auto nonConstTokens = const_cast(tokens.data()); - return llama_batch_get_one(nonConstTokens, int32_t(tokens.size())); -} - void normalizeEmbedding(const float * inp, float * out, int n, int embd_norm) { double sum = 0.0; @@ -106,48 +100,58 @@ void normalizeEmbedding(const float * inp, float * out, int n, int embd_norm) { out[i] = inp[i] * norm; } } -} -std::vector InstanceEmbedding::getEmbeddingVector(std::span prompt) { - // count number of embeddings - int n_embd_count = 0; - // if (pooling_type == LLAMA_POOLING_TYPE_NONE) { - // for (int k = 0; k < n_prompts; k++) { - // n_embd_count += inputs[k].size(); - // } - // } else { - n_embd_count = 1;//n_prompts; - // } +void common_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + const std::vector & seq_ids, + bool logits) { + GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); + + batch.token [batch.n_tokens] = id; + batch.pos [batch.n_tokens] = pos; + batch.n_seq_id[batch.n_tokens] = seq_ids.size(); + for (size_t i = 0; i < seq_ids.size(); ++i) { + batch.seq_id[batch.n_tokens][i] = seq_ids[i]; + } + batch.logits [batch.n_tokens] = logits; - // allocate output - const int n_embd = llama_n_embd(m_model.lmodel()); - std::vector embeddings(n_embd_count * n_embd, 0); - float* emb = embeddings.data(); + batch.n_tokens++; +} - int e = 0; - // final batch - float * out = emb + e * n_embd; - llama_batch batch = makeInputBatch(prompt); +void batch_add_seq(llama_batch& batch, std::span tokens, llama_seq_id seq_id) { + size_t n_tokens = tokens.size(); + for (size_t i = 0; i < n_tokens; i++) { + common_batch_add(batch, tokens[i], i, { seq_id }, true); + } +} + +} - //batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); +std::vector InstanceEmbedding::getEmbeddingVector(std::span prompt, int32_t normalization) { const enum llama_pooling_type pooling_type = llama_pooling_type(m_lctx.get()); llama_context* ctx = m_lctx.get(); llama_model* model = m_model.lmodel(); + int n_embd_count = 1; // TODO: support multiple prompts + + // allocate output + const int n_embd = llama_n_embd(model); + std::vector embeddings(n_embd_count * n_embd, 0); + float* embData = embeddings.data(); + + llama_batch batch = llama_batch_init(m_params.batchSize, 0, 1); + batch_add_seq(batch, prompt, 0); - // clear previous kv_cache values (irrelevant for embeddings) llama_kv_cache_clear(ctx); - // run model - // LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { - // encoder-only model if (llama_encode(ctx, batch) < 0) { - // LOG_ERR("%s : failed to encode\n", __func__); + LLAMA_LOG(Error, "Failed to encode!"); } } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { - // decoder-only model if (llama_decode(ctx, batch) < 0) { - // LOG_ERR("%s : failed to decode\n", __func__); + LLAMA_LOG(Error, "Failed to decode!"); } } @@ -163,18 +167,16 @@ std::vector InstanceEmbedding::getEmbeddingVector(std::span // try to get token embeddings embd = llama_get_embeddings_ith(ctx, i); embd_pos = i; - // GGML_ASSERT(embd != NULL && "failed to get token embeddings"); + assert(embd != NULL && "Failed to get token embeddings"); } else { // try to get sequence embeddings - supported only when pooling_type is not NONE embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); embd_pos = batch.seq_id[i][0]; - // GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); + assert(embd != NULL && "Failed to get sequence embeddings"); } - float * outRes = out + embd_pos * n_embd; - // TODO: add normalization option - int embd_norm = 0; //params.embd_normalize; - normalizeEmbedding(embd, outRes, n_embd, embd_norm); + float * outRes = embData + embd_pos * n_embd; + normalizeEmbedding(embd, outRes, n_embd, normalization); } return embeddings; diff --git a/code/ac/llama/InstanceEmbedding.hpp b/code/ac/llama/InstanceEmbedding.hpp index 4ee6da1..27bc839 100644 --- a/code/ac/llama/InstanceEmbedding.hpp +++ b/code/ac/llama/InstanceEmbedding.hpp @@ -29,7 +29,10 @@ class AC_LLAMA_EXPORT InstanceEmbedding { explicit InstanceEmbedding(Model& model, InitParams params); ~InstanceEmbedding(); - std::vector getEmbeddingVector(std::span prompt); + // Get the embedding vector for the given prompt + // the normalization parameter is used to normalize the embeddings + // values are (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean[default], >2=p-norm) + std::vector getEmbeddingVector(std::span prompt, int32_t normalization = 2); const Model& model() const noexcept { return m_model; } Sampler& sampler() noexcept { return m_sampler; } @@ -37,6 +40,7 @@ class AC_LLAMA_EXPORT InstanceEmbedding { private: Model& m_model; Sampler m_sampler; + InitParams m_params; astl::c_unique_ptr m_lctx; std::optional m_session; }; diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 8817fe1..8eaf19c 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -1,13 +1,18 @@ # Copyright (c) Alpaca Core # SPDX-License-Identifier: MIT # -add_executable(example-ac-llama-basic e-basic.cpp) -target_link_libraries(example-ac-llama-basic PRIVATE - ac::llama - ac-test-data::llama - ac::jalog -) -set_target_properties(example-ac-llama-basic PROPERTIES FOLDER example) +function (add_example name) + add_executable(example-ac-llama-${name} e-${name}.cpp) + target_link_libraries(example-ac-llama-${name} PRIVATE + ac::llama + ac-test-data::llama + ac::jalog + ) + set_target_properties(example-ac-llama-${name} PROPERTIES FOLDER example) +endfunction() + +add_example(basic) +add_example(embedding) CPMAddPackage(gh:alpaca-core/helper-imgui-sdl@1.0.0) if(TARGET ac-dev::imgui-sdl-app) diff --git a/example/e-embedding.cpp b/example/e-embedding.cpp new file mode 100644 index 0000000..715ac8c --- /dev/null +++ b/example/e-embedding.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Alpaca Core +// SPDX-License-Identifier: MIT +// + +// trivial example of using alpaca-core's llama embedding API + +// llama +#include +#include +#include + +// logging +#include +#include + +// model source directory +#include "ac-test-data-llama-dir.h" + +#include +#include + +int main() try { + ac::jalog::Instance jl; + jl.setup().add(); + + // initialize the library + ac::llama::initLibrary(); + + // load model + std::string modelGguf = AC_TEST_DATA_LLAMA_DIR "/bge-small-en-v1.5-f16.gguf"; + ac::llama::Model::Params modelParams; + auto modelLoadProgressCallback = [](float progress) { + const int barWidth = 50; + static float currProgress = 0; + auto delta = int(progress * barWidth) - int(currProgress * barWidth); + for (int i = 0; i < delta; i++) { + std::cout.put('='); + } + currProgress = progress; + if (progress == 1.f) { + std::cout << '\n'; + } + return true; + }; + auto lmodel = ac::llama::ModelRegistry::getInstance().loadModel(modelGguf, modelLoadProgressCallback, modelParams); + ac::llama::Model model(lmodel, modelParams); + + // create inference instance + ac::llama::InstanceEmbedding instance(model, {}); + + std::string prompt = "The main character in the story loved to eat pineapples."; + std::vector tokens = model.vocab().tokenize(prompt, true, true); + + auto embeddings = instance.getEmbeddingVector(tokens); + + std::cout << "Embedding vector for prompt(" << prompt<< "): "; + for (uint64_t i = 0; i < embeddings.size(); i++) { + std::cout << embeddings[i] << ' '; + } + std::cout << std::endl; + + return 0; +} +catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + return 1; +} diff --git a/test/t-integration.cpp b/test/t-integration.cpp index 7490092..3d33062 100644 --- a/test/t-integration.cpp +++ b/test/t-integration.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -320,3 +321,32 @@ TEST_CASE("control_vector") { } } } + +TEST_CASE("embedding") { + ac::llama::Model::Params iParams = {}; + const char* Model_bge_small_en = AC_TEST_DATA_LLAMA_DIR "/bge-small-en-v1.5-f16.gguf"; + auto lmodel = ac::llama::ModelRegistry::getInstance().loadModel(Model_bge_small_en, {}, iParams); + ac::llama::Model model(lmodel, iParams); + CHECK(!!model.lmodel()); + + auto& params = model.params(); + CHECK(params.gpu); + + CHECK(model.trainCtxLength() == 512); + CHECK_FALSE(model.hasEncoder()); + + ac::llama::InstanceEmbedding inst(model, {}); + auto tokens = model.vocab().tokenize("The main character in the story loved to eat pineapples.", true, true); + auto embeddings = inst.getEmbeddingVector(tokens); + CHECK(embeddings.size() == 384); + + std::vector expected = { 0.00723457, 0.0672964, 0.00372222, -0.0458788, 0.00874835, 0.00432054, 0.109124, 0.00175256, 0.0172868, 0.0279001, -0.0223953, -0.00486074, 0.0112226, 0.0423849, 0.0285155, -0.00827027, 0.0247047, 0.0291312, -0.0786626, 0.0228906, 0.00884803, -0.0545553, 0.00242499, -0.0371614, 0.0145663, 0.0217592, -0.0379476, -0.012417, -0.031311, -0.0907524, -0.00270661, 0.0225516, 0.0166742, -0.023172, -0.0234313, 0.0518579, -0.00522299, 0.0011265, 0.00472722, -0.00702098, 0.0576354, 0.00290366, 0.0278902, -0.0283858, -0.00852266, -0.0349532, -0.0258749, 0.00864892, 0.0944385, -0.032376, -0.102357, -0.0570537, -0.0630057, -0.0366031, 0.0250621, 0.098078, 0.0734987, -0.0411082, -0.0521881, 0.00953602, 0.00460035, 0.014422, -0.0135636, 0.0487354, 0.0659704, -0.0510038, -0.0432206, 0.0347124, 0.000337169, 0.00681155, -0.0349383, 0.0462863, 0.0538792, 0.0218382, 0.0313523, 0.0300653, -0.00807435, -0.0203202, -0.0387424, 0.0531275, -0.0327624, 0.0274246, -0.000469622, 0.0148036, -0.0624161, -0.024254, 0.00340036, -0.0639136, -0.0116692, 0.0111668, 0.0197133, -0.0172656, -0.00784806, 0.0131758, -0.0579778, -0.00333637, -0.0446055, -0.0315641, -0.00882497, 0.354434, 0.0259944, -0.00811709, 0.060054, -0.0282549, -0.0194096, 0.0259942, -0.010753, -0.0537825, 0.0373867, 0.0552687, -0.0193146, 0.0116561, -0.00876264, 0.0234502, 0.0116844, 0.05702, 0.0531629, -0.0222655, -0.0866693, 0.0299643, 0.0295443, 0.0653484, -0.0565965, -0.00480344, -0.0103601, -0.0158926, 0.0853524, 0.0103825, 0.0322511, -0.0413097, 0.00330726, -0.0114999, -0.0119125, 0.0362464, 0.0276722, 0.0352711, 0.00796944, -0.0262156, -0.0402713, -0.0239314, -0.0561523, -0.0660272, -0.0442701, -0.0105944, 0.0156493, -0.0800205, 0.0467227, 0.0380684, -0.0314222, 0.109449, -0.031353, 0.0298688, -0.00155366, -0.00118869, 0.019166, -0.005014, 0.0258291, 0.0608314, 0.025612, 0.0432555, -0.010526, 0.0102892, 0.006778, -0.0804542, 0.0300636, 0.0019367, -0.00946688, 0.0633147, 0.00758261, 5.33199e-05, 0.034628, 0.0540261, -0.125455, 0.0102287, 0.00555666, 0.0565227, 0.00660611, 0.0497022, -0.0642718, -0.0175176, 0.0052292, -0.0916462, -0.0293923, 0.035024, 0.0503401, -0.0244895, 0.0903103, -0.007599, 0.039994, -0.0427364, 0.086443, 0.0564919, -0.0789255, -0.0167457, -0.0495721, -0.102541, 0.00512145, 0.00380079, -0.0334622, -0.00113675, -0.0529158, -0.0167595, -0.0920621, -0.0877459, 0.13931, -0.0685575, -0.00105833, 0.0327333, -0.0313494, -0.00404531, -0.0188106, 0.0216038, 0.0198488, 0.0505344, -0.00976201, 0.0336061, 0.0362691, 0.074989, 0.0155995, -0.0351994, 0.0128507, -0.0593599, 0.0247995, -0.265298, -0.0213482, -0.00865759, -0.0900854, -0.021827, 0.0103148, -0.0650073, -0.064416, 0.0544336, -0.0180563, -0.0126009, -0.0752656, 0.0396613, 0.0599272, 0.0281464, 0.0102912, 0.0458024, -0.058047, 0.0391549, 0.0234603, -0.00715374, -0.0155389, 0.0115466, -0.00202032, -0.0387425, 0.00196627, 0.189942, 0.138904, -0.031122, 0.00910502, -0.0201774, -0.00269432, -0.0330239, -0.0526063, 0.0205691, 0.0440849, 0.0738484, -0.0430935, -0.0378577, 0.00628437, 0.0127056, 0.0740211, -0.0536525, -0.0183475, -0.0520914, -0.0588744, 0.0223303, 0.0162849, 0.0259296, 0.0510308, 0.0436266, 0.0286193, -0.00156158, 0.0123141, -0.0173283, -0.030903, -0.0197604, 0.00607057, -0.055449, 0.0341534, -0.069812, 0.00289869, 0.000113235, -0.00571824, 0.00992975, -0.0031352, 0.00464151, -0.00241301, -0.0168796, 0.0110532, -0.0204679, -0.0672177, -0.0340668, -0.0370501, 0.0311332, 0.0710521, 0.0382394, -0.115705, -0.0437406, 0.00240175, -0.0409236, -0.00446289, -0.016308, 0.0365087, 0.0138439, -0.0697056, -0.00489864, 1.96082e-05, -0.00335489, -0.0200612, 0.058619, -2.70922e-05, -0.0262538, -0.0136708, 0.0375921, 0.0739009, -0.278277, 0.0240451, -0.0747427, 0.0138804, -0.00663228, 0.0299832, 0.028293, 0.0287869, -0.0257129, 0.0193498, 0.0975099, -0.0386528, 0.0509279, -0.0456842, -0.0403165, 0.0030311, -0.0409809, 0.017794, 0.0191697, -0.0300541, 0.0511827, 0.0638279, 0.148544, -0.0117107, -0.0472298, -0.0296059, -0.0162564, 0.0123344, -0.0239339, 0.0448291, 0.0605528, 0.0288511, 0.0759243, 0.0195688, 0.0373413, 0.0402353, 0.00830747, 0.000708879, 0.00346375, 0.0104776, -0.0347978, 0.0630426, -0.0580485, -0.0384997, 0.00238404, 0.00442908, -0.0406986, -0.00532351, -0.0112028, -0.0070308, 0.0222813, -0.0732604, 0.0689749, 0.0287737, 0.0242196, -0.0179569, -0.109264, 0.00263097, -0.0182948, -0.0285666, 0.00388148, -0.000162523, 0.00822485, 0.0211785, -0.00316543 }; + + CHECK(embeddings.size() == expected.size()); + + for (size_t i = 0; i < embeddings.size(); i++) + { + REQUIRE(embeddings[i] == doctest::Approx(expected[i]).epsilon(0.0001)); + } + +}