Skip to content

Commit

Permalink
TextVectorStore support save/load.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed Jan 11, 2024
1 parent 4cb5c44 commit f558df9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
7 changes: 7 additions & 0 deletions demo/store_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,12 @@ int main(int argc, const char* argv[]) {
for (const auto& text : similar_texts) {
std::cout << text << std::endl;
}
store->save("./tmp.mnn");
store.reset(TextVectorStore::load("./tmp.mnn"));
store->set_embedding(embedding);
similar_texts = store->search_similar_texts(text, 2);
for (const auto& text : similar_texts) {
std::cout << text << std::endl;
}
return 0;
}
25 changes: 18 additions & 7 deletions src/llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -791,17 +791,28 @@ VARP Bge::gen_position_ids(int seq_len) {
// Embedding end

// TextVectorStore strat

TextVectorStore* TextVectorStore::load(const std::string& path) {
auto vars = Variable::load(path.c_str());
return nullptr;
// TODO
if (vars.size() < 2) {
return nullptr;
}
TextVectorStore* store = new TextVectorStore;
store->vectors_ = vars[0];
for (int i = 1; i < vars.size(); i++) {
const char* txt = vars[i]->readMap<char>();
store->texts_.push_back(txt);
}
return store;
}

void TextVectorStore::save(const std::string& path) {
std::vector<VARP> vars;
vars.push_back(vectors_);
for (auto text : texts_) {
auto text_var = _Const(text.data(), {text.size()}, NHWC, halide_type_of<int8_t>());
vars.push_back(text_var);
}
Variable::save(vars, path.c_str());
// TODO
}

void TextVectorStore::add_text(const std::string& text) {
Expand All @@ -812,6 +823,7 @@ void TextVectorStore::add_text(const std::string& text) {
} else {
vectors_ = _Concat({vectors_, vector}, 0);
}
vectors_.fix(VARP::CONSTANT);
}

void TextVectorStore::add_texts(const std::vector<std::string>& texts) {
Expand All @@ -824,8 +836,7 @@ std::vector<std::string> TextVectorStore::search_similar_texts(const std::string
auto vector = text2vector(text);
auto dist = _Sqrt(_ReduceSum(_Square(vectors_ - vector), {-1}));
auto indices = _Sort(dist, 0, true);
auto ptr = dist->readMap<float>();
auto iptr = indices->readMap<int>();
// auto ptr = dist->readMap<float>();
auto idx_ptr = indices->readMap<int>();
std::vector<std::string> res;
for (int i = 0; i < topk; i++) {
Expand All @@ -848,8 +859,8 @@ void TextVectorStore::bench() {
auto vec = _RandomUnifom(shape1, halide_type_of<float>());
auto start = std::chrono::high_resolution_clock::now();
auto dist = _Sqrt(_ReduceSum(_Square(vectors_ - vec), {-1}));
auto ptr = dist->readMap<float>();
auto indices = _Sort(dist, 0, true);
auto ptr = dist->readMap<float>();
auto iptr = indices->readMap<int>();
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
Expand Down

0 comments on commit f558df9

Please sign in to comment.