From 84a52f401fde86db477c55e839f167bb8f464fb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=81=E8=A1=8C?= Date: Wed, 6 Dec 2023 16:40:25 +0800 Subject: [PATCH] build qwen-1.8b apk. --- README.md | 18 ++-- android/.idea/compiler.xml | 2 +- android/app/src/main/assets/tokenizer | 1 - .../main/java/com/mnn/llm/Conversation.java | 9 +- .../main/java/com/mnn/llm/DownloadModel.java | 6 +- .../main/java/com/mnn/llm/MainActivity.java | 56 +++++------ .../DownloadRecyclerView.java | 4 +- .../ConversationRecyclerView.java | 1 - .../com/mnn/llm/recylcerchat/HolderYou.java | 11 +-- android/app/src/main/jni/CMakeLists.txt | 1 + android/app/src/main/jni/llm_mnn_jni.cpp | 15 +-- .../app/src/main/res/layout/activity_main.xml | 2 +- .../src/main/res/layout/layout_holder_you.xml | 16 +--- android/app/src/main/res/values/models.xml | 87 +++++------------ android/app/src/main/res/values/strings.xml | 4 +- android/gradle.properties | 2 +- include/llm.hpp | 7 +- src/llm.cpp | 96 +++++++++++-------- 18 files changed, 155 insertions(+), 183 deletions(-) delete mode 120000 android/app/src/main/assets/tokenizer diff --git a/README.md b/README.md index c4a7d574..1764e906 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ llm模型导出onnx模型请使用[llm-export](https://github.com/wangzhaode/llm 其他版本: - Qwen-1_8B-Chat-int8:[![Download][download-qwen-1.8b-mnn-int8]][release-qwen-1.8b-mnn-int8] +- Android APK: [![Download][download-qwen-1.8b-apk]][release-qwen-1.8b-apk] [download-chatglm-6b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/chatglm-6b-onnx/total [download-chatglm2-6b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/chatglm2-6b-onnx/total @@ -51,6 +52,8 @@ llm模型导出onnx模型请使用[llm-export](https://github.com/wangzhaode/llm [download-llama2-7b-chat-mnn]: https://img.shields.io/github/downloads/wangzhaode/mnn-llm/llama2-7b-chat-mnn/total [download-qwen-1.8b-mnn]: https://img.shields.io/github/downloads/wangzhaode/mnn-llm/qwen-1.8b-mnn/total [download-qwen-1.8b-mnn-int8]: https://img.shields.io/github/downloads/wangzhaode/mnn-llm/qwen-1.8b-mnn-int8/total +[download-qwen-1.8b-apk]: https://img.shields.io/github/downloads/wangzhaode/mnn-llm/qwen-1.8b-apk/total + [release-chatglm-6b-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/chatglm-6b-mnn [release-chatglm2-6b-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/chatglm2-6b-mnn [release-chatglm3-6b-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/chatglm3-6b-mnn @@ -60,6 +63,8 @@ llm模型导出onnx模型请使用[llm-export](https://github.com/wangzhaode/llm [release-llama2-7b-chat-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/llama2-7b-chat-mnn [release-qwen-1.8b-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/qwen-1.8b-mnn [release-qwen-1.8b-mnn-int8]: https://github.com/wangzhaode/mnn-llm/releases/tag/qwen-1.8b-mnn-int8 +[release-qwen-1.8b-apk]: https://github.com/wangzhaode/mnn-llm/releases/tag/qwen-1.8b-apk + ### 速度 @@ -157,12 +162,13 @@ adb shell "cd /data/local/tmp && export LD_LIBRARY_PATH=. && ./cli_demo -m model ## Reference -- [chatglm-6b](https://huggingface.co/THUDM/chatglm-6b) -- [chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) -- [codegeex2-6b](https://huggingface.co/THUDM/codegeex2-6b) -- [Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) -- [Qwen-7B-Chat](https://huggingface.co/tangger/Qwen-7B-Chat) +- [chatglm-6b](https://modelscope.cn/models/ZhipuAI/chatglm-6b/summary) +- [chatglm2-6b](https://modelscope.cn/models/ZhipuAI/chatglm2-6b/summary) +- [chatglm3-6b](https://modelscope.cn/models/ZhipuAI/chatglm3-6b/summary) +- [codegeex2-6b](https://modelscope.cn/models/ZhipuAI/codegeex2-6b/summary) +- [Baichuan2-7B-Chat](https://modelscope.cn/models/baichuan-inc/baichuan-7B/summary) +- [Qwen-7B-Chat](https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary) +- [Qwen-1.8B-Chat](https://modelscope.cn/models/qwen/Qwen-1_8B-Chat/summary) - [cpp-httplib](https://github.com/yhirose/cpp-httplib) - [chatgpt-web](https://github.com/xqdoo00o/chatgpt-web) -- [cppjieba](https://github.com/yanyiwu/cppjieba) - [ChatViewDemo](https://github.com/BrettFX/ChatViewDemo) diff --git a/android/.idea/compiler.xml b/android/.idea/compiler.xml index 443b5d2b..61a9130c 100644 --- a/android/.idea/compiler.xml +++ b/android/.idea/compiler.xml @@ -1,6 +1,6 @@ - + \ No newline at end of file diff --git a/android/app/src/main/assets/tokenizer b/android/app/src/main/assets/tokenizer deleted file mode 120000 index 7496190d..00000000 --- a/android/app/src/main/assets/tokenizer +++ /dev/null @@ -1 +0,0 @@ -../../../../../resource/tokenizer \ No newline at end of file diff --git a/android/app/src/main/java/com/mnn/llm/Conversation.java b/android/app/src/main/java/com/mnn/llm/Conversation.java index e45e568f..8a017513 100644 --- a/android/app/src/main/java/com/mnn/llm/Conversation.java +++ b/android/app/src/main/java/com/mnn/llm/Conversation.java @@ -31,7 +31,7 @@ public class Conversation extends BaseActivity { private Button send; private DateFormat mDateFormat; private Chat mChat; - private boolean mHistory = false; + private boolean mHistory = true; @Override protected void onCreate(Bundle savedInstanceState) { @@ -138,12 +138,16 @@ public boolean onCreateOptionsMenu(Menu menu) { @Override public boolean onOptionsItemSelected(MenuItem item) { + /* if (mHistory) { Toast.makeText(getBaseContext(), "关闭上下文", Toast.LENGTH_SHORT).show(); } else { Toast.makeText(getBaseContext(), "打开上下文", Toast.LENGTH_SHORT).show(); } mHistory = !mHistory; + */ + Toast.makeText(getBaseContext(), "清空记忆", Toast.LENGTH_SHORT).show(); + mChat.Reset(); return true; } } @@ -168,7 +172,7 @@ public void run() { System.out.println("[MNN_DEBUG] start response\n"); while (!last_response.contains("")) { try { - Thread.sleep(200); + Thread.sleep(50); } catch (Exception e) {} String response = new String(mChat.Response()); if (response.equals(last_response)) { @@ -177,6 +181,7 @@ public void run() { last_response = response; } Message msg = new Message(); + System.out.println("[MNN_DEBUG] " + response); msg.obj = response.replaceFirst("", ""); mHandler.sendMessage(msg); } diff --git a/android/app/src/main/java/com/mnn/llm/DownloadModel.java b/android/app/src/main/java/com/mnn/llm/DownloadModel.java index cb1f25b1..cbe8372e 100644 --- a/android/app/src/main/java/com/mnn/llm/DownloadModel.java +++ b/android/app/src/main/java/com/mnn/llm/DownloadModel.java @@ -2,11 +2,8 @@ import android.app.AlertDialog; import android.content.DialogInterface; -import android.content.Intent; import android.graphics.Color; import android.os.Bundle; -import android.os.Handler; -import android.os.Message; import android.view.View; import android.widget.Button; @@ -29,8 +26,7 @@ protected void onCreate(Bundle savedInstanceState) { mDownloadAll = (Button)findViewById(R.id.download_all); // init Data String[] modelArray = this.getResources().getStringArray(R.array.model_list); - int[] modelSize = this.getResources().getIntArray(R.array.model_size); - mAdapter = new DownloadRecyclerView(this, modelArray, modelSize); + mAdapter = new DownloadRecyclerView(this, modelArray); mRecyclerView.setAdapter(mAdapter); } public void downloadAll(View view) { diff --git a/android/app/src/main/java/com/mnn/llm/MainActivity.java b/android/app/src/main/java/com/mnn/llm/MainActivity.java index 194ed5d3..696e35d2 100644 --- a/android/app/src/main/java/com/mnn/llm/MainActivity.java +++ b/android/app/src/main/java/com/mnn/llm/MainActivity.java @@ -13,9 +13,9 @@ import android.widget.RelativeLayout; import android.widget.TextView; -import org.w3c.dom.Text; - import java.io.File; +import java.io.IOException; + public class MainActivity extends AppCompatActivity { private Chat mChat; @@ -28,8 +28,9 @@ public class MainActivity extends AppCompatActivity { private TextView mProcessName; private TextView mProcessPercent; // resource files - private String mModelDir = "/data/local/tmp/model"; - private boolean mModelNeedDownload = true; + private String mModelName = "qwen-1.8b-int4"; + private String mModelDir = "/data/local/tmp/chat/" + mModelName; // default dir + private boolean mModelReady = true; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); @@ -41,7 +42,8 @@ protected void onCreate(Bundle savedInstanceState) { mProcessBar = (ProgressBar)findViewById(R.id.process_bar); mProcessName = (TextView)findViewById(R.id.process_name); mProcessPercent = (TextView)findViewById(R.id.process_percent); - mModelDir = this.getCacheDir().toString() + "/model"; + // using assert file + mModelDir = this.getCacheDir() + "/" + mModelName; mProcessHandler = new Handler() { @Override public void handleMessage(Message msg) { @@ -58,55 +60,56 @@ public void handleMessage(Message msg) { } } }; - /* - File model = new File(mModelDir, "glm_block_0.mnn"); - if (model.exists()) { - model.delete(); - } - */ - onCheckModels(); } @Override protected void onResume() { super.onResume(); - onCheckModels(); } public void onCheckModels() { - mModelNeedDownload = checkModelsNeedDownload(); - if (mModelNeedDownload) { + mModelReady = checkModelsReady(); + // try copy from assert file + if (!mModelReady) { + try { + mModelDir = Common.copyAssetResource2File(this, mModelName); + } catch (IOException e) { + throw new RuntimeException(e); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + mModelReady = checkModelsReady(); + } + // download models + if (!mModelReady) { mModelInfo.setVisibility(View.VISIBLE); - mModelInfo.setText("使用前请先下载模型!"); + mModelInfo.setText("请下载模型文件"); mLoadButton.setText("下载模型"); } else { mModelInfo.setVisibility(View.VISIBLE); - mModelInfo.setText("模型下载完毕,请加载模型!"); + mModelInfo.setText(mModelName + "模型文件就绪,模型加载中"); mLoadButton.setText("加载模型"); } } - public boolean checkModelsNeedDownload() { + public boolean checkModelsReady() { System.out.println("### Check Models!"); File dir = new File(mModelDir); if (!dir.exists()) { - return true; + return false; } String[] modelArray = this.getResources().getStringArray(R.array.model_list); - int[] modelSize = this.getResources().getIntArray(R.array.model_size); for (int i = 0; i < modelArray.length; i++) { File model = new File(mModelDir, modelArray[i]); if (!model.exists()) { - return true; - } - if (model.length() != modelSize[i]) { - return true; + return false; } } - return false; + return true; } public void loadModel(View view) { - if (mModelNeedDownload) { + onCheckModels(); + if (!mModelReady) { startActivity(new Intent(this, DownloadModel.class)); return; } @@ -115,7 +118,6 @@ public void loadModel(View view) { mLoadButton.setText("模型加载中 ..."); mProcessView.setVisibility(View.VISIBLE); mChat = new Chat(); - System.out.println("[MNN_DEBUG] is chat Ready: " + mChat.Ready()); Handler handler = new Handler() { @Override public void handleMessage(Message msg) { diff --git a/android/app/src/main/java/com/mnn/llm/recyclerdownload/DownloadRecyclerView.java b/android/app/src/main/java/com/mnn/llm/recyclerdownload/DownloadRecyclerView.java index 76a33d05..7d077187 100644 --- a/android/app/src/main/java/com/mnn/llm/recyclerdownload/DownloadRecyclerView.java +++ b/android/app/src/main/java/com/mnn/llm/recyclerdownload/DownloadRecyclerView.java @@ -25,7 +25,7 @@ public class DownloadRecyclerView extends RecyclerView.Adapter(); final String modelDir = context.getCacheDir().toString() + "/model"; @@ -48,7 +48,7 @@ public void handleMessage(Message msg) { } }; for (int i = 0; i < models.length; i++) { - this.mItems.add(new DownloadData(mHandler, modelDir, models[i], i, modelSize[i])); + this.mItems.add(new DownloadData(mHandler, modelDir, models[i], i, 25751300)); } } diff --git a/android/app/src/main/java/com/mnn/llm/recylcerchat/ConversationRecyclerView.java b/android/app/src/main/java/com/mnn/llm/recylcerchat/ConversationRecyclerView.java index a3378536..580bb7c9 100644 --- a/android/app/src/main/java/com/mnn/llm/recylcerchat/ConversationRecyclerView.java +++ b/android/app/src/main/java/com/mnn/llm/recylcerchat/ConversationRecyclerView.java @@ -101,7 +101,6 @@ private void configureViewHolder3(HolderMe vh1, int position) { } private void configureViewHolder2(HolderYou vh1, int position) { - vh1.getTime().setText(items.get(position).getTime()); vh1.getChatText().setText(items.get(position).getText()); } private void configureViewHolder1(HolderDate vh1, int position) { diff --git a/android/app/src/main/java/com/mnn/llm/recylcerchat/HolderYou.java b/android/app/src/main/java/com/mnn/llm/recylcerchat/HolderYou.java index c52c03d5..44e56ebf 100644 --- a/android/app/src/main/java/com/mnn/llm/recylcerchat/HolderYou.java +++ b/android/app/src/main/java/com/mnn/llm/recylcerchat/HolderYou.java @@ -9,22 +9,13 @@ public class HolderYou extends RecyclerView.ViewHolder { - private TextView time, chatText; + private TextView chatText; public HolderYou(View v) { super(v); - time = (TextView) v.findViewById(R.id.tv_time); chatText = (TextView) v.findViewById(R.id.tv_chat_text); } - public TextView getTime() { - return time; - } - - public void setTime(TextView time) { - this.time = time; - } - public TextView getChatText() { return chatText; } diff --git a/android/app/src/main/jni/CMakeLists.txt b/android/app/src/main/jni/CMakeLists.txt index cb067922..246aa329 100644 --- a/android/app/src/main/jni/CMakeLists.txt +++ b/android/app/src/main/jni/CMakeLists.txt @@ -7,6 +7,7 @@ cmake_minimum_required(VERSION 3.10) include_directories(${CMAKE_CURRENT_LIST_DIR}/../../../../../include/) link_directories(${CMAKE_CURRENT_LIST_DIR}/libs/arm64-v8a) +add_definitions(-DUSING_DISK_EMBED) FILE(GLOB SRCS ../../../../../src/*.cpp) add_library(llm_mnn SHARED llm_mnn_jni.cpp ${SRCS}) diff --git a/android/app/src/main/jni/llm_mnn_jni.cpp b/android/app/src/main/jni/llm_mnn_jni.cpp index ecd570c7..a1481b82 100644 --- a/android/app/src/main/jni/llm_mnn_jni.cpp +++ b/android/app/src/main/jni/llm_mnn_jni.cpp @@ -10,7 +10,7 @@ #include "llm.hpp" -static Llm* llm; +static std::unique_ptr llm(nullptr); static std::stringstream response_buffer; extern "C" { @@ -25,32 +25,33 @@ JNIEXPORT void JNI_OnUnload(JavaVM* vm, void* reserved) { } JNIEXPORT jboolean JNICALL Java_com_mnn_llm_Chat_Init(JNIEnv* env, jobject thiz, jstring modelDir) { - if (llm->load_progress() < 100) { - const char* model_dir = env->GetStringUTFChars(modelDir, 0); - llm = Llm::createLLM(model_dir); + const char* model_dir = env->GetStringUTFChars(modelDir, 0); + if (!llm.get()) { + llm.reset(Llm::createLLM(model_dir)); llm->load(model_dir); } return JNI_TRUE; } JNIEXPORT jboolean JNICALL Java_com_mnn_llm_Chat_Ready(JNIEnv* env, jobject thiz) { - if (llm->load_progress() >= 100) { + if (llm.get() && llm->load_progress() >= 100) { return JNI_TRUE; } return JNI_FALSE; } JNIEXPORT jfloat JNICALL Java_com_mnn_llm_Chat_Progress(JNIEnv* env, jobject thiz) { + if (!llm.get()) return jfloat(0); return jfloat(llm->load_progress()); } JNIEXPORT jstring JNICALL Java_com_mnn_llm_Chat_Submit(JNIEnv* env, jobject thiz, jstring inputStr) { - if (llm->load_progress() < 100) { + if (!llm.get()) { return env->NewStringUTF("Failed, Chat is not ready!"); } const char* input_str = env->GetStringUTFChars(inputStr, 0); auto chat = [&](std::string str) { - llm->response(str, &response_buffer); + llm->response(str, &response_buffer, ""); }; std::thread chat_thread(chat, input_str); chat_thread.detach(); diff --git a/android/app/src/main/res/layout/activity_main.xml b/android/app/src/main/res/layout/activity_main.xml index 7b8df790..4b0b5b84 100644 --- a/android/app/src/main/res/layout/activity_main.xml +++ b/android/app/src/main/res/layout/activity_main.xml @@ -28,7 +28,7 @@ android:layout_height="wrap_content" android:layout_gravity="center" android:textSize="8pt" - android:text="ChatGLM-MNN" /> + android:text="mnn-llm" /> + android:orientation="horizontal"> - \ No newline at end of file diff --git a/android/app/src/main/res/values/models.xml b/android/app/src/main/res/values/models.xml index 16713288..d5e7d4c5 100644 --- a/android/app/src/main/res/values/models.xml +++ b/android/app/src/main/res/values/models.xml @@ -1,67 +1,32 @@ - glm_block_0.mnn - glm_block_1.mnn - glm_block_2.mnn - glm_block_3.mnn - glm_block_4.mnn - glm_block_5.mnn - glm_block_6.mnn - glm_block_7.mnn - glm_block_8.mnn - glm_block_9.mnn - glm_block_10.mnn - glm_block_11.mnn - glm_block_12.mnn - glm_block_13.mnn - glm_block_14.mnn - glm_block_15.mnn - glm_block_16.mnn - glm_block_17.mnn - glm_block_18.mnn - glm_block_19.mnn - glm_block_20.mnn - glm_block_21.mnn - glm_block_22.mnn - glm_block_23.mnn - glm_block_24.mnn - glm_block_25.mnn - glm_block_26.mnn - glm_block_27.mnn + block_0.mnn + block_1.mnn + block_2.mnn + block_3.mnn + block_4.mnn + block_5.mnn + block_6.mnn + block_7.mnn + block_8.mnn + block_9.mnn + block_10.mnn + block_11.mnn + block_12.mnn + block_13.mnn + block_14.mnn + block_15.mnn + block_16.mnn + block_17.mnn + block_18.mnn + block_19.mnn + block_20.mnn + block_21.mnn + block_22.mnn + block_23.mnn + embeddings_bf16.bin lm.mnn - slim_word_embeddings_bf16.bin + tokenizer.txt - - 101104204 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 101104300 - 100507772 - 268366612 - 1069285376 - \ No newline at end of file diff --git a/android/app/src/main/res/values/strings.xml b/android/app/src/main/res/values/strings.xml index 6d5c1e3e..c61c1498 100644 --- a/android/app/src/main/res/values/strings.xml +++ b/android/app/src/main/res/values/strings.xml @@ -1,8 +1,8 @@ - ChatGLM-MNN + mnn-llm MainActivity Open navigation drawer Close navigation drawer Settings - 本Demo基于ChatGLM-6B模型开发,对ChatGLM-6B模型进行了4bit量化,使用MNN进行推理加速。模型加载略慢,请稍作等待,点击下方按钮加载模型。 + mnn-llm diff --git a/android/gradle.properties b/android/gradle.properties index 199d16ed..d41b0e88 100644 --- a/android/gradle.properties +++ b/android/gradle.properties @@ -6,7 +6,7 @@ # http://www.gradle.org/docs/current/userguide/build_environment.html # Specifies the JVM arguments used for the daemon process. # The setting is particularly useful for tweaking memory settings. -org.gradle.jvmargs=-Xmx1536m +org.gradle.jvmargs=-Xmx4g # When configured, Gradle will run in incubating parallel mode. # This option should only be used with decoupled projects. More details, visit # http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects diff --git a/include/llm.hpp b/include/llm.hpp index 48dc6515..456d440a 100644 --- a/include/llm.hpp +++ b/include/llm.hpp @@ -32,7 +32,7 @@ class Llm { tokenizer_.reset(new Sentencepiece); } virtual ~Llm() = default; - static Llm* createLLM(const std::string& path); + static Llm* createLLM(const std::string& path, std::string model_type = "auto"); VARP disk_embedding(const std::vector& input_ids); void load(const std::string& model_dir); int forward(const std::vector& input_ids); @@ -40,13 +40,14 @@ class Llm { std::string decode(int id); void chat(); void warmup(); - std::string response(const std::string& input_str, std::ostream* os = &std::cout); + std::string response(const std::string& input_str, std::ostream* os = &std::cout, const char* end_with = nullptr); float load_progress() { return load_progress_; } void reset(); void print_speed(); public: + std::vector history_; // forward info - int max_seq_len_ = 256; + int max_seq_len_ = 1024; int prompt_len_ = 0; int gen_seq_len_ = 0; int all_seq_len_ = 0; diff --git a/src/llm.cpp b/src/llm.cpp index f8094f99..c2856f7a 100644 --- a/src/llm.cpp +++ b/src/llm.cpp @@ -15,7 +15,7 @@ #include -Llm* Llm::createLLM(const std::string& path) { +Llm* Llm::createLLM(const std::string& path, std::string model_type) { auto size = path.size(); // end with '.mnn' is single model file, otherwise split block models bool is_single = (size > 4 && @@ -24,27 +24,30 @@ Llm* Llm::createLLM(const std::string& path) { path[size - 2] == 'n' && path[size - 1] == 'n'); Llm* llm = nullptr; - if (path.find("chatglm") != std::string::npos) { - if (path.find("chatglm2") != std::string::npos) { + if (model_type == "auto") { + model_type = path; + } + if (model_type.find("chatglm") != std::string::npos) { + if (model_type.find("chatglm2") != std::string::npos) { llm = new Chatglm2_6b; - } else if (path.find("chatglm3") != std::string::npos) { + } else if (model_type.find("chatglm3") != std::string::npos) { llm = new Chatglm2_6b; llm->model_name_ = "Chatglm3_6b"; } else { llm = new Chatglm_6b; } - } else if (path.find("codegeex2") != std::string::npos) { + } else if (model_type.find("codegeex2") != std::string::npos) { llm = new Chatglm2_6b; llm->model_name_ = "Codegeex2_6b"; - } else if (path.find("qwen") != std::string::npos) { - if (path.find("1.8") != std::string::npos) { + } else if (model_type.find("qwen") != std::string::npos) { + if (model_type.find("1.8") != std::string::npos) { llm = new Qwen_1_8b; } else { llm = new Qwen_7b; } - } else if (path.find("llama2") != std::string::npos) { + } else if (model_type.find("llama2") != std::string::npos) { llm = new Llama2_7b; - } else if (path.find("baichuan") != std::string::npos) { + } else if (model_type.find("baichuan") != std::string::npos) { llm = new Llama2_7b; llm->model_name_ = "Baichuan2_7b"; } @@ -64,13 +67,20 @@ void Llm::chat() { std::cin >> input_str; std::cout << "\nA: " << std::flush; response(input_str); - reset(); std::cout << std::endl; } } -std::string Llm::response(const std::string& query, std::ostream* os) { +std::string Llm::response(const std::string& query, std::ostream* os, const char* end_with) { + if (!end_with) { + end_with = "\n"; + } // init status + gen_seq_len_ = 0; + all_seq_len_ = 0; + prefill_us_ = 0; + decode_us_ = 0; + past_key_values_.clear(); if (is_single_) { past_key_values_.push_back(_Input(key_value_shape_, NCHW)); } else { @@ -80,12 +90,20 @@ std::string Llm::response(const std::string& query, std::ostream* os) { } // response auto input_ids = tokenizer(query); + if (!history_.empty()) { + std::copy(input_ids.begin(), input_ids.end(), std::back_inserter(history_)); + input_ids = history_; + } else { + history_ = input_ids; + } + prompt_len_ = input_ids.size(); // printf("token_num : %lu\n", input_ids.size()); auto st = std::chrono::system_clock::now(); int token = forward(input_ids); - std::string output_str = decode(token); auto et = std::chrono::system_clock::now(); + history_.push_back(token); + std::string output_str = decode(token); prefill_us_ = std::chrono::duration_cast(et - st).count(); *os << output_str << std::flush; while (gen_seq_len_ < max_seq_len_) { @@ -94,9 +112,10 @@ std::string Llm::response(const std::string& query, std::ostream* os) { et = std::chrono::system_clock::now(); decode_us_ += std::chrono::duration_cast(et - st).count(); if (is_stop(token)) { - *os << std::endl << std::flush; + *os << end_with << std::flush; break; } + history_.push_back(token); auto word = decode(token); *os << word << std::flush; output_str += word; @@ -105,7 +124,8 @@ std::string Llm::response(const std::string& query, std::ostream* os) { print_speed(); #endif // update Cache - runtime_manager_->updateCache(); + // runtime_manager_->updateCache(); + // reset forward info return output_str; } @@ -128,13 +148,7 @@ void Llm::print_speed() { } void Llm::reset() { - past_key_values_.clear(); - past_key_values_.shrink_to_fit(); - prompt_len_ = 0; - gen_seq_len_ = 0; - all_seq_len_ = 0; - prefill_us_ = 0; - decode_us_ = 0; + history_.clear(); } void Llm::load(const std::string& model_dir) { @@ -153,54 +167,55 @@ void Llm::load(const std::string& model_dir) { const char* cacheFileName = ".tempcache"; runtime_manager_->setCache(cacheFileName); } + load_progress_ = 0.f; // 1. load vocab std::string tokenizer_path = model_dir + "/tokenizer.txt"; + load_progress_ += 5.f; tokenizer_->load(tokenizer_path); + load_progress_ += 5.f; // 2. load model Module::Config module_config; module_config.shapeMutable = true; module_config.rearrange = true; - load_progress_ = 0.f; if (is_single_) { key_value_shape_.insert(key_value_shape_.begin(), layer_nums_); modules_.resize(1); std::string model_path = model_dir; std::string external_path = model_dir + ".weight"; - printf("load %s ... ", model_path.c_str()); + MNN_PRINT("load %s ... ", model_path.c_str()); runtime_manager_->setExternalFile(external_path); modules_[0].reset(Module::load( {"input_ids", "attention_mask", "position_ids", "past_key_values"}, {"token_id", "presents"}, model_path.c_str(), runtime_manager_, &module_config)); - printf("Done!\n"); - fflush(stdout); + MNN_PRINT("Done!\n"); + load_progress_ += 90.f; } else { // 2. load models modules_.resize(layer_nums_ + 2); - float step = 100.0 / modules_.size(); + float step = 90.0 / modules_.size(); char buffer[50]; // load lm model std::string lm_model_path = model_dir + "/lm.mnn"; std::string embedding_model_path = model_dir + "/embedding.mnn"; - printf("[%3.0f%% ] load %s model ... ", load_progress_, lm_model_path.c_str()); + MNN_PRINT("[%3.0f%% ] load %s model ... ", load_progress_, lm_model_path.c_str()); modules_[layer_nums_].reset(Module::load({}, {}, lm_model_path.c_str(), runtime_manager_, &module_config)); - printf("Done!\n"); + MNN_PRINT("Done!\n"); load_progress_ += step; #ifndef USING_DISK_EMBED - printf("[%3.0f%% ] load %s model ... ", load_progress_, embedding_model_path.c_str());fflush(stdout); + MNN_PRINT("[%3.0f%% ] load %s model ... ", load_progress_, embedding_model_path.c_str());fflush(stdout); modules_[layer_nums_ + 1].reset(Module::load({}, {}, embedding_model_path.c_str(), runtime_manager_, &module_config)); - printf("Done!\n"); + MNN_PRINT("Done!\n"); load_progress_ += step; #endif // load glm_block models for (int i = 0; i < layer_nums_; i++) { load_progress_ += step; std::string model_path = model_dir + "/block_" + std::to_string(i) + ".mnn"; - printf("[%3.0f%% ] load %s model ... ", load_progress_, model_path.c_str()); + MNN_PRINT("[%3.0f%% ] load %s model ... ", load_progress_, model_path.c_str()); modules_[i].reset(Module::load( {"inputs_embeds", "attention_mask", "position_ids", "past_key_values"}, {"hidden_states", "presents"}, model_path.c_str(), runtime_manager_, &module_config)); - printf("Done!\n"); - fflush(stdout); + MNN_PRINT("Done!\n"); } } if (config.type == MNN_FORWARD_OPENCL) { @@ -210,13 +225,18 @@ void Llm::load(const std::string& model_dir) { void Llm::warmup() { // warmup - printf("### warmup ... "); - for (int i = 0; i < layer_nums_; i++) { + MNN_PRINT("### warmup ... "); + if (is_single_) { past_key_values_.push_back(_Input(key_value_shape_, NCHW)); + } else { + for (int i = 0; i < layer_nums_; i++) { + past_key_values_.push_back(_Input(key_value_shape_, NCHW)); + } } - std::vector tmp_0(1, 0); - forward(tmp_0); - reset(); + std::vector tmp(1, 0); + forward(tmp); + all_seq_len_ = 0; + gen_seq_len_ = 0; printf("Done\n"); }