-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add example and test for the embedding, ref #24
- Loading branch information
Showing
6 changed files
with
158 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,18 @@ | ||
# Copyright (c) Alpaca Core | ||
# SPDX-License-Identifier: MIT | ||
# | ||
add_executable(example-ac-llama-basic e-basic.cpp) | ||
target_link_libraries(example-ac-llama-basic PRIVATE | ||
ac::llama | ||
ac-test-data::llama | ||
ac::jalog | ||
) | ||
set_target_properties(example-ac-llama-basic PROPERTIES FOLDER example) | ||
function (add_example name) | ||
add_executable(example-ac-llama-${name} e-${name}.cpp) | ||
target_link_libraries(example-ac-llama-${name} PRIVATE | ||
ac::llama | ||
ac-test-data::llama | ||
ac::jalog | ||
) | ||
set_target_properties(example-ac-llama-${name} PROPERTIES FOLDER example) | ||
endfunction() | ||
|
||
add_example(basic) | ||
add_example(embedding) | ||
|
||
CPMAddPackage(gh:alpaca-core/[email protected]) | ||
if(TARGET ac-dev::imgui-sdl-app) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
// Copyright (c) Alpaca Core | ||
// SPDX-License-Identifier: MIT | ||
// | ||
|
||
// trivial example of using alpaca-core's llama embedding API | ||
|
||
// llama | ||
#include <ac/llama/Init.hpp> | ||
#include <ac/llama/Model.hpp> | ||
#include <ac/llama/InstanceEmbedding.hpp> | ||
|
||
// logging | ||
#include <ac/jalog/Instance.hpp> | ||
#include <ac/jalog/sinks/ColorSink.hpp> | ||
|
||
// model source directory | ||
#include "ac-test-data-llama-dir.h" | ||
|
||
#include <iostream> | ||
#include <string> | ||
|
||
int main() try { | ||
ac::jalog::Instance jl; | ||
jl.setup().add<ac::jalog::sinks::ColorSink>(); | ||
|
||
// initialize the library | ||
ac::llama::initLibrary(); | ||
|
||
// load model | ||
std::string modelGguf = AC_TEST_DATA_LLAMA_DIR "/bge-small-en-v1.5-f16.gguf"; | ||
ac::llama::Model::Params modelParams; | ||
auto modelLoadProgressCallback = [](float progress) { | ||
const int barWidth = 50; | ||
static float currProgress = 0; | ||
auto delta = int(progress * barWidth) - int(currProgress * barWidth); | ||
for (int i = 0; i < delta; i++) { | ||
std::cout.put('='); | ||
} | ||
currProgress = progress; | ||
if (progress == 1.f) { | ||
std::cout << '\n'; | ||
} | ||
return true; | ||
}; | ||
auto lmodel = ac::llama::ModelRegistry::getInstance().loadModel(modelGguf, modelLoadProgressCallback, modelParams); | ||
ac::llama::Model model(lmodel, modelParams); | ||
|
||
// create inference instance | ||
ac::llama::InstanceEmbedding instance(model, {}); | ||
|
||
std::string prompt = "The main character in the story loved to eat pineapples."; | ||
std::vector<ac::llama::Token> tokens = model.vocab().tokenize(prompt, true, true); | ||
|
||
auto embeddings = instance.getEmbeddingVector(tokens); | ||
|
||
std::cout << "Embedding vector for prompt(" << prompt<< "): "; | ||
for (uint64_t i = 0; i < embeddings.size(); i++) { | ||
std::cout << embeddings[i] << ' '; | ||
} | ||
std::cout << std::endl; | ||
|
||
return 0; | ||
} | ||
catch (const std::exception& e) { | ||
std::cerr << "Error: " << e.what() << std::endl; | ||
return 1; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters