diff --git a/README.md b/README.md
index c4a7d574..1764e906 100644
--- a/README.md
+++ b/README.md
@@ -25,6 +25,7 @@ llm模型导出onnx模型请使用[llm-export](https://github.com/wangzhaode/llm
其他版本:
- Qwen-1_8B-Chat-int8:[![Download][download-qwen-1.8b-mnn-int8]][release-qwen-1.8b-mnn-int8]
+- Android APK: [![Download][download-qwen-1.8b-apk]][release-qwen-1.8b-apk]
[download-chatglm-6b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/chatglm-6b-onnx/total
[download-chatglm2-6b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/chatglm2-6b-onnx/total
@@ -51,6 +52,8 @@ llm模型导出onnx模型请使用[llm-export](https://github.com/wangzhaode/llm
[download-llama2-7b-chat-mnn]: https://img.shields.io/github/downloads/wangzhaode/mnn-llm/llama2-7b-chat-mnn/total
[download-qwen-1.8b-mnn]: https://img.shields.io/github/downloads/wangzhaode/mnn-llm/qwen-1.8b-mnn/total
[download-qwen-1.8b-mnn-int8]: https://img.shields.io/github/downloads/wangzhaode/mnn-llm/qwen-1.8b-mnn-int8/total
+[download-qwen-1.8b-apk]: https://img.shields.io/github/downloads/wangzhaode/mnn-llm/qwen-1.8b-apk/total
+
[release-chatglm-6b-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/chatglm-6b-mnn
[release-chatglm2-6b-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/chatglm2-6b-mnn
[release-chatglm3-6b-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/chatglm3-6b-mnn
@@ -60,6 +63,8 @@ llm模型导出onnx模型请使用[llm-export](https://github.com/wangzhaode/llm
[release-llama2-7b-chat-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/llama2-7b-chat-mnn
[release-qwen-1.8b-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/qwen-1.8b-mnn
[release-qwen-1.8b-mnn-int8]: https://github.com/wangzhaode/mnn-llm/releases/tag/qwen-1.8b-mnn-int8
+[release-qwen-1.8b-apk]: https://github.com/wangzhaode/mnn-llm/releases/tag/qwen-1.8b-apk
+
### 速度
@@ -157,12 +162,13 @@ adb shell "cd /data/local/tmp && export LD_LIBRARY_PATH=. && ./cli_demo -m model
## Reference
-- [chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
-- [chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
-- [codegeex2-6b](https://huggingface.co/THUDM/codegeex2-6b)
-- [Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
-- [Qwen-7B-Chat](https://huggingface.co/tangger/Qwen-7B-Chat)
+- [chatglm-6b](https://modelscope.cn/models/ZhipuAI/chatglm-6b/summary)
+- [chatglm2-6b](https://modelscope.cn/models/ZhipuAI/chatglm2-6b/summary)
+- [chatglm3-6b](https://modelscope.cn/models/ZhipuAI/chatglm3-6b/summary)
+- [codegeex2-6b](https://modelscope.cn/models/ZhipuAI/codegeex2-6b/summary)
+- [Baichuan2-7B-Chat](https://modelscope.cn/models/baichuan-inc/baichuan-7B/summary)
+- [Qwen-7B-Chat](https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary)
+- [Qwen-1.8B-Chat](https://modelscope.cn/models/qwen/Qwen-1_8B-Chat/summary)
- [cpp-httplib](https://github.com/yhirose/cpp-httplib)
- [chatgpt-web](https://github.com/xqdoo00o/chatgpt-web)
-- [cppjieba](https://github.com/yanyiwu/cppjieba)
- [ChatViewDemo](https://github.com/BrettFX/ChatViewDemo)
diff --git a/android/.idea/compiler.xml b/android/.idea/compiler.xml
index 443b5d2b..61a9130c 100644
--- a/android/.idea/compiler.xml
+++ b/android/.idea/compiler.xml
@@ -1,6 +1,6 @@
-
+
\ No newline at end of file
diff --git a/android/app/src/main/assets/tokenizer b/android/app/src/main/assets/tokenizer
deleted file mode 120000
index 7496190d..00000000
--- a/android/app/src/main/assets/tokenizer
+++ /dev/null
@@ -1 +0,0 @@
-../../../../../resource/tokenizer
\ No newline at end of file
diff --git a/android/app/src/main/java/com/mnn/llm/Conversation.java b/android/app/src/main/java/com/mnn/llm/Conversation.java
index e45e568f..8a017513 100644
--- a/android/app/src/main/java/com/mnn/llm/Conversation.java
+++ b/android/app/src/main/java/com/mnn/llm/Conversation.java
@@ -31,7 +31,7 @@ public class Conversation extends BaseActivity {
private Button send;
private DateFormat mDateFormat;
private Chat mChat;
- private boolean mHistory = false;
+ private boolean mHistory = true;
@Override
protected void onCreate(Bundle savedInstanceState) {
@@ -138,12 +138,16 @@ public boolean onCreateOptionsMenu(Menu menu) {
@Override
public boolean onOptionsItemSelected(MenuItem item) {
+ /*
if (mHistory) {
Toast.makeText(getBaseContext(), "关闭上下文", Toast.LENGTH_SHORT).show();
} else {
Toast.makeText(getBaseContext(), "打开上下文", Toast.LENGTH_SHORT).show();
}
mHistory = !mHistory;
+ */
+ Toast.makeText(getBaseContext(), "清空记忆", Toast.LENGTH_SHORT).show();
+ mChat.Reset();
return true;
}
}
@@ -168,7 +172,7 @@ public void run() {
System.out.println("[MNN_DEBUG] start response\n");
while (!last_response.contains("")) {
try {
- Thread.sleep(200);
+ Thread.sleep(50);
} catch (Exception e) {}
String response = new String(mChat.Response());
if (response.equals(last_response)) {
@@ -177,6 +181,7 @@ public void run() {
last_response = response;
}
Message msg = new Message();
+ System.out.println("[MNN_DEBUG] " + response);
msg.obj = response.replaceFirst("", "");
mHandler.sendMessage(msg);
}
diff --git a/android/app/src/main/java/com/mnn/llm/DownloadModel.java b/android/app/src/main/java/com/mnn/llm/DownloadModel.java
index cb1f25b1..cbe8372e 100644
--- a/android/app/src/main/java/com/mnn/llm/DownloadModel.java
+++ b/android/app/src/main/java/com/mnn/llm/DownloadModel.java
@@ -2,11 +2,8 @@
import android.app.AlertDialog;
import android.content.DialogInterface;
-import android.content.Intent;
import android.graphics.Color;
import android.os.Bundle;
-import android.os.Handler;
-import android.os.Message;
import android.view.View;
import android.widget.Button;
@@ -29,8 +26,7 @@ protected void onCreate(Bundle savedInstanceState) {
mDownloadAll = (Button)findViewById(R.id.download_all);
// init Data
String[] modelArray = this.getResources().getStringArray(R.array.model_list);
- int[] modelSize = this.getResources().getIntArray(R.array.model_size);
- mAdapter = new DownloadRecyclerView(this, modelArray, modelSize);
+ mAdapter = new DownloadRecyclerView(this, modelArray);
mRecyclerView.setAdapter(mAdapter);
}
public void downloadAll(View view) {
diff --git a/android/app/src/main/java/com/mnn/llm/MainActivity.java b/android/app/src/main/java/com/mnn/llm/MainActivity.java
index 194ed5d3..696e35d2 100644
--- a/android/app/src/main/java/com/mnn/llm/MainActivity.java
+++ b/android/app/src/main/java/com/mnn/llm/MainActivity.java
@@ -13,9 +13,9 @@
import android.widget.RelativeLayout;
import android.widget.TextView;
-import org.w3c.dom.Text;
-
import java.io.File;
+import java.io.IOException;
+
public class MainActivity extends AppCompatActivity {
private Chat mChat;
@@ -28,8 +28,9 @@ public class MainActivity extends AppCompatActivity {
private TextView mProcessName;
private TextView mProcessPercent;
// resource files
- private String mModelDir = "/data/local/tmp/model";
- private boolean mModelNeedDownload = true;
+ private String mModelName = "qwen-1.8b-int4";
+ private String mModelDir = "/data/local/tmp/chat/" + mModelName; // default dir
+ private boolean mModelReady = true;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
@@ -41,7 +42,8 @@ protected void onCreate(Bundle savedInstanceState) {
mProcessBar = (ProgressBar)findViewById(R.id.process_bar);
mProcessName = (TextView)findViewById(R.id.process_name);
mProcessPercent = (TextView)findViewById(R.id.process_percent);
- mModelDir = this.getCacheDir().toString() + "/model";
+ // using assert file
+ mModelDir = this.getCacheDir() + "/" + mModelName;
mProcessHandler = new Handler() {
@Override
public void handleMessage(Message msg) {
@@ -58,55 +60,56 @@ public void handleMessage(Message msg) {
}
}
};
- /*
- File model = new File(mModelDir, "glm_block_0.mnn");
- if (model.exists()) {
- model.delete();
- }
- */
- onCheckModels();
}
@Override
protected void onResume() {
super.onResume();
- onCheckModels();
}
public void onCheckModels() {
- mModelNeedDownload = checkModelsNeedDownload();
- if (mModelNeedDownload) {
+ mModelReady = checkModelsReady();
+ // try copy from assert file
+ if (!mModelReady) {
+ try {
+ mModelDir = Common.copyAssetResource2File(this, mModelName);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ mModelReady = checkModelsReady();
+ }
+ // download models
+ if (!mModelReady) {
mModelInfo.setVisibility(View.VISIBLE);
- mModelInfo.setText("使用前请先下载模型!");
+ mModelInfo.setText("请下载模型文件");
mLoadButton.setText("下载模型");
} else {
mModelInfo.setVisibility(View.VISIBLE);
- mModelInfo.setText("模型下载完毕,请加载模型!");
+ mModelInfo.setText(mModelName + "模型文件就绪,模型加载中");
mLoadButton.setText("加载模型");
}
}
- public boolean checkModelsNeedDownload() {
+ public boolean checkModelsReady() {
System.out.println("### Check Models!");
File dir = new File(mModelDir);
if (!dir.exists()) {
- return true;
+ return false;
}
String[] modelArray = this.getResources().getStringArray(R.array.model_list);
- int[] modelSize = this.getResources().getIntArray(R.array.model_size);
for (int i = 0; i < modelArray.length; i++) {
File model = new File(mModelDir, modelArray[i]);
if (!model.exists()) {
- return true;
- }
- if (model.length() != modelSize[i]) {
- return true;
+ return false;
}
}
- return false;
+ return true;
}
public void loadModel(View view) {
- if (mModelNeedDownload) {
+ onCheckModels();
+ if (!mModelReady) {
startActivity(new Intent(this, DownloadModel.class));
return;
}
@@ -115,7 +118,6 @@ public void loadModel(View view) {
mLoadButton.setText("模型加载中 ...");
mProcessView.setVisibility(View.VISIBLE);
mChat = new Chat();
- System.out.println("[MNN_DEBUG] is chat Ready: " + mChat.Ready());
Handler handler = new Handler() {
@Override
public void handleMessage(Message msg) {
diff --git a/android/app/src/main/java/com/mnn/llm/recyclerdownload/DownloadRecyclerView.java b/android/app/src/main/java/com/mnn/llm/recyclerdownload/DownloadRecyclerView.java
index 76a33d05..7d077187 100644
--- a/android/app/src/main/java/com/mnn/llm/recyclerdownload/DownloadRecyclerView.java
+++ b/android/app/src/main/java/com/mnn/llm/recyclerdownload/DownloadRecyclerView.java
@@ -25,7 +25,7 @@ public class DownloadRecyclerView extends RecyclerView.Adapter();
final String modelDir = context.getCacheDir().toString() + "/model";
@@ -48,7 +48,7 @@ public void handleMessage(Message msg) {
}
};
for (int i = 0; i < models.length; i++) {
- this.mItems.add(new DownloadData(mHandler, modelDir, models[i], i, modelSize[i]));
+ this.mItems.add(new DownloadData(mHandler, modelDir, models[i], i, 25751300));
}
}
diff --git a/android/app/src/main/java/com/mnn/llm/recylcerchat/ConversationRecyclerView.java b/android/app/src/main/java/com/mnn/llm/recylcerchat/ConversationRecyclerView.java
index a3378536..580bb7c9 100644
--- a/android/app/src/main/java/com/mnn/llm/recylcerchat/ConversationRecyclerView.java
+++ b/android/app/src/main/java/com/mnn/llm/recylcerchat/ConversationRecyclerView.java
@@ -101,7 +101,6 @@ private void configureViewHolder3(HolderMe vh1, int position) {
}
private void configureViewHolder2(HolderYou vh1, int position) {
- vh1.getTime().setText(items.get(position).getTime());
vh1.getChatText().setText(items.get(position).getText());
}
private void configureViewHolder1(HolderDate vh1, int position) {
diff --git a/android/app/src/main/java/com/mnn/llm/recylcerchat/HolderYou.java b/android/app/src/main/java/com/mnn/llm/recylcerchat/HolderYou.java
index c52c03d5..44e56ebf 100644
--- a/android/app/src/main/java/com/mnn/llm/recylcerchat/HolderYou.java
+++ b/android/app/src/main/java/com/mnn/llm/recylcerchat/HolderYou.java
@@ -9,22 +9,13 @@
public class HolderYou extends RecyclerView.ViewHolder {
- private TextView time, chatText;
+ private TextView chatText;
public HolderYou(View v) {
super(v);
- time = (TextView) v.findViewById(R.id.tv_time);
chatText = (TextView) v.findViewById(R.id.tv_chat_text);
}
- public TextView getTime() {
- return time;
- }
-
- public void setTime(TextView time) {
- this.time = time;
- }
-
public TextView getChatText() {
return chatText;
}
diff --git a/android/app/src/main/jni/CMakeLists.txt b/android/app/src/main/jni/CMakeLists.txt
index cb067922..246aa329 100644
--- a/android/app/src/main/jni/CMakeLists.txt
+++ b/android/app/src/main/jni/CMakeLists.txt
@@ -7,6 +7,7 @@ cmake_minimum_required(VERSION 3.10)
include_directories(${CMAKE_CURRENT_LIST_DIR}/../../../../../include/)
link_directories(${CMAKE_CURRENT_LIST_DIR}/libs/arm64-v8a)
+add_definitions(-DUSING_DISK_EMBED)
FILE(GLOB SRCS ../../../../../src/*.cpp)
add_library(llm_mnn SHARED llm_mnn_jni.cpp ${SRCS})
diff --git a/android/app/src/main/jni/llm_mnn_jni.cpp b/android/app/src/main/jni/llm_mnn_jni.cpp
index ecd570c7..a1481b82 100644
--- a/android/app/src/main/jni/llm_mnn_jni.cpp
+++ b/android/app/src/main/jni/llm_mnn_jni.cpp
@@ -10,7 +10,7 @@
#include "llm.hpp"
-static Llm* llm;
+static std::unique_ptr llm(nullptr);
static std::stringstream response_buffer;
extern "C" {
@@ -25,32 +25,33 @@ JNIEXPORT void JNI_OnUnload(JavaVM* vm, void* reserved) {
}
JNIEXPORT jboolean JNICALL Java_com_mnn_llm_Chat_Init(JNIEnv* env, jobject thiz, jstring modelDir) {
- if (llm->load_progress() < 100) {
- const char* model_dir = env->GetStringUTFChars(modelDir, 0);
- llm = Llm::createLLM(model_dir);
+ const char* model_dir = env->GetStringUTFChars(modelDir, 0);
+ if (!llm.get()) {
+ llm.reset(Llm::createLLM(model_dir));
llm->load(model_dir);
}
return JNI_TRUE;
}
JNIEXPORT jboolean JNICALL Java_com_mnn_llm_Chat_Ready(JNIEnv* env, jobject thiz) {
- if (llm->load_progress() >= 100) {
+ if (llm.get() && llm->load_progress() >= 100) {
return JNI_TRUE;
}
return JNI_FALSE;
}
JNIEXPORT jfloat JNICALL Java_com_mnn_llm_Chat_Progress(JNIEnv* env, jobject thiz) {
+ if (!llm.get()) return jfloat(0);
return jfloat(llm->load_progress());
}
JNIEXPORT jstring JNICALL Java_com_mnn_llm_Chat_Submit(JNIEnv* env, jobject thiz, jstring inputStr) {
- if (llm->load_progress() < 100) {
+ if (!llm.get()) {
return env->NewStringUTF("Failed, Chat is not ready!");
}
const char* input_str = env->GetStringUTFChars(inputStr, 0);
auto chat = [&](std::string str) {
- llm->response(str, &response_buffer);
+ llm->response(str, &response_buffer, "");
};
std::thread chat_thread(chat, input_str);
chat_thread.detach();
diff --git a/android/app/src/main/res/layout/activity_main.xml b/android/app/src/main/res/layout/activity_main.xml
index 7b8df790..4b0b5b84 100644
--- a/android/app/src/main/res/layout/activity_main.xml
+++ b/android/app/src/main/res/layout/activity_main.xml
@@ -28,7 +28,7 @@
android:layout_height="wrap_content"
android:layout_gravity="center"
android:textSize="8pt"
- android:text="ChatGLM-MNN" />
+ android:text="mnn-llm" />
+ android:orientation="horizontal">
-
\ No newline at end of file
diff --git a/android/app/src/main/res/values/models.xml b/android/app/src/main/res/values/models.xml
index 16713288..d5e7d4c5 100644
--- a/android/app/src/main/res/values/models.xml
+++ b/android/app/src/main/res/values/models.xml
@@ -1,67 +1,32 @@
- - glm_block_0.mnn
- - glm_block_1.mnn
- - glm_block_2.mnn
- - glm_block_3.mnn
- - glm_block_4.mnn
- - glm_block_5.mnn
- - glm_block_6.mnn
- - glm_block_7.mnn
- - glm_block_8.mnn
- - glm_block_9.mnn
- - glm_block_10.mnn
- - glm_block_11.mnn
- - glm_block_12.mnn
- - glm_block_13.mnn
- - glm_block_14.mnn
- - glm_block_15.mnn
- - glm_block_16.mnn
- - glm_block_17.mnn
- - glm_block_18.mnn
- - glm_block_19.mnn
- - glm_block_20.mnn
- - glm_block_21.mnn
- - glm_block_22.mnn
- - glm_block_23.mnn
- - glm_block_24.mnn
- - glm_block_25.mnn
- - glm_block_26.mnn
- - glm_block_27.mnn
+ - block_0.mnn
+ - block_1.mnn
+ - block_2.mnn
+ - block_3.mnn
+ - block_4.mnn
+ - block_5.mnn
+ - block_6.mnn
+ - block_7.mnn
+ - block_8.mnn
+ - block_9.mnn
+ - block_10.mnn
+ - block_11.mnn
+ - block_12.mnn
+ - block_13.mnn
+ - block_14.mnn
+ - block_15.mnn
+ - block_16.mnn
+ - block_17.mnn
+ - block_18.mnn
+ - block_19.mnn
+ - block_20.mnn
+ - block_21.mnn
+ - block_22.mnn
+ - block_23.mnn
+ - embeddings_bf16.bin
- lm.mnn
- - slim_word_embeddings_bf16.bin
+ - tokenizer.txt
-
- - 101104204
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 101104300
- - 100507772
- - 268366612
- - 1069285376
-
\ No newline at end of file
diff --git a/android/app/src/main/res/values/strings.xml b/android/app/src/main/res/values/strings.xml
index 6d5c1e3e..c61c1498 100644
--- a/android/app/src/main/res/values/strings.xml
+++ b/android/app/src/main/res/values/strings.xml
@@ -1,8 +1,8 @@
- ChatGLM-MNN
+ mnn-llm
MainActivity
Open navigation drawer
Close navigation drawer
Settings
- 本Demo基于ChatGLM-6B模型开发,对ChatGLM-6B模型进行了4bit量化,使用MNN进行推理加速。模型加载略慢,请稍作等待,点击下方按钮加载模型。
+ mnn-llm
diff --git a/android/gradle.properties b/android/gradle.properties
index 199d16ed..d41b0e88 100644
--- a/android/gradle.properties
+++ b/android/gradle.properties
@@ -6,7 +6,7 @@
# http://www.gradle.org/docs/current/userguide/build_environment.html
# Specifies the JVM arguments used for the daemon process.
# The setting is particularly useful for tweaking memory settings.
-org.gradle.jvmargs=-Xmx1536m
+org.gradle.jvmargs=-Xmx4g
# When configured, Gradle will run in incubating parallel mode.
# This option should only be used with decoupled projects. More details, visit
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
diff --git a/include/llm.hpp b/include/llm.hpp
index 48dc6515..456d440a 100644
--- a/include/llm.hpp
+++ b/include/llm.hpp
@@ -32,7 +32,7 @@ class Llm {
tokenizer_.reset(new Sentencepiece);
}
virtual ~Llm() = default;
- static Llm* createLLM(const std::string& path);
+ static Llm* createLLM(const std::string& path, std::string model_type = "auto");
VARP disk_embedding(const std::vector& input_ids);
void load(const std::string& model_dir);
int forward(const std::vector& input_ids);
@@ -40,13 +40,14 @@ class Llm {
std::string decode(int id);
void chat();
void warmup();
- std::string response(const std::string& input_str, std::ostream* os = &std::cout);
+ std::string response(const std::string& input_str, std::ostream* os = &std::cout, const char* end_with = nullptr);
float load_progress() { return load_progress_; }
void reset();
void print_speed();
public:
+ std::vector history_;
// forward info
- int max_seq_len_ = 256;
+ int max_seq_len_ = 1024;
int prompt_len_ = 0;
int gen_seq_len_ = 0;
int all_seq_len_ = 0;
diff --git a/src/llm.cpp b/src/llm.cpp
index f8094f99..c2856f7a 100644
--- a/src/llm.cpp
+++ b/src/llm.cpp
@@ -15,7 +15,7 @@
#include
-Llm* Llm::createLLM(const std::string& path) {
+Llm* Llm::createLLM(const std::string& path, std::string model_type) {
auto size = path.size();
// end with '.mnn' is single model file, otherwise split block models
bool is_single = (size > 4 &&
@@ -24,27 +24,30 @@ Llm* Llm::createLLM(const std::string& path) {
path[size - 2] == 'n' &&
path[size - 1] == 'n');
Llm* llm = nullptr;
- if (path.find("chatglm") != std::string::npos) {
- if (path.find("chatglm2") != std::string::npos) {
+ if (model_type == "auto") {
+ model_type = path;
+ }
+ if (model_type.find("chatglm") != std::string::npos) {
+ if (model_type.find("chatglm2") != std::string::npos) {
llm = new Chatglm2_6b;
- } else if (path.find("chatglm3") != std::string::npos) {
+ } else if (model_type.find("chatglm3") != std::string::npos) {
llm = new Chatglm2_6b;
llm->model_name_ = "Chatglm3_6b";
} else {
llm = new Chatglm_6b;
}
- } else if (path.find("codegeex2") != std::string::npos) {
+ } else if (model_type.find("codegeex2") != std::string::npos) {
llm = new Chatglm2_6b;
llm->model_name_ = "Codegeex2_6b";
- } else if (path.find("qwen") != std::string::npos) {
- if (path.find("1.8") != std::string::npos) {
+ } else if (model_type.find("qwen") != std::string::npos) {
+ if (model_type.find("1.8") != std::string::npos) {
llm = new Qwen_1_8b;
} else {
llm = new Qwen_7b;
}
- } else if (path.find("llama2") != std::string::npos) {
+ } else if (model_type.find("llama2") != std::string::npos) {
llm = new Llama2_7b;
- } else if (path.find("baichuan") != std::string::npos) {
+ } else if (model_type.find("baichuan") != std::string::npos) {
llm = new Llama2_7b;
llm->model_name_ = "Baichuan2_7b";
}
@@ -64,13 +67,20 @@ void Llm::chat() {
std::cin >> input_str;
std::cout << "\nA: " << std::flush;
response(input_str);
- reset();
std::cout << std::endl;
}
}
-std::string Llm::response(const std::string& query, std::ostream* os) {
+std::string Llm::response(const std::string& query, std::ostream* os, const char* end_with) {
+ if (!end_with) {
+ end_with = "\n";
+ }
// init status
+ gen_seq_len_ = 0;
+ all_seq_len_ = 0;
+ prefill_us_ = 0;
+ decode_us_ = 0;
+ past_key_values_.clear();
if (is_single_) {
past_key_values_.push_back(_Input(key_value_shape_, NCHW));
} else {
@@ -80,12 +90,20 @@ std::string Llm::response(const std::string& query, std::ostream* os) {
}
// response
auto input_ids = tokenizer(query);
+ if (!history_.empty()) {
+ std::copy(input_ids.begin(), input_ids.end(), std::back_inserter(history_));
+ input_ids = history_;
+ } else {
+ history_ = input_ids;
+ }
+
prompt_len_ = input_ids.size();
// printf("token_num : %lu\n", input_ids.size());
auto st = std::chrono::system_clock::now();
int token = forward(input_ids);
- std::string output_str = decode(token);
auto et = std::chrono::system_clock::now();
+ history_.push_back(token);
+ std::string output_str = decode(token);
prefill_us_ = std::chrono::duration_cast(et - st).count();
*os << output_str << std::flush;
while (gen_seq_len_ < max_seq_len_) {
@@ -94,9 +112,10 @@ std::string Llm::response(const std::string& query, std::ostream* os) {
et = std::chrono::system_clock::now();
decode_us_ += std::chrono::duration_cast(et - st).count();
if (is_stop(token)) {
- *os << std::endl << std::flush;
+ *os << end_with << std::flush;
break;
}
+ history_.push_back(token);
auto word = decode(token);
*os << word << std::flush;
output_str += word;
@@ -105,7 +124,8 @@ std::string Llm::response(const std::string& query, std::ostream* os) {
print_speed();
#endif
// update Cache
- runtime_manager_->updateCache();
+ // runtime_manager_->updateCache();
+ // reset forward info
return output_str;
}
@@ -128,13 +148,7 @@ void Llm::print_speed() {
}
void Llm::reset() {
- past_key_values_.clear();
- past_key_values_.shrink_to_fit();
- prompt_len_ = 0;
- gen_seq_len_ = 0;
- all_seq_len_ = 0;
- prefill_us_ = 0;
- decode_us_ = 0;
+ history_.clear();
}
void Llm::load(const std::string& model_dir) {
@@ -153,54 +167,55 @@ void Llm::load(const std::string& model_dir) {
const char* cacheFileName = ".tempcache";
runtime_manager_->setCache(cacheFileName);
}
+ load_progress_ = 0.f;
// 1. load vocab
std::string tokenizer_path = model_dir + "/tokenizer.txt";
+ load_progress_ += 5.f;
tokenizer_->load(tokenizer_path);
+ load_progress_ += 5.f;
// 2. load model
Module::Config module_config;
module_config.shapeMutable = true;
module_config.rearrange = true;
- load_progress_ = 0.f;
if (is_single_) {
key_value_shape_.insert(key_value_shape_.begin(), layer_nums_);
modules_.resize(1);
std::string model_path = model_dir;
std::string external_path = model_dir + ".weight";
- printf("load %s ... ", model_path.c_str());
+ MNN_PRINT("load %s ... ", model_path.c_str());
runtime_manager_->setExternalFile(external_path);
modules_[0].reset(Module::load(
{"input_ids", "attention_mask", "position_ids", "past_key_values"},
{"token_id", "presents"}, model_path.c_str(), runtime_manager_, &module_config));
- printf("Done!\n");
- fflush(stdout);
+ MNN_PRINT("Done!\n");
+ load_progress_ += 90.f;
} else {
// 2. load models
modules_.resize(layer_nums_ + 2);
- float step = 100.0 / modules_.size();
+ float step = 90.0 / modules_.size();
char buffer[50];
// load lm model
std::string lm_model_path = model_dir + "/lm.mnn";
std::string embedding_model_path = model_dir + "/embedding.mnn";
- printf("[%3.0f%% ] load %s model ... ", load_progress_, lm_model_path.c_str());
+ MNN_PRINT("[%3.0f%% ] load %s model ... ", load_progress_, lm_model_path.c_str());
modules_[layer_nums_].reset(Module::load({}, {}, lm_model_path.c_str(), runtime_manager_, &module_config));
- printf("Done!\n");
+ MNN_PRINT("Done!\n");
load_progress_ += step;
#ifndef USING_DISK_EMBED
- printf("[%3.0f%% ] load %s model ... ", load_progress_, embedding_model_path.c_str());fflush(stdout);
+ MNN_PRINT("[%3.0f%% ] load %s model ... ", load_progress_, embedding_model_path.c_str());fflush(stdout);
modules_[layer_nums_ + 1].reset(Module::load({}, {}, embedding_model_path.c_str(), runtime_manager_, &module_config));
- printf("Done!\n");
+ MNN_PRINT("Done!\n");
load_progress_ += step;
#endif
// load glm_block models
for (int i = 0; i < layer_nums_; i++) {
load_progress_ += step;
std::string model_path = model_dir + "/block_" + std::to_string(i) + ".mnn";
- printf("[%3.0f%% ] load %s model ... ", load_progress_, model_path.c_str());
+ MNN_PRINT("[%3.0f%% ] load %s model ... ", load_progress_, model_path.c_str());
modules_[i].reset(Module::load(
{"inputs_embeds", "attention_mask", "position_ids", "past_key_values"},
{"hidden_states", "presents"}, model_path.c_str(), runtime_manager_, &module_config));
- printf("Done!\n");
- fflush(stdout);
+ MNN_PRINT("Done!\n");
}
}
if (config.type == MNN_FORWARD_OPENCL) {
@@ -210,13 +225,18 @@ void Llm::load(const std::string& model_dir) {
void Llm::warmup() {
// warmup
- printf("### warmup ... ");
- for (int i = 0; i < layer_nums_; i++) {
+ MNN_PRINT("### warmup ... ");
+ if (is_single_) {
past_key_values_.push_back(_Input(key_value_shape_, NCHW));
+ } else {
+ for (int i = 0; i < layer_nums_; i++) {
+ past_key_values_.push_back(_Input(key_value_shape_, NCHW));
+ }
}
- std::vector tmp_0(1, 0);
- forward(tmp_0);
- reset();
+ std::vector tmp(1, 0);
+ forward(tmp);
+ all_seq_len_ = 0;
+ gen_seq_len_ = 0;
printf("Done\n");
}