Skip to content

Commit

Permalink
feat: add example and test for the embedding, ref #24
Browse files Browse the repository at this point in the history
  • Loading branch information
pminev committed Jan 14, 2025
1 parent 39e541e commit a9c0550
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 50 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ if(AC_LLAMA_BUILD_TESTS OR AC_LLAMA_BUILD_EXAMPLES)
NAME ac-test-data-llama
VERSION 1.0.0
GIT_REPOSITORY https://huggingface.co/alpaca-core/ac-test-data-llama
GIT_TAG 989ab9fe8f85a706453fe1f74a324700132f5881
GIT_TAG 164e484873785c804fe724f88bdb96088b573ebc
)
endif()

Expand Down
84 changes: 43 additions & 41 deletions code/ac/llama/InstanceEmbedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ llama_context_params llamaFromInstanceInitParams(const InstanceEmbedding::InitPa
llamaParams.n_batch = params.batchSize;
llamaParams.n_ubatch = params.ubatchSize;
llamaParams.flash_attn = params.flashAttn;
llamaParams.embeddings = true;
return llamaParams;
}
} // namespace

InstanceEmbedding::InstanceEmbedding(Model& model, InitParams params)
: m_model(model)
, m_sampler(model, {})
, m_params(std::move(params))
, m_lctx(llama_new_context_with_model(model.lmodel(), llamaFromInstanceInitParams(params)), llama_free)
{
if (!m_lctx) {
Expand Down Expand Up @@ -63,14 +65,6 @@ InstanceEmbedding::InstanceEmbedding(Model& model, InitParams params)
InstanceEmbedding::~InstanceEmbedding() = default;

namespace {
llama_batch makeInputBatch(std::span<const Token> tokens) {
// well, llama.cpp does not touch the tokens for input batches, but llama_batch needs them to be non-const
// (mostly for stupid C reasons)
// so... we have to do something evil here
auto nonConstTokens = const_cast<Token*>(tokens.data());
return llama_batch_get_one(nonConstTokens, int32_t(tokens.size()));
}

void normalizeEmbedding(const float * inp, float * out, int n, int embd_norm) {
double sum = 0.0;

Expand Down Expand Up @@ -106,48 +100,58 @@ void normalizeEmbedding(const float * inp, float * out, int n, int embd_norm) {
out[i] = inp[i] * norm;
}
}
}

std::vector<float> InstanceEmbedding::getEmbeddingVector(std::span<const Token> prompt) {
// count number of embeddings
int n_embd_count = 0;
// if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
// for (int k = 0; k < n_prompts; k++) {
// n_embd_count += inputs[k].size();
// }
// } else {
n_embd_count = 1;//n_prompts;
// }
void common_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");

batch.token [batch.n_tokens] = id;
batch.pos [batch.n_tokens] = pos;
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;

// allocate output
const int n_embd = llama_n_embd(m_model.lmodel());
std::vector<float> embeddings(n_embd_count * n_embd, 0);
float* emb = embeddings.data();
batch.n_tokens++;
}

int e = 0;
// final batch
float * out = emb + e * n_embd;
llama_batch batch = makeInputBatch(prompt);
void batch_add_seq(llama_batch& batch, std::span<const Token> tokens, llama_seq_id seq_id) {
size_t n_tokens = tokens.size();
for (size_t i = 0; i < n_tokens; i++) {
common_batch_add(batch, tokens[i], i, { seq_id }, true);
}
}

}

//batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
std::vector<float> InstanceEmbedding::getEmbeddingVector(std::span<const Token> prompt, int32_t normalization) {
const enum llama_pooling_type pooling_type = llama_pooling_type(m_lctx.get());
llama_context* ctx = m_lctx.get();
llama_model* model = m_model.lmodel();
int n_embd_count = 1; // TODO: support multiple prompts

// allocate output
const int n_embd = llama_n_embd(model);
std::vector<float> embeddings(n_embd_count * n_embd, 0);
float* embData = embeddings.data();

llama_batch batch = llama_batch_init(m_params.batchSize, 0, 1);
batch_add_seq(batch, prompt, 0);

// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);

// run model
// LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
// encoder-only model
if (llama_encode(ctx, batch) < 0) {
// LOG_ERR("%s : failed to encode\n", __func__);
LLAMA_LOG(Error, "Failed to encode!");
}
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
// decoder-only model
if (llama_decode(ctx, batch) < 0) {
// LOG_ERR("%s : failed to decode\n", __func__);
LLAMA_LOG(Error, "Failed to decode!");
}
}

Expand All @@ -163,18 +167,16 @@ std::vector<float> InstanceEmbedding::getEmbeddingVector(std::span<const Token>
// try to get token embeddings
embd = llama_get_embeddings_ith(ctx, i);
embd_pos = i;
// GGML_ASSERT(embd != NULL && "failed to get token embeddings");
assert(embd != NULL && "Failed to get token embeddings");
} else {
// try to get sequence embeddings - supported only when pooling_type is not NONE
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
embd_pos = batch.seq_id[i][0];
// GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
assert(embd != NULL && "Failed to get sequence embeddings");
}

float * outRes = out + embd_pos * n_embd;
// TODO: add normalization option
int embd_norm = 0; //params.embd_normalize;
normalizeEmbedding(embd, outRes, n_embd, embd_norm);
float * outRes = embData + embd_pos * n_embd;
normalizeEmbedding(embd, outRes, n_embd, normalization);
}

return embeddings;
Expand Down
6 changes: 5 additions & 1 deletion code/ac/llama/InstanceEmbedding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,18 @@ class AC_LLAMA_EXPORT InstanceEmbedding {
explicit InstanceEmbedding(Model& model, InitParams params);
~InstanceEmbedding();

std::vector<float> getEmbeddingVector(std::span<const Token> prompt);
// Get the embedding vector for the given prompt
// the normalization parameter is used to normalize the embeddings
// values are (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean[default], >2=p-norm)
std::vector<float> getEmbeddingVector(std::span<const Token> prompt, int32_t normalization = 2);

const Model& model() const noexcept { return m_model; }
Sampler& sampler() noexcept { return m_sampler; }

private:
Model& m_model;
Sampler m_sampler;
InitParams m_params;
astl::c_unique_ptr<llama_context> m_lctx;
std::optional<Session> m_session;
};
Expand Down
19 changes: 12 additions & 7 deletions example/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# Copyright (c) Alpaca Core
# SPDX-License-Identifier: MIT
#
add_executable(example-ac-llama-basic e-basic.cpp)
target_link_libraries(example-ac-llama-basic PRIVATE
ac::llama
ac-test-data::llama
ac::jalog
)
set_target_properties(example-ac-llama-basic PROPERTIES FOLDER example)
function (add_example name)
add_executable(example-ac-llama-${name} e-${name}.cpp)
target_link_libraries(example-ac-llama-${name} PRIVATE
ac::llama
ac-test-data::llama
ac::jalog
)
set_target_properties(example-ac-llama-${name} PROPERTIES FOLDER example)
endfunction()

add_example(basic)
add_example(embedding)

CPMAddPackage(gh:alpaca-core/[email protected])
if(TARGET ac-dev::imgui-sdl-app)
Expand Down
67 changes: 67 additions & 0 deletions example/e-embedding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) Alpaca Core
// SPDX-License-Identifier: MIT
//

// trivial example of using alpaca-core's llama embedding API

// llama
#include <ac/llama/Init.hpp>
#include <ac/llama/Model.hpp>
#include <ac/llama/InstanceEmbedding.hpp>

// logging
#include <ac/jalog/Instance.hpp>
#include <ac/jalog/sinks/ColorSink.hpp>

// model source directory
#include "ac-test-data-llama-dir.h"

#include <iostream>
#include <string>

int main() try {
ac::jalog::Instance jl;
jl.setup().add<ac::jalog::sinks::ColorSink>();

// initialize the library
ac::llama::initLibrary();

// load model
std::string modelGguf = AC_TEST_DATA_LLAMA_DIR "/bge-small-en-v1.5-f16.gguf";
ac::llama::Model::Params modelParams;
auto modelLoadProgressCallback = [](float progress) {
const int barWidth = 50;
static float currProgress = 0;
auto delta = int(progress * barWidth) - int(currProgress * barWidth);
for (int i = 0; i < delta; i++) {
std::cout.put('=');
}
currProgress = progress;
if (progress == 1.f) {
std::cout << '\n';
}
return true;
};
auto lmodel = ac::llama::ModelRegistry::getInstance().loadModel(modelGguf, modelLoadProgressCallback, modelParams);
ac::llama::Model model(lmodel, modelParams);

// create inference instance
ac::llama::InstanceEmbedding instance(model, {});

std::string prompt = "The main character in the story loved to eat pineapples.";
std::vector<ac::llama::Token> tokens = model.vocab().tokenize(prompt, true, true);

auto embeddings = instance.getEmbeddingVector(tokens);

std::cout << "Embedding vector for prompt(" << prompt<< "): ";
for (uint64_t i = 0; i < embeddings.size(); i++) {
std::cout << embeddings[i] << ' ';
}
std::cout << std::endl;

return 0;
}
catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}
30 changes: 30 additions & 0 deletions test/t-integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ac/llama/Init.hpp>
#include <ac/llama/Model.hpp>
#include <ac/llama/Instance.hpp>
#include <ac/llama/InstanceEmbedding.hpp>
#include <ac/llama/Session.hpp>
#include <ac/llama/ControlVector.hpp>

Expand Down Expand Up @@ -320,3 +321,32 @@ TEST_CASE("control_vector") {
}
}
}

TEST_CASE("embedding") {
ac::llama::Model::Params iParams = {};
const char* Model_bge_small_en = AC_TEST_DATA_LLAMA_DIR "/bge-small-en-v1.5-f16.gguf";
auto lmodel = ac::llama::ModelRegistry::getInstance().loadModel(Model_bge_small_en, {}, iParams);
ac::llama::Model model(lmodel, iParams);
CHECK(!!model.lmodel());

auto& params = model.params();
CHECK(params.gpu);

CHECK(model.trainCtxLength() == 512);
CHECK_FALSE(model.hasEncoder());

ac::llama::InstanceEmbedding inst(model, {});
auto tokens = model.vocab().tokenize("The main character in the story loved to eat pineapples.", true, true);
auto embeddings = inst.getEmbeddingVector(tokens);
CHECK(embeddings.size() == 384);

std::vector<float> expected = { 0.00723457, 0.0672964, 0.00372222, -0.0458788, 0.00874835, 0.00432054, 0.109124, 0.00175256, 0.0172868, 0.0279001, -0.0223953, -0.00486074, 0.0112226, 0.0423849, 0.0285155, -0.00827027, 0.0247047, 0.0291312, -0.0786626, 0.0228906, 0.00884803, -0.0545553, 0.00242499, -0.0371614, 0.0145663, 0.0217592, -0.0379476, -0.012417, -0.031311, -0.0907524, -0.00270661, 0.0225516, 0.0166742, -0.023172, -0.0234313, 0.0518579, -0.00522299, 0.0011265, 0.00472722, -0.00702098, 0.0576354, 0.00290366, 0.0278902, -0.0283858, -0.00852266, -0.0349532, -0.0258749, 0.00864892, 0.0944385, -0.032376, -0.102357, -0.0570537, -0.0630057, -0.0366031, 0.0250621, 0.098078, 0.0734987, -0.0411082, -0.0521881, 0.00953602, 0.00460035, 0.014422, -0.0135636, 0.0487354, 0.0659704, -0.0510038, -0.0432206, 0.0347124, 0.000337169, 0.00681155, -0.0349383, 0.0462863, 0.0538792, 0.0218382, 0.0313523, 0.0300653, -0.00807435, -0.0203202, -0.0387424, 0.0531275, -0.0327624, 0.0274246, -0.000469622, 0.0148036, -0.0624161, -0.024254, 0.00340036, -0.0639136, -0.0116692, 0.0111668, 0.0197133, -0.0172656, -0.00784806, 0.0131758, -0.0579778, -0.00333637, -0.0446055, -0.0315641, -0.00882497, 0.354434, 0.0259944, -0.00811709, 0.060054, -0.0282549, -0.0194096, 0.0259942, -0.010753, -0.0537825, 0.0373867, 0.0552687, -0.0193146, 0.0116561, -0.00876264, 0.0234502, 0.0116844, 0.05702, 0.0531629, -0.0222655, -0.0866693, 0.0299643, 0.0295443, 0.0653484, -0.0565965, -0.00480344, -0.0103601, -0.0158926, 0.0853524, 0.0103825, 0.0322511, -0.0413097, 0.00330726, -0.0114999, -0.0119125, 0.0362464, 0.0276722, 0.0352711, 0.00796944, -0.0262156, -0.0402713, -0.0239314, -0.0561523, -0.0660272, -0.0442701, -0.0105944, 0.0156493, -0.0800205, 0.0467227, 0.0380684, -0.0314222, 0.109449, -0.031353, 0.0298688, -0.00155366, -0.00118869, 0.019166, -0.005014, 0.0258291, 0.0608314, 0.025612, 0.0432555, -0.010526, 0.0102892, 0.006778, -0.0804542, 0.0300636, 0.0019367, -0.00946688, 0.0633147, 0.00758261, 5.33199e-05, 0.034628, 0.0540261, -0.125455, 0.0102287, 0.00555666, 0.0565227, 0.00660611, 0.0497022, -0.0642718, -0.0175176, 0.0052292, -0.0916462, -0.0293923, 0.035024, 0.0503401, -0.0244895, 0.0903103, -0.007599, 0.039994, -0.0427364, 0.086443, 0.0564919, -0.0789255, -0.0167457, -0.0495721, -0.102541, 0.00512145, 0.00380079, -0.0334622, -0.00113675, -0.0529158, -0.0167595, -0.0920621, -0.0877459, 0.13931, -0.0685575, -0.00105833, 0.0327333, -0.0313494, -0.00404531, -0.0188106, 0.0216038, 0.0198488, 0.0505344, -0.00976201, 0.0336061, 0.0362691, 0.074989, 0.0155995, -0.0351994, 0.0128507, -0.0593599, 0.0247995, -0.265298, -0.0213482, -0.00865759, -0.0900854, -0.021827, 0.0103148, -0.0650073, -0.064416, 0.0544336, -0.0180563, -0.0126009, -0.0752656, 0.0396613, 0.0599272, 0.0281464, 0.0102912, 0.0458024, -0.058047, 0.0391549, 0.0234603, -0.00715374, -0.0155389, 0.0115466, -0.00202032, -0.0387425, 0.00196627, 0.189942, 0.138904, -0.031122, 0.00910502, -0.0201774, -0.00269432, -0.0330239, -0.0526063, 0.0205691, 0.0440849, 0.0738484, -0.0430935, -0.0378577, 0.00628437, 0.0127056, 0.0740211, -0.0536525, -0.0183475, -0.0520914, -0.0588744, 0.0223303, 0.0162849, 0.0259296, 0.0510308, 0.0436266, 0.0286193, -0.00156158, 0.0123141, -0.0173283, -0.030903, -0.0197604, 0.00607057, -0.055449, 0.0341534, -0.069812, 0.00289869, 0.000113235, -0.00571824, 0.00992975, -0.0031352, 0.00464151, -0.00241301, -0.0168796, 0.0110532, -0.0204679, -0.0672177, -0.0340668, -0.0370501, 0.0311332, 0.0710521, 0.0382394, -0.115705, -0.0437406, 0.00240175, -0.0409236, -0.00446289, -0.016308, 0.0365087, 0.0138439, -0.0697056, -0.00489864, 1.96082e-05, -0.00335489, -0.0200612, 0.058619, -2.70922e-05, -0.0262538, -0.0136708, 0.0375921, 0.0739009, -0.278277, 0.0240451, -0.0747427, 0.0138804, -0.00663228, 0.0299832, 0.028293, 0.0287869, -0.0257129, 0.0193498, 0.0975099, -0.0386528, 0.0509279, -0.0456842, -0.0403165, 0.0030311, -0.0409809, 0.017794, 0.0191697, -0.0300541, 0.0511827, 0.0638279, 0.148544, -0.0117107, -0.0472298, -0.0296059, -0.0162564, 0.0123344, -0.0239339, 0.0448291, 0.0605528, 0.0288511, 0.0759243, 0.0195688, 0.0373413, 0.0402353, 0.00830747, 0.000708879, 0.00346375, 0.0104776, -0.0347978, 0.0630426, -0.0580485, -0.0384997, 0.00238404, 0.00442908, -0.0406986, -0.00532351, -0.0112028, -0.0070308, 0.0222813, -0.0732604, 0.0689749, 0.0287737, 0.0242196, -0.0179569, -0.109264, 0.00263097, -0.0182948, -0.0285666, 0.00388148, -0.000162523, 0.00822485, 0.0211785, -0.00316543 };

CHECK(embeddings.size() == expected.size());

for (size_t i = 0; i < embeddings.size(); i++)
{
REQUIRE(embeddings[i] == doctest::Approx(expected[i]).epsilon(0.0001));
}

}

0 comments on commit a9c0550

Please sign in to comment.