Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add embedding instance #29

Merged
merged 6 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ if(AC_LLAMA_BUILD_TESTS OR AC_LLAMA_BUILD_EXAMPLES)
NAME ac-test-data-llama
VERSION 1.0.0
GIT_REPOSITORY https://huggingface.co/alpaca-core/ac-test-data-llama
GIT_TAG 989ab9fe8f85a706453fe1f74a324700132f5881
GIT_TAG 164e484873785c804fe724f88bdb96088b573ebc
)
endif()

Expand Down
1 change: 1 addition & 0 deletions ac-local-plugin/code/LocalLlama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//
#include <ac/llama/Session.hpp>
#include <ac/llama/Instance.hpp>
#include <ac/llama/InstanceEmbedding.hpp>
#include <ac/llama/Init.hpp>
#include <ac/llama/Model.hpp>
#include <ac/llama/AntipromptManager.hpp>
Expand Down
2 changes: 1 addition & 1 deletion ac-local-plugin/example/ep-run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ int main() try {

llama.expectState<schema::StateInstance>();

constexpr std::string prompt = "The first person to";
const std::string prompt = "The first person to";

std::vector<std::string> antiprompts;
antiprompts.push_back("user:"); // change it to "name" to break the token generation with the default input
Expand Down
31 changes: 30 additions & 1 deletion ac-local-plugin/schema/ac/schema/LlamaCpp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ struct StateInitial {
using Outs = std::tuple<>;
};


struct StateModelLoaded {
static constexpr auto id = "model-loaded";
static constexpr auto desc = "Model loaded state";
Expand Down Expand Up @@ -185,6 +184,36 @@ struct StateChat {
using Outs = std::tuple<>;
};

struct StateEmbeddingInstance {
static constexpr auto id = "embedding-instance";
static constexpr auto desc = "Embedding instance state";

struct OpRun {
static inline constexpr std::string_view id = "run";
static inline constexpr std::string_view description = "Run to produce an embedding vector";

struct Params {
Field<std::string> prompt;

template <typename Visitor>
void visitFields(Visitor& v) {
v(prompt, "prompt", "Prompt to generate the embedding for");
}
};

struct Return {
Field<std::vector<float>> result;

template <typename Visitor>
void visitFields(Visitor& v) {
v(result, "result", "Generated result (embedding vector)");
}
};
};

using Ops = std::tuple<OpRun>;
};

struct Interface {
static inline constexpr std::string_view id = "llama.cpp";
static inline constexpr std::string_view desc = "Inference based on our fork of https://github.com/ggerganov/llama.cpp";
Expand Down
2 changes: 2 additions & 0 deletions code/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ target_sources(ac-llama PRIVATE
ac/llama/Sampler.cpp
ac/llama/Instance.hpp
ac/llama/Instance.cpp
ac/llama/InstanceEmbedding.hpp
ac/llama/InstanceEmbedding.cpp
ac/llama/Session.hpp
ac/llama/Session.cpp
ac/llama/AntipromptManager.hpp
Expand Down
163 changes: 163 additions & 0 deletions code/ac/llama/InstanceEmbedding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// Copyright (c) Alpaca Core
// SPDX-License-Identifier: MIT
//
#include "InstanceEmbedding.hpp"
#include "Model.hpp"
#include "LoraAdapter.hpp"
#include "Logging.hpp"
#include "Session.hpp"
#include "ControlVector.hpp"

#include <llama.h>

#include <astl/throw_stdex.hpp>
#include <astl/iile.h>
#include <astl/move.hpp>
#include <astl/sentry.hpp>

#include <cassert>
#include <span>
#include <fstream>

namespace ac::llama {

namespace {
llama_context_params llamaFromInstanceInitParams(const InstanceEmbedding::InitParams& params) {
llama_context_params llamaParams = llama_context_default_params();
llamaParams.n_ctx = params.ctxSize;
llamaParams.n_batch = params.batchSize;
llamaParams.n_ubatch = params.ubatchSize;
llamaParams.flash_attn = params.flashAttn;
llamaParams.embeddings = true;
return llamaParams;
}
} // namespace

InstanceEmbedding::InstanceEmbedding(Model& model, InitParams params)
: m_model(model)
, m_sampler(model, {})
, m_params(std::move(params))
, m_lctx(llama_new_context_with_model(model.lmodel(), llamaFromInstanceInitParams(params)), llama_free)
{
if (!m_lctx) {
throw_ex{} << "Failed to create llama context";
}
assert(model.lmodel() == llama_get_model(m_lctx.get()));

const auto ctxLen = llama_n_ctx(m_lctx.get());
const auto ctxTrain = model.trainCtxLength();
if (ctxLen > ctxTrain) {
LLAMA_LOG(Warning, "Instance requested context length ", ctxLen, " is greater than the model's training context length ", ctxTrain);
}

if (llama_model_has_encoder(m_model.lmodel()) && llama_model_has_decoder(m_model.lmodel())) {
LLAMA_LOG(Error, "Computing embeddings in encoder-decoder models is not supported");
}
}

InstanceEmbedding::~InstanceEmbedding() = default;

namespace {
void normalizeEmbedding(const float * inp, float * out, int n, int embd_norm) {
double sum = 0.0;

switch (embd_norm) {
case -1: // no normalisation
sum = 1.0;
break;
case 0: // max absolute
for (int i = 0; i < n; i++) {
if (sum < std::abs(inp[i])) {
sum = std::abs(inp[i]);
}
}
sum /= 32760.0; // make an int16 range
break;
case 2: // euclidean
for (int i = 0; i < n; i++) {
sum += inp[i] * inp[i];
}
sum = std::sqrt(sum);
break;
default: // p-norm (euclidean is p-norm p=2)
for (int i = 0; i < n; i++) {
sum += std::pow(std::abs(inp[i]), embd_norm);
}
sum = std::pow(sum, 1.0 / embd_norm);
break;
}

const float norm = sum > 0.0 ? 1.0 / sum : 0.0f;

for (int i = 0; i < n; i++) {
out[i] = inp[i] * norm;
}
}

void batchAddSeq(llama_batch& batch, std::span<const Token> tokens, llama_seq_id seq_id) {
for (size_t i = 0; i < tokens.size(); i++) {
batch.token [batch.n_tokens] = tokens[i];
batch.pos [batch.n_tokens] = i;
batch.n_seq_id[batch.n_tokens] = 1;
batch.seq_id[batch.n_tokens][0] = seq_id;
batch.logits [batch.n_tokens] = true;

batch.n_tokens++;
}
}
}

std::vector<float> InstanceEmbedding::getEmbeddingVector(std::span<const Token> prompt, int32_t normalization) {
const enum llama_pooling_type pooling_type = llama_pooling_type(m_lctx.get());
llama_context* ctx = m_lctx.get();
llama_model* model = m_model.lmodel();
int n_embd_count = 1; // TODO: support multiple prompts

// allocate output
const int n_embd = llama_n_embd(model);
std::vector<float> embeddings(n_embd_count * n_embd, 0);
float* embData = embeddings.data();

llama_batch batch = llama_batch_init(m_params.batchSize, 0, 1);
batchAddSeq(batch, prompt, 0);

llama_kv_cache_clear(ctx);

if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
if (llama_encode(ctx, batch) < 0) {
LLAMA_LOG(Error, "Failed to encode!");
}
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
if (llama_decode(ctx, batch) < 0) {
LLAMA_LOG(Error, "Failed to decode!");
}
}

for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
continue;
}

const float * embd = nullptr;
int embd_pos = 0;

if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
// try to get token embeddings
embd = llama_get_embeddings_ith(ctx, i);
embd_pos = i;
assert(embd != NULL && "Failed to get token embeddings");
} else {
// try to get sequence embeddings - supported only when pooling_type is not NONE
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
embd_pos = batch.seq_id[i][0];
assert(embd != NULL && "Failed to get sequence embeddings");
}

float * outRes = embData + embd_pos * n_embd;
normalizeEmbedding(embd, outRes, n_embd, normalization);
}

return embeddings;
}

} // namespace ac::llama
49 changes: 49 additions & 0 deletions code/ac/llama/InstanceEmbedding.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) Alpaca Core
// SPDX-License-Identifier: MIT
//
#pragma once
#include "export.h"
#include "Sampler.hpp"
#include "Session.hpp"
#include <astl/mem_ext.hpp>

#include <vector>
#include <optional>

struct llama_context;

namespace ac::llama {
class Model;
class Session;
class StringSession;
class ControlVector;

class AC_LLAMA_EXPORT InstanceEmbedding {
public:
struct InitParams {
uint32_t ctxSize = 0; // context size for the model (0 = maximum allowed by model)
uint32_t batchSize = 2048; // logical batch size for prompt processing (may be silently truncated to ctxSize)
uint32_t ubatchSize = 512; // physical batch size for prompt processing (0 = batchSize)
bool flashAttn = false; // enable flash attention
};

explicit InstanceEmbedding(Model& model, InitParams params);
~InstanceEmbedding();

// Get the embedding vector for the given prompt
// the normalization parameter is used to normalize the embeddings
// values are (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean[default], >2=p-norm)
std::vector<float> getEmbeddingVector(std::span<const Token> prompt, int32_t normalization = 2);

const Model& model() const noexcept { return m_model; }
Sampler& sampler() noexcept { return m_sampler; }

private:
Model& m_model;
Sampler m_sampler;
InitParams m_params;
astl::c_unique_ptr<llama_context> m_lctx;
std::optional<Session> m_session;
};

} // namespace ac::llama
19 changes: 12 additions & 7 deletions example/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# Copyright (c) Alpaca Core
# SPDX-License-Identifier: MIT
#
add_executable(example-ac-llama-basic e-basic.cpp)
target_link_libraries(example-ac-llama-basic PRIVATE
ac::llama
ac-test-data::llama
ac::jalog
)
set_target_properties(example-ac-llama-basic PROPERTIES FOLDER example)
function (add_example name)
add_executable(example-ac-llama-${name} e-${name}.cpp)
target_link_libraries(example-ac-llama-${name} PRIVATE
ac::llama
ac-test-data::llama
ac::jalog
)
set_target_properties(example-ac-llama-${name} PROPERTIES FOLDER example)
endfunction()

add_example(basic)
add_example(embedding)

CPMAddPackage(gh:alpaca-core/[email protected])
if(TARGET ac-dev::imgui-sdl-app)
Expand Down
67 changes: 67 additions & 0 deletions example/e-embedding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) Alpaca Core
// SPDX-License-Identifier: MIT
//

// trivial example of using alpaca-core's llama embedding API

// llama
#include <ac/llama/Init.hpp>
#include <ac/llama/Model.hpp>
#include <ac/llama/InstanceEmbedding.hpp>

// logging
#include <ac/jalog/Instance.hpp>
#include <ac/jalog/sinks/ColorSink.hpp>

// model source directory
#include "ac-test-data-llama-dir.h"

#include <iostream>
#include <string>

int main() try {
ac::jalog::Instance jl;
jl.setup().add<ac::jalog::sinks::ColorSink>();

// initialize the library
ac::llama::initLibrary();

// load model
std::string modelGguf = AC_TEST_DATA_LLAMA_DIR "/bge-small-en-v1.5-f16.gguf";
ac::llama::Model::Params modelParams;
auto modelLoadProgressCallback = [](float progress) {
const int barWidth = 50;
static float currProgress = 0;
auto delta = int(progress * barWidth) - int(currProgress * barWidth);
for (int i = 0; i < delta; i++) {
std::cout.put('=');
}
currProgress = progress;
if (progress == 1.f) {
std::cout << '\n';
}
return true;
};
auto lmodel = ac::llama::ModelRegistry::getInstance().loadModel(modelGguf, modelLoadProgressCallback, modelParams);
ac::llama::Model model(lmodel, modelParams);

// create inference instance
ac::llama::InstanceEmbedding instance(model, {});

std::string prompt = "The main character in the story loved to eat pineapples.";
std::vector<ac::llama::Token> tokens = model.vocab().tokenize(prompt, true, true);

auto embeddings = instance.getEmbeddingVector(tokens);

std::cout << "Embedding vector for prompt(" << prompt<< "): ";
for (uint64_t i = 0; i < embeddings.size(); i++) {
std::cout << embeddings[i] << ' ';
}
std::cout << std::endl;

return 0;
}
catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}
Loading