diff --git a/.gitmodules b/.gitmodules index 7029b7ed..38f0d41d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "crates/llm-chain-llama/sys/llama.cpp"] path = crates/llm-chain-llama-sys/llama.cpp url = https://github.com/ggerganov/llama.cpp.git +[submodule "crates/llm-chain-gemma-sys/gemma.cpp"] + path = crates/llm-chain-gemma-sys/gemma.cpp + url = https://github.com/google/gemma.cpp.git diff --git a/Cargo.lock b/Cargo.lock index 29d6965b..ce8e99f8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -698,12 +698,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.83" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" -dependencies = [ - "libc", -] +checksum = "3286b845d0fccbdd15af433f61c5970e711987036cb468f437ff6badd70f4e24" [[package]] name = "cexpr" @@ -1800,6 +1797,23 @@ dependencies = [ "uuid", ] +[[package]] +name = "llm-chain-gemma" +version = "0.1.0" +dependencies = [ + "async-trait", + "llm-chain", + "llm-chain-gemma-sys", + "tokio", +] + +[[package]] +name = "llm-chain-gemma-sys" +version = "0.1.0" +dependencies = [ + "cc", +] + [[package]] name = "llm-chain-hnsw" version = "0.13.0" diff --git a/crates/llm-chain-gemma-sys/Cargo.toml b/crates/llm-chain-gemma-sys/Cargo.toml new file mode 100644 index 00000000..58c22a20 --- /dev/null +++ b/crates/llm-chain-gemma-sys/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "llm-chain-gemma-sys" +description = "A library with bindings for gemma.cpp" +version = "0.1.0" +edition = "2021" +license = "MIT" +keywords = ["llm", "langchain", "gemma", "chain"] +categories = ["science"] +authors = [ + "Jun Mukai ", +] +repository = "https://github.com/sobelio/llm-chain/" +readme = "README.md" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] + +[build-dependencies] +cc = "1.0.87" diff --git a/crates/llm-chain-gemma-sys/a.out b/crates/llm-chain-gemma-sys/a.out new file mode 100755 index 00000000..32335178 Binary files /dev/null and b/crates/llm-chain-gemma-sys/a.out differ diff --git a/crates/llm-chain-gemma-sys/build.rs b/crates/llm-chain-gemma-sys/build.rs new file mode 100644 index 00000000..9280e7fb --- /dev/null +++ b/crates/llm-chain-gemma-sys/build.rs @@ -0,0 +1,123 @@ +#![allow(clippy::uninlined_format_args)] + +extern crate cc; + +use std::env; + +fn main() { + let target = env::var("TARGET").unwrap(); + // Link C++ standard library + if let Some(cpp_stdlib) = get_cpp_link_stdlib(&target) { + println!("cargo:rustc-link-lib=dylib={}", cpp_stdlib); + println!("cargo:rustc-link-arg=-l{}", cpp_stdlib); + } + // Link macOS Accelerate framework for matrix calculations + if target.contains("apple") { + println!("cargo:rustc-link-lib=framework=Accelerate"); + } + println!("cargo:rustc-link-search={}", env::var("OUT_DIR").unwrap()); + println!("cargo:rustc-link-lib=static=gemma"); + println!("cargo:rustc-link-lib=static=hwy"); + println!("cargo:rustc-link-lib=static=hwy_contrib"); + println!("cargo:rustc-link-lib=static=sentencepiece"); + println!("cargo:rustc-link-lib=static=bindings"); + println!("cargo:rerun-if-changed=wrapper.h"); + + // stop if we're on docs.rs + if env::var("DOCS_RS").is_ok() { + return; + } + + // Run cmake to generate build files. + env::set_current_dir("gemma.cpp").expect("Unable to change directory to gemma.cpp"); + env::set_current_dir("build").expect("Unable to change directory to gemma.cpp build"); + + env::set_var("CXXFLAGS", "-fPIC"); + env::set_var("CFLAGS", "-fPIC"); + + let mut code = std::process::Command::new("cmake"); + let code = code + .arg("..") + .arg("-DCMAKE_BUILD_TYPE=Release") + .arg("-DBUILD_SHARED_LIBS=OFF") + .arg("-DWEIGHT_TYPE=hwy::bfloat16_t") + .arg("-DSPM_ENABLE_SHARED=OFF"); + let code = code.status().expect("Failed to generate build script"); + if code.code() != Some(0) { + panic!("Failed to generate build script"); + } + + // Build binary. + #[allow(clippy::suspicious_command_arg_space)] + let code = std::process::Command::new("cmake") + .arg("--build") + .arg(".") + .arg("--config Release") + .arg("--") + .arg("libgemma") + .status() + .expect("Failed to build lib"); + if code.code() != Some(0) { + panic!("Failed to build lib"); + } + + // move libllama.a to where Cargo expects it (OUT_DIR) + #[cfg(target_os = "windows")] + { + // I haven't tested windows, so it's not supported yet. + } + + #[cfg(not(target_os = "windows"))] + { + std::fs::copy( + "libgemma.a", + format!("{}/libgemma.a", env::var("OUT_DIR").unwrap()), + ) + .expect("Failed to copy lib"); + + std::fs::copy( + "_deps/highway-build/libhwy.a", + format!("{}/libhwy.a", env::var("OUT_DIR").unwrap()), + ) + .expect("Failed to copy libwhy.a"); + + std::fs::copy( + "_deps/highway-build/libhwy_contrib.a", + format!("{}/libhwy_contrib.a", env::var("OUT_DIR").unwrap()), + ) + .expect("Failed to copy libwhy_contrib.a"); + + std::fs::copy( + "_deps/sentencepiece-build/src/libsentencepiece.a", + format!("{}/libsentencepiece.a", env::var("OUT_DIR").unwrap()), + ) + .expect("Failed to copy libsentencepiece.a"); + } + + // Finally, build bindings.cc to allow access for gemma.cpp. + // So far, bindgen does not correctly generate buildable rust file, + // so I manually wrote bindings.rs for hand-written src/bindings.cc file. + env::set_current_dir("..").expect("Unlable to change directory back to gemma.cpp"); + env::set_current_dir("..").expect("Unlable to change directory back to crate top"); + + cc::Build::new() + .cpp(true) + .file("src/bindings.cc") + .include("./gemma.cpp") + .include("./gemma.cpp/build/_deps/highway-src") + .include("./gemma.cpp/build/_deps/sentencepiece-src") + .compile("bindings"); +} + +// From https://github.com/alexcrichton/cc-rs/blob/fba7feded71ee4f63cfe885673ead6d7b4f2f454/src/lib.rs#L2462 +fn get_cpp_link_stdlib(target: &str) -> Option<&'static str> { + if target.contains("msvc") { + None + } else if target.contains("apple") || target.contains("freebsd") || target.contains("openbsd") { + Some("c++") + } else if target.contains("android") { + Some("c++_shared") + } else { + Some("stdc++") + } +} diff --git a/crates/llm-chain-gemma-sys/gemma.cpp b/crates/llm-chain-gemma-sys/gemma.cpp new file mode 160000 index 00000000..0508e2c2 --- /dev/null +++ b/crates/llm-chain-gemma-sys/gemma.cpp @@ -0,0 +1 @@ +Subproject commit 0508e2c2e1c3a2ed63564886f4b5468dce5c9871 diff --git a/crates/llm-chain-gemma-sys/src/bindings.cc b/crates/llm-chain-gemma-sys/src/bindings.cc new file mode 100644 index 00000000..3fecf508 --- /dev/null +++ b/crates/llm-chain-gemma-sys/src/bindings.cc @@ -0,0 +1,211 @@ +#include + +extern "C" { + +gcpp::LoaderArgs* gcpp_LoaderArgs_LoaderArgs(int argc, char* argv[]) { + return new gcpp::LoaderArgs(argc, argv); +} + +void gcpp_LoaderArgs_destructor(gcpp::LoaderArgs* args) { + delete args; +} + +const char* gcpp_LoaderArgs_Validate(gcpp::LoaderArgs* args) { + return args->Validate(); +} + +gcpp::Model gcpp_LoaderArgs_ModelType(const gcpp::LoaderArgs* args) { + return args->ModelType(); +} + +gcpp::ModelTraining gcpp_LoaderArgs_ModelTraining(const gcpp::LoaderArgs* args) { + return args->ModelTraining(); +} + +void gcpp_LoaderArgs_SetTokenizer(gcpp::LoaderArgs* args, char* path) { + args->tokenizer.path = std::string(path); +} + +const char* gcpp_LoaderArgs_Tokenizer(gcpp::LoaderArgs* args) { + return args->tokenizer.path.c_str(); +} + +void gcpp_LoaderArgs_SetModel(gcpp::LoaderArgs* args, char* path) { + args->model.path = std::string(path); +} + +const char* gcpp_LoaderArgs_Model(gcpp::LoaderArgs* args) { + return args->model.path.c_str(); +} + +void gcpp_LoaderArgs_SetCache(gcpp::LoaderArgs* args, char* path) { + args->cache.path = std::string(path); +} + +const char* gcpp_LoaderArgs_Cache(gcpp::LoaderArgs* args) { + return args->cache.path.c_str(); +} + +void gcpp_LoaderArgs_SetModelTypeValue(gcpp::LoaderArgs* args, char* v) { + args->model_type = std::string(v); +} + +const char* gcpp_LoaderArgs_ModelTypeValue(gcpp::LoaderArgs* args) { + return args->model_type.c_str(); +} + +hwy::ThreadPool* hwy_ThreadPool_ThreadPool(size_t num_threads) { + return new hwy::ThreadPool(num_threads); +} + +void hwy_ThreadPool_destructor(hwy::ThreadPool* pool) { + delete pool; +} + +gcpp::Gemma* gcpp_Gemma_Gemma(const gcpp::LoaderArgs* args, hwy::ThreadPool* pool) { + return new gcpp::Gemma(*args, *pool); +} + +void gcpp_Gemma_destructor(gcpp::Gemma* gemma) { + delete gemma; +} + +std::vector* std_vector_int_vector() { + return new std::vector(); +} + +void std_vector_int_destructor(std::vector* v) { + delete v; +} + +size_t std_vector_int_size(const std::vector* v) { + return v->size(); +} + +int std_vector_int_at(const std::vector* v, size_t i) { + return v->at(i); +} + +std::string* std_string_string() { + return new std::string(); +} + +void std_string_destructor(std::string* s) { + delete s; +} + +const char* std_string_c_str(const std::string* s) { + return s->c_str(); +} + +bool gcpp_Gemma_Encode(gcpp::Gemma* gemma, const char* input, size_t len, std::vector* out) { + return gemma->Tokenizer().Encode(std::string(input, len), out).ok(); +} + +bool gcpp_Gemma_Decode(gcpp::Gemma* gemma, int token, std::string* out) { + return gemma->Tokenizer().Decode(std::vector{token}, out).ok(); +} + +bool gcpp_Gemma_Decodes(gcpp::Gemma* gemma, const int* tokens, int num_tokens, std::string* out) { + std::vector v; + v.reserve(num_tokens); + for (int i = 0; i < num_tokens; i++) { + v.push_back(tokens[i]); + } + return gemma->Tokenizer().Decode(v, out).ok(); +} + +gcpp::InferenceArgs* gcpp_InferenceArgs_InferenceArgs(int argc, char* argv[]) { + return new gcpp::InferenceArgs(argc, argv); +} + +void gcpp_InferenceArgs_destructor(gcpp::InferenceArgs* args) { + delete args; +} + +const char* gcpp_InferenceArgs_Validate(gcpp::InferenceArgs* args) { + return args->Validate(); +} + +size_t gcpp_InferenceArgs_MaxTokens(gcpp::InferenceArgs* args) { + return args->max_tokens; +} + +void gcpp_InferenceArgs_SetMaxTokens(gcpp::InferenceArgs* args, size_t mt) { + args->max_tokens = mt; +} + +size_t gcpp_InferenceArgs_MaxGeneratedTokens(gcpp::InferenceArgs* args) { + return args->max_generated_tokens; +} + +void gcpp_InferenceArgs_SetMaxGeneratedTokens(gcpp::InferenceArgs* args, size_t mgt) { + args->max_generated_tokens = mgt; +} + +float gcpp_InferenceArgs_Temperature(gcpp::InferenceArgs* args) { + return args->temperature; +} + +void gcpp_InferenceArgs_SetTemperature(gcpp::InferenceArgs* args, float t) { + args->temperature = t; +} + +bool gcpp_InferenceArgs_Deterministic(gcpp::InferenceArgs* args) { + return args->deterministic; +} + +void gcpp_InferenceArgs_SetDeterministic(gcpp::InferenceArgs* args, bool d) { + args->deterministic = d; +} + +bool gcpp_InferenceArgs_Multiturn(gcpp::InferenceArgs* args) { + return args->multiturn; +} + +void gcpp_InferenceArgs_SetMultiturn(gcpp::InferenceArgs* args, bool mt) { + args->multiturn = mt; +} + +std::mt19937* std_mt19937_mt19937() { + return new std::mt19937(); +} + +void std_mt19937_destructor(std::mt19937* gen) { + delete gen; +} + +void std_mt19937_seed(std::mt19937* gen, int seed) { + gen->seed(seed); +} + +void std_mt19937_random_seed(std::mt19937* gen) { + std::random_device rd; + gen->seed(rd()); +} + +typedef bool (*stream_callback)(void*, int, float); +typedef bool (*accept_callback)(void*, int); + +void gcpp_GenerateGemma( + gcpp::Gemma* gemma, const gcpp::InferenceArgs* args, + const std::vector* prompt, size_t start_pos, + hwy::ThreadPool* pool, hwy::ThreadPool* inner_pool, + void* stream_context, + stream_callback stream_token, + void* accept_context, + accept_callback accept_token, + std::mt19937* gen, int verbosity) { + gcpp::GenerateGemma( + *gemma, *args, *prompt, start_pos, + *pool, *inner_pool, + [&stream_context, &stream_token](int token, float value) { + return stream_token(stream_context, token, value); + }, + [&accept_context, &accept_token](int token) { + return accept_token(accept_context, token); + }, + *gen, verbosity); +} + +} \ No newline at end of file diff --git a/crates/llm-chain-gemma-sys/src/bindings.rs b/crates/llm-chain-gemma-sys/src/bindings.rs new file mode 100644 index 00000000..ebb4428f --- /dev/null +++ b/crates/llm-chain-gemma-sys/src/bindings.rs @@ -0,0 +1,245 @@ +use std::ffi; + +pub type gcpp_Model = ffi::c_int; +pub const gcpp_Model_GEMMA_2B: gcpp_Model = 0; +pub const gcpp_Model_GEMMA_7B: gcpp_Model = 1; + +pub type gcpp_ModelTraining = ffi::c_int; +pub const gcpp_ModelTraining_GEMMA_IT: gcpp_ModelTraining = 0; +pub const gcpp_ModelTraining_GEMMA_PT: gcpp_ModelTraining = 1; + +pub const EOS_ID: i32 = 1; + +#[repr(C)] +pub struct gcpp_LoaderArgs { + _data: [u8; 0], + _marker: + core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + +extern "C" { + pub fn gcpp_LoaderArgs_LoaderArgs(argc: ffi::c_int, argv: *mut *mut ffi::c_char) -> *mut gcpp_LoaderArgs; + pub fn gcpp_LoaderArgs_destructor(largs: *mut gcpp_LoaderArgs); + pub fn gcpp_LoaderArgs_Validate(largs: *mut gcpp_LoaderArgs) -> *const ffi::c_char; + pub fn gcpp_LoaderArgs_ModelType(largs: *const gcpp_LoaderArgs) -> gcpp_Model; + pub fn gcpp_LoaderArgs_ModelTraining(largs: *const gcpp_LoaderArgs) -> gcpp_ModelTraining; + pub fn gcpp_LoaderArgs_SetTokenizer(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char); + pub fn gcpp_LoaderArgs_Tokenizer(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char; + pub fn gcpp_LoaderArgs_SetModel(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char); + pub fn gcpp_LoaderArgs_Model(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char; + pub fn gcpp_LoaderArgs_SetCache(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char); + pub fn gcpp_LoaderArgs_Cache(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char; + pub fn gcpp_LoaderArgs_SetModelTypeValue(largs: *mut gcpp_LoaderArgs, s: *const ffi::c_char); + pub fn gcpp_LoaderArgs_ModelTypeValue(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char; +} + +#[repr(C)] +pub struct hwy_ThreadPool { + _data: [u8; 0], + _marker: + core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + +extern "C" { + pub fn hwy_ThreadPool_ThreadPool(num_threads: ffi::c_uint) -> *mut hwy_ThreadPool; + pub fn hwy_ThreadPool_destructor(pool: *mut hwy_ThreadPool); +} + +#[repr(C)] +pub struct gcpp_Gemma { + _data: [u8; 0], + _marker: + core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + +extern "C" { + pub fn gcpp_Gemma_Gemma(args: *mut gcpp_LoaderArgs, pool: *mut hwy_ThreadPool) -> *mut gcpp_Gemma; + pub fn gcpp_Gemma_destructor(gemma: *mut gcpp_Gemma); + pub fn gcpp_Gemma_Encode(gemma: *mut gcpp_Gemma, input: *mut ffi::c_char, len: ffi::c_uint, out: *mut std_vector_int) -> ffi::c_char; + pub fn gcpp_Gemma_Decode(gemma: *mut gcpp_Gemma, token: ffi::c_int, out: *mut std_string) -> ffi::c_char; + pub fn gcpp_Gemma_Decodes(gemma: *mut gcpp_Gemma, tokens: *const ffi::c_int, num_tokens: ffi::c_int, out: *mut std_string) -> ffi::c_char; + + pub fn gcpp_GenerateGemma( + gemma: *mut gcpp_Gemma, args: *mut gcpp_InferenceArgs, + prompt: *const std_vector_int, start_pos: ffi::c_uint, + pool: *mut hwy_ThreadPool, inner_pool: *mut hwy_ThreadPool, + stream_context: *mut ffi::c_void, + stream_token: extern fn(*mut ffi::c_void, ffi::c_int, ffi::c_float) -> ffi::c_char, + accept_context: *mut ffi::c_void, + accept_token: extern fn(*mut ffi::c_void, ffi::c_int) -> ffi::c_char, + gen: *mut std_mt19937, verbosity: ffi::c_int, + ); +} + +#[repr(C)] +pub struct std_vector_int { + _data: [u8; 0], + _marker: + core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + +extern "C" { + pub fn std_vector_int_vector() -> *mut std_vector_int; + pub fn std_vector_int_destructor(v: *mut std_vector_int); + pub fn std_vector_int_size(v: *const std_vector_int) -> ffi::c_uint; + pub fn std_vector_int_at(v: *const std_vector_int, i: ffi::c_uint) -> ffi::c_int; +} + +pub struct std_vector_int_iter { + v: *mut std_vector_int, + i: ffi::c_uint, +} + +impl std_vector_int_iter { + pub fn new(v: *mut std_vector_int) -> std_vector_int_iter { + std_vector_int_iter{ + v: v, + i: 0, + } + } +} + +impl ExactSizeIterator for std_vector_int_iter { + fn len(&self) -> usize { + unsafe { std_vector_int_size(self.v) as usize } + } +} + +impl Iterator for std_vector_int_iter { + type Item = i32; + + fn next(&mut self) -> Option { + unsafe { + if self.i < std_vector_int_size(self.v) { + let v = std_vector_int_at(self.v, self.i); + self.i += 1; + Some(v as i32) + } else { + None + } + } + } +} + +#[repr(C)] +pub struct std_string { + _data: [u8; 0], + _marker: + core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + +extern "C" { + pub fn std_string_string() -> *mut std_string; + pub fn std_string_destructor(s: *mut std_string); + pub fn std_string_c_str(s: *const std_string) -> *mut ffi::c_char; +} + +#[repr(C)] +pub struct gcpp_InferenceArgs { + _data: [u8; 0], + _marker: + core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + +extern "C" { + pub fn gcpp_InferenceArgs_InferenceArgs(argc: ffi::c_int, argv: *mut *mut ffi::c_char) -> *mut gcpp_InferenceArgs; + pub fn gcpp_InferenceArgs_destructor(args: *mut gcpp_InferenceArgs); + pub fn gcpp_InferenceArgs_Validate(args: *mut gcpp_InferenceArgs) -> *const ffi::c_char; + + pub fn gcpp_InferenceArgs_MaxTokens(args: *const gcpp_InferenceArgs) -> ffi::c_uint; + pub fn gcpp_InferenceArgs_SetMaxTokens(args: *mut gcpp_InferenceArgs, mt: ffi::c_uint); + pub fn gcpp_InferenceArgs_MaxGeneratedTokens(args: *const gcpp_InferenceArgs) -> ffi::c_uint; + pub fn gcpp_InferenceArgs_SetMaxGeneratedTokens(args: *mut gcpp_InferenceArgs, mgt: ffi::c_uint); + pub fn gcpp_InferenceArgs_Temperature(args: *const gcpp_InferenceArgs) -> ffi::c_float; + pub fn gcpp_InferenceArgs_SetTemperature(args: *mut gcpp_InferenceArgs, t: ffi::c_float); + pub fn gcpp_InferenceArgs_Deterministic(args: *const gcpp_InferenceArgs) -> ffi::c_char; + pub fn gcpp_InferenceArgs_SetDeterministic(args: *mut gcpp_InferenceArgs, d: ffi::c_char); + pub fn gcpp_InferenceArgs_Multiturn(args: *const gcpp_InferenceArgs) -> ffi::c_char; + pub fn gcpp_InferenceArgs_SetMultiturn(args: *mut gcpp_InferenceArgs, mt: ffi::c_char); +} + +#[repr(C)] +pub struct std_mt19937 { + _data: [u8; 0], + _marker: + core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + +extern "C" { + pub fn std_mt19937_mt19937() -> *mut std_mt19937; + pub fn std_mt19937_destructor(gen: *mut std_mt19937); + pub fn std_mt19937_seed(gen: *mut std_mt19937, seed: ffi::c_int); + pub fn std_mt19937_random_seed(gen: *mut std_mt19937); +} + +#[cfg(test)] +mod test { + use crate::*; + + #[test] + fn create_and_delete_largs() { + let args = vec![ + "prog", + "--tokenizer", "tokenizer.spm", + "--model", "2b-pt", + "--compressed_weights", "2b-pt.sbs", + ]; + unsafe { + let largs = gcpp_LoaderArgs_LoaderArgs( + args.len() as ffi::c_int, + Vec::from_iter(args.into_iter().map(|arg| + ffi::CString::new(arg).unwrap().into_raw() + ).into_iter()).as_mut_ptr(), + ); + assert_eq!(gcpp_Model_GEMMA_2B, gcpp_LoaderArgs_ModelType(largs)); + assert_eq!(gcpp_ModelTraining_GEMMA_PT, gcpp_LoaderArgs_ModelTraining(largs)); + let tp = gcpp_LoaderArgs_Tokenizer(largs); + let s = ffi::CStr::from_ptr(tp).to_str().unwrap(); + assert_eq!(s, "tokenizer.spm"); + gcpp_LoaderArgs_destructor(largs); + } + } + + #[test] + fn create_and_delete_largs_direct() { + let tokenizer_path = "tokenizer.spm"; + let compressed_weights = "2b-pt.sbs"; + let model = "2b-pt"; + unsafe { + let largs = gcpp_LoaderArgs_LoaderArgs(0, std::ptr::null_mut()); + gcpp_LoaderArgs_SetTokenizer(largs, ffi::CString::new(tokenizer_path).unwrap().as_ptr()); + gcpp_LoaderArgs_SetCache(largs, ffi::CString::new(compressed_weights).unwrap().as_ptr()); + gcpp_LoaderArgs_SetModelTypeValue(largs, ffi::CString::new(model).unwrap().as_ptr()); + let err = gcpp_LoaderArgs_Validate(largs); + if err != std::ptr::null_mut() { + println!("{}", ffi::CStr::from_ptr(err).to_str().unwrap()); + } + assert_eq!(std::ptr::null(), err); + } + } + + #[test] + fn create_and_delete_iargs_direct() { + unsafe { + let iargs = gcpp_InferenceArgs_InferenceArgs(0, std::ptr::null_mut()); + assert_eq!(gcpp_InferenceArgs_Validate(iargs), std::ptr::null()); + + assert_eq!(gcpp_InferenceArgs_MaxGeneratedTokens(iargs), 2048); + assert_eq!(gcpp_InferenceArgs_MaxTokens(iargs), 3072); + + gcpp_InferenceArgs_SetMaxGeneratedTokens(iargs, 4096); + + assert_ne!(gcpp_InferenceArgs_Validate(iargs), std::ptr::null()); + + gcpp_InferenceArgs_destructor(iargs); + } + } + + #[test] + fn create_and_delete_pool() { + unsafe { + let pool = hwy_ThreadPool_ThreadPool(1); + hwy_ThreadPool_destructor(pool); + } + } +} \ No newline at end of file diff --git a/crates/llm-chain-gemma-sys/src/check.cc b/crates/llm-chain-gemma-sys/src/check.cc new file mode 100644 index 00000000..dd8b023e --- /dev/null +++ b/crates/llm-chain-gemma-sys/src/check.cc @@ -0,0 +1,42 @@ +#include +#include +#include + +int main(int argc, char* argv[]) { + gcpp::LoaderArgs largs(argc, argv); + gcpp::InferenceArgs iargs(argc, argv); + gcpp::AppArgs aargs(argc, argv); + + largs.Validate(); + + hwy::ThreadPool pool(1); + hwy::ThreadPool inner_pool(0); + + gcpp::Gemma gemma(largs, pool); + + std::mt19937 gen; + gen.seed(42); + + std::vector tokens; + gemma.Tokenizer().Encode( + "user\nWhat is a gemma?\nmodel\n", &tokens); + for (auto token : tokens) { + std::cout << "token: " << token << std::endl; + } + + gcpp::GenerateGemma( + gemma, iargs, tokens, 0, + pool, inner_pool, + [&gemma](int token, float value) { + std::string decoded; + gemma.Tokenizer().Decode(std::vector{token}, &decoded); + std::cout << decoded; + return true; + }, + [](int token) { return true; }, + gen, + 10 + ); + std::cout << std::endl; + return 0; +} diff --git a/crates/llm-chain-gemma-sys/src/lib.rs b/crates/llm-chain-gemma-sys/src/lib.rs new file mode 100644 index 00000000..5bbe75e9 --- /dev/null +++ b/crates/llm-chain-gemma-sys/src/lib.rs @@ -0,0 +1,5 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] + +include!("./bindings.rs"); \ No newline at end of file diff --git a/crates/llm-chain-gemma-sys/wrapper.h b/crates/llm-chain-gemma-sys/wrapper.h new file mode 100644 index 00000000..02084e24 --- /dev/null +++ b/crates/llm-chain-gemma-sys/wrapper.h @@ -0,0 +1 @@ +#include \ No newline at end of file diff --git a/crates/llm-chain-gemma/Cargo.toml b/crates/llm-chain-gemma/Cargo.toml new file mode 100644 index 00000000..2eaffb08 --- /dev/null +++ b/crates/llm-chain-gemma/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "llm-chain-gemma" +description = "The llm-chain implementation for Gemma." +version = "0.1.0" +edition = "2021" +license = "MIT" +keywords = ["llm", "langchain", "gemma", "chain"] +categories = ["science"] +authors = [ + "Jun Mukai ", +] +readme = "./README.md" +repository = "https://github.com/sobelio/llm-chain/" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +async-trait.workspace = true +llm-chain = {path = "../llm-chain", version="0.13.0"} +llm-chain-gemma-sys = {path = "../llm-chain-gemma-sys", version="0.1.0"} +tokio.workspace = true diff --git a/crates/llm-chain-gemma/examples/simple.rs b/crates/llm-chain-gemma/examples/simple.rs new file mode 100644 index 00000000..f9a53cdb --- /dev/null +++ b/crates/llm-chain-gemma/examples/simple.rs @@ -0,0 +1,54 @@ +use llm_chain::options; +use llm_chain::options::ModelRef; +use llm_chain::{executor, parameters, prompt}; +use std::env::args; +use std::path::Path; + +/// This example demonstrates how to use the llm-chain-gemma crate to generate text using a +/// Gemma. +/// +/// Usage: cargo run --example simple path/to/model prompt +/// +/// Note: gemma requires 2 files to load, one for the model itself and the other is for +/// sentencepiece. Currently it assumes both resides in the same directory, and the +/// sentencepiece file name is tokenizer.sbm + +fn get_model_type(model_path: &str) -> &str { + let p = Path::new(model_path); + if let Some(stem) = p.file_stem() { + if let Some(model_type) = stem.to_str() { + return model_type + } + } + "2b-it" +} + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let raw_args: Vec = args().collect(); + let args = match &raw_args.len() { + 2 => ( + raw_args[1].as_str(), + "Rust is a cool programming language because", + ), + 3 => (raw_args[1].as_str(), raw_args[2].as_str()), + _ => { + panic!("Usage: cargo run --release --example simple ") + } + }; + + let model_path = args.0; + let prompt = args.1; + let opts = options!( + Model: ModelRef::from_path(model_path), + ModelType: get_model_type(model_path), + ModelType: "gemma", + Temperature: 0.8 + ); + let exec = executor!(gemma, opts.clone())?; + + let res = prompt!(prompt).run(¶meters!(), &exec).await?; + + println!("{}", res.to_immediate().await?); + Ok(()) +} diff --git a/crates/llm-chain-gemma/src/context.rs b/crates/llm-chain-gemma/src/context.rs new file mode 100644 index 00000000..f266fd28 --- /dev/null +++ b/crates/llm-chain-gemma/src/context.rs @@ -0,0 +1,222 @@ +use std::ffi; +use std::path::Path; +use llm_chain::output::StreamSegment; +use llm_chain_gemma_sys::{ + gcpp_Gemma, gcpp_Gemma_Decode, gcpp_Gemma_Decodes, gcpp_Gemma_Encode, gcpp_Gemma_Gemma, gcpp_Gemma_destructor, gcpp_GenerateGemma, gcpp_InferenceArgs, gcpp_InferenceArgs_InferenceArgs, gcpp_InferenceArgs_MaxGeneratedTokens, gcpp_InferenceArgs_Multiturn, gcpp_InferenceArgs_SetMaxTokens, gcpp_InferenceArgs_SetTemperature, gcpp_InferenceArgs_Validate, gcpp_InferenceArgs_destructor, gcpp_LoaderArgs_LoaderArgs, gcpp_LoaderArgs_ModelTraining, gcpp_LoaderArgs_SetCache, gcpp_LoaderArgs_SetModelTypeValue, gcpp_LoaderArgs_SetTokenizer, gcpp_LoaderArgs_Validate, gcpp_LoaderArgs_destructor, gcpp_ModelTraining, gcpp_ModelTraining_GEMMA_IT, hwy_ThreadPool, hwy_ThreadPool_ThreadPool, hwy_ThreadPool_destructor, std_mt19937, std_mt19937_destructor, std_mt19937_mt19937, std_mt19937_random_seed, std_string_c_str, std_string_destructor, std_string_string, std_vector_int_destructor, std_vector_int_iter, std_vector_int_size, std_vector_int_vector, EOS_ID +}; +use llm_chain::options::{Opt, Options, OptDiscriminants}; +use llm_chain::tokens::{TokenCollection, Tokenizer, TokenizerError}; +use llm_chain::traits::ExecutorCreationError; +use tokio::sync::mpsc; + +pub struct GemmaContext { + gemma: *mut gcpp_Gemma, + model_training: gcpp_ModelTraining, + gen: *mut std_mt19937, + pub iargs: *mut gcpp_InferenceArgs, + pool: *mut hwy_ThreadPool, + inner_pool: *mut hwy_ThreadPool, + pos: u32, +} + +impl GemmaContext { + pub fn new(options: &Options) -> Result { + unsafe { + let largs = gcpp_LoaderArgs_LoaderArgs(0, std::ptr::null_mut()); + if let Some(Opt::ModelType(mt)) = options.get(OptDiscriminants::ModelType) { + gcpp_LoaderArgs_SetModelTypeValue(largs, mt.clone().into_bytes().as_ptr() as *const i8); + } + if let Some(Opt::Model(m)) = options.get(OptDiscriminants::Model) { + // Typically the downloaded model data is compressed and set as cache. + // TODO: consider the case of non-compressed one? + let path = m.to_path(); + gcpp_LoaderArgs_SetCache(largs, path.as_ptr() as *const i8); + // TODO: consider adding the option for tokenizer file. + let parent = Path::new(&path).parent(); + if parent.is_none() { + return Err(ExecutorCreationError::InvalidValue(String::from("no parent for path"))); + } + if let Some(tokenizer_path) = parent.unwrap().join("tokenizer.spm").to_str() { + gcpp_LoaderArgs_SetTokenizer(largs, tokenizer_path.as_ptr() as *const i8); + } else { + return Err(ExecutorCreationError::InvalidValue(String::from("conversion from path to str for tokenizer"))); + } + } + + let err = gcpp_LoaderArgs_Validate(largs); + if err != std::ptr::null_mut() { + let msg = ffi::CString::from_raw(err as *mut ffi::c_char).into_string(); + if msg.is_err() { + return Err(ExecutorCreationError::InnerError(Box::new(msg.unwrap_err()))); + } + gcpp_LoaderArgs_destructor(largs); + return Err(ExecutorCreationError::InvalidValue(msg.unwrap())); + } + + let iargs = gcpp_InferenceArgs_InferenceArgs(0, std::ptr::null_mut()); + if let Some(Opt::Temperature(t)) = options.get(OptDiscriminants::Temperature) { + gcpp_InferenceArgs_SetTemperature(iargs, *t); + } + if let Some(Opt::MaxTokens(m)) = options.get(OptDiscriminants::MaxTokens) { + gcpp_InferenceArgs_SetMaxTokens(iargs, *m as ffi::c_uint); + } + + let err = gcpp_InferenceArgs_Validate(iargs); + if err != std::ptr::null_mut() { + let msg = ffi::CString::from_raw(err as *mut ffi::c_char).into_string(); + if msg.is_err() { + return Err(ExecutorCreationError::InnerError(Box::new(msg.unwrap_err()))); + } + gcpp_LoaderArgs_destructor(largs); + gcpp_InferenceArgs_destructor(iargs); + return Err(ExecutorCreationError::InvalidValue(msg.unwrap())); + } + + let pool = hwy_ThreadPool_ThreadPool( + if let Some(Opt::NThreads(nt)) = options.get(OptDiscriminants::NThreads) { + *nt as ffi::c_uint + } else { + 0 + }); + let inner_pool = hwy_ThreadPool_ThreadPool(1); + + let gemma = gcpp_Gemma_Gemma(largs, pool); + let gen = std_mt19937_mt19937(); + std_mt19937_random_seed(gen); + + let model_training = gcpp_LoaderArgs_ModelTraining(largs); + + gcpp_LoaderArgs_destructor(largs); + + Ok(GemmaContext{ + gemma: gemma, + gen: gen, + model_training: model_training as gcpp_ModelTraining, + iargs: iargs, + pool: pool, + inner_pool: inner_pool, + pos: 0, + }) + } + } +} + +impl Drop for GemmaContext { + fn drop(&mut self) { + unsafe { + gcpp_Gemma_destructor(self.gemma); + std_mt19937_destructor(self.gen); + gcpp_InferenceArgs_destructor(self.iargs); + hwy_ThreadPool_destructor(self.pool); + hwy_ThreadPool_destructor(self.inner_pool); + } + } +} + +#[repr(C)] +struct GenerateContext { + gemma: *mut gcpp_Gemma, + pos: u32, + tokens_processed: u32, + input_tokens: u32, + out: mpsc::UnboundedSender, +} + +extern fn stream_token(ctx: *mut ffi::c_void, token: ffi::c_int, _: ffi::c_float) -> ffi::c_char { + unsafe { + let gctx = ctx as *mut GenerateContext; + (*gctx).pos += 1; + (*gctx).tokens_processed += 1; + if (*gctx).tokens_processed < (*gctx).input_tokens { + return true as ffi::c_char; + } + if token == EOS_ID { + return true as ffi::c_char; + } + let s = std_string_string(); + if gcpp_Gemma_Decode((*gctx).gemma, token, s) == 0 { + return false as ffi::c_char; + } + let decoded = ffi::CString::from_raw(std_string_c_str(s)).into_string(); + if decoded.is_err() { + return false as ffi::c_char; + } + (*gctx).out.send(StreamSegment::Content(decoded.unwrap())).is_ok() as ffi::c_char + } +} + +extern fn accept_token(_ctx: *mut ffi::c_void, _token: ffi::c_int) -> ffi::c_char { + true as ffi::c_char +} + +impl GemmaContext { + pub fn generate<'a>(&mut self, prompt: String, out: mpsc::UnboundedSender) { + unsafe { + if gcpp_InferenceArgs_Multiturn(self.iargs) != 0 { + self.pos = 0 + } + let mut prompt_text = if self.model_training == gcpp_ModelTraining_GEMMA_IT { + format!("{prompt}model\n") + } else { + prompt + }; + if self.pos > 0 { + prompt_text = format!("{prompt_text}"); + } + let tokens = std_vector_int_vector(); + gcpp_Gemma_Encode(self.gemma, prompt_text.as_mut_ptr() as *mut ffi::c_char, prompt_text.len() as ffi::c_uint, tokens); + let mut genctx = GenerateContext{ + gemma: self.gemma, + pos: self.pos, + tokens_processed: 0, + input_tokens: std_vector_int_size(tokens) as u32, + out: out, + }; + gcpp_GenerateGemma( + self.gemma, self.iargs, + tokens, self.pos, self.pool, self.inner_pool, + (&mut genctx as *mut GenerateContext) as *mut ffi::c_void, stream_token, + std::ptr::null_mut(), accept_token, self.gen, 0); + self.pos = genctx.pos; + std_vector_int_destructor(tokens); + } + } + + pub fn max_generated_tokens(&self) -> u32 { + unsafe { + gcpp_InferenceArgs_MaxGeneratedTokens(self.iargs) + } + } +} + +impl Tokenizer for GemmaContext { + fn tokenize_str(&self, doc: &str) -> Result { + unsafe { + let mut doc_copied = String::from(doc); + let tokens = std_vector_int_vector(); + let result = gcpp_Gemma_Encode(self.gemma, doc_copied.as_mut_ptr() as *mut ffi::c_char, doc.len() as ffi::c_uint, tokens); + if result == 0 { + return Err(TokenizerError::ToStringError); + } + Ok(TokenCollection::from(Vec::from_iter(std_vector_int_iter::new(tokens)))) + } + } + + fn to_string(&self, tokens: TokenCollection) -> Result { + let ts = tokens.as_i32()?; + unsafe { + let out = std_string_string(); + let ok = gcpp_Gemma_Decodes(self.gemma, ts.as_ptr(), ts.len() as ffi::c_int, out); + if ok == 0 { + std_string_destructor(out); + return Err(TokenizerError::ToStringError); + } + let out_str = ffi::CString::from_raw(std_string_c_str(out)).into_string(); + std_string_destructor(out); + out_str.map_err(|_| TokenizerError::ToStringError) + } + } +} + +unsafe impl Sync for GemmaContext {} +unsafe impl Send for GemmaContext {} \ No newline at end of file diff --git a/crates/llm-chain-gemma/src/executor.rs b/crates/llm-chain-gemma/src/executor.rs new file mode 100644 index 00000000..5b5f367a --- /dev/null +++ b/crates/llm-chain-gemma/src/executor.rs @@ -0,0 +1,95 @@ +use llm_chain::options::{Opt, OptDiscriminants, Options}; +use llm_chain::prompt::Prompt; +use llm_chain::tokens::{PromptTokensError, TokenCollection, TokenCount, Tokenizer, TokenizerError}; +use llm_chain::output::Output; +use llm_chain::traits::{Executor as ExecutorTrait, ExecutorCreationError, ExecutorError}; +use std::sync::{Arc, Mutex}; +use async_trait::async_trait; +use crate::context::GemmaContext; +use tokio; + +pub struct Executor { + context: Arc>, + stream: bool, +} + +#[async_trait] +impl ExecutorTrait for Executor { + type StepTokenizer<'a> = GemmaTokenizer; + + fn new_with_options(options: Options) -> Result { + let gemma_context = GemmaContext::new(&options)?; + Ok(Executor{ + context: Arc::new(Mutex::new(gemma_context)), + stream: if let Some(Opt::Stream(s)) = options.get(OptDiscriminants::Stream) { + *s + } else { + false + }, + }) + } + + async fn execute(&self, options: &Options, prompt: &Prompt) -> Result { + let is_stream = if let Some(Opt::Stream(s)) = options.get(OptDiscriminants::Stream) { + *s + } else { + self.stream + }; + let (sender, stream) = Output::new_stream(); + let context = self.context.clone(); + let prompt_text = prompt.to_string(); + if is_stream { + tokio::task::spawn_blocking(move || { + if let Ok(mut ctx) = context.lock() { + ctx.generate(prompt_text, sender); + } + }); + return Ok(stream); + } else { + let mut ctx = context.lock().map_err(|_| ExecutorError::InvalidOptions)?; + ctx.generate(prompt_text, sender); + } + stream.to_immediate().await.map(|imm| Output::Immediate(imm)) + } + + fn tokens_used( + &self, + options: &Options, + prompt: &Prompt, + ) -> Result { + let tokenizer = self.get_tokenizer(options)?; + let tokens = tokenizer.tokenize_str(prompt.to_string().as_str())?; + Ok(TokenCount::new(self.max_tokens_allowed(options), tokens.len() as i32)) + } + + fn max_tokens_allowed(&self, options: &Options) -> i32 { + if let Some(Opt::MaxTokens(mt)) = options.get(OptDiscriminants::MaxTokens) { + return *mt as i32; + } + self.context.lock().unwrap().max_generated_tokens() as i32 + } + + fn get_tokenizer(&self, _options: &Options) -> Result, TokenizerError> { + Ok(GemmaTokenizer{context: self.context.clone()}) + } + + fn answer_prefix(&self, _prompt: &Prompt) -> Option { + None + } +} + +pub struct GemmaTokenizer { + context: Arc>, +} + +impl Tokenizer for GemmaTokenizer { + fn tokenize_str(&self, doc: &str) -> Result { + let ctx = self.context.lock().map_err(|_| TokenizerError::TokenizationError)?; + ctx.tokenize_str(doc) + } + + fn to_string(&self, tokens: TokenCollection) -> Result { + let ctx = self.context.lock().map_err(|_| TokenizerError::ToStringError)?; + ctx.to_string(tokens) + } +} \ No newline at end of file diff --git a/crates/llm-chain-gemma/src/lib.rs b/crates/llm-chain-gemma/src/lib.rs new file mode 100644 index 00000000..cf1f029c --- /dev/null +++ b/crates/llm-chain-gemma/src/lib.rs @@ -0,0 +1,4 @@ +mod executor; +mod context; + +pub use executor::Executor; \ No newline at end of file diff --git a/crates/llm-chain/src/executor.rs b/crates/llm-chain/src/executor.rs index dbf044cb..6de5fe52 100644 --- a/crates/llm-chain/src/executor.rs +++ b/crates/llm-chain/src/executor.rs @@ -56,6 +56,10 @@ macro_rules! executor { use llm_chain::traits::Executor; llm_chain_openai::chatgpt::Executor::new_with_options($options) }}; + (gemma, $options:expr) => {{ + use llm_chain::traits::Executor; + llm_chain_gemma::Executor::new_with_options($options) + }}; (llama) => {{ use llm_chain::traits::Executor; llm_chain_llama::Executor::new()