diff --git a/crates/llm-chain-gemma-sys/a.out b/crates/llm-chain-gemma-sys/a.out deleted file mode 100755 index 32335178..00000000 Binary files a/crates/llm-chain-gemma-sys/a.out and /dev/null differ diff --git a/crates/llm-chain-gemma-sys/gemma.cpp b/crates/llm-chain-gemma-sys/gemma.cpp index 0508e2c2..0221956b 160000 --- a/crates/llm-chain-gemma-sys/gemma.cpp +++ b/crates/llm-chain-gemma-sys/gemma.cpp @@ -1 +1 @@ -Subproject commit 0508e2c2e1c3a2ed63564886f4b5468dce5c9871 +Subproject commit 0221956b2e4fb5ec65d3685fad09f257cf5700e7 diff --git a/crates/llm-chain-gemma-sys/src/bindings.cc b/crates/llm-chain-gemma-sys/src/bindings.cc index 329353b4..73b54d62 100644 --- a/crates/llm-chain-gemma-sys/src/bindings.cc +++ b/crates/llm-chain-gemma-sys/src/bindings.cc @@ -2,58 +2,6 @@ extern "C" { -gcpp::LoaderArgs* gcpp_LoaderArgs_LoaderArgs(int argc, char* argv[]) { - return new gcpp::LoaderArgs(argc, argv); -} - -void gcpp_LoaderArgs_destructor(gcpp::LoaderArgs* args) { - delete args; -} - -const char* gcpp_LoaderArgs_Validate(gcpp::LoaderArgs* args) { - return args->Validate(); -} - -gcpp::Model gcpp_LoaderArgs_ModelType(const gcpp::LoaderArgs* args) { - return args->ModelType(); -} - -gcpp::ModelTraining gcpp_LoaderArgs_ModelTraining(const gcpp::LoaderArgs* args) { - return args->ModelTraining(); -} - -void gcpp_LoaderArgs_SetTokenizer(gcpp::LoaderArgs* args, char* path, size_t n) { - args->tokenizer.path = std::string(path, n); -} - -const char* gcpp_LoaderArgs_Tokenizer(gcpp::LoaderArgs* args) { - return args->tokenizer.path.c_str(); -} - -void gcpp_LoaderArgs_SetModel(gcpp::LoaderArgs* args, char* path, size_t n) { - args->model.path = std::string(path, n); -} - -const char* gcpp_LoaderArgs_Model(gcpp::LoaderArgs* args) { - return args->model.path.c_str(); -} - -void gcpp_LoaderArgs_SetCache(gcpp::LoaderArgs* args, char* path, size_t n) { - args->cache.path = std::string(path, n); -} - -const char* gcpp_LoaderArgs_Cache(gcpp::LoaderArgs* args) { - return args->cache.path.c_str(); -} - -void gcpp_LoaderArgs_SetModelTypeValue(gcpp::LoaderArgs* args, char* v, size_t n) { - args->model_type = std::string(v, n); -} - -const char* gcpp_LoaderArgs_ModelTypeValue(gcpp::LoaderArgs* args) { - return args->model_type.c_str(); -} - hwy::ThreadPool* hwy_ThreadPool_ThreadPool(size_t num_threads) { return new hwy::ThreadPool(num_threads); } @@ -62,14 +10,38 @@ void hwy_ThreadPool_destructor(hwy::ThreadPool* pool) { delete pool; } -gcpp::Gemma* gcpp_Gemma_Gemma(const gcpp::LoaderArgs* args, hwy::ThreadPool* pool) { - return new gcpp::Gemma(*args, *pool); +gcpp::Gemma* gcpp_Gemma_Gemma( + const char* tokenizer_path, size_t tokenizer_path_len, + const char* compressed_weights_path, size_t compressed_weights_path_len, + const char* weights_path, size_t weights_path_len, + gcpp::Model model_type, hwy::ThreadPool* pool) { + gcpp::Path tpath; + tpath.path = std::string(tokenizer_path, tokenizer_path_len); + gcpp::Path cwpath; + cwpath.path = std::string(compressed_weights_path, compressed_weights_path_len); + gcpp::Path wpath; + wpath.path = std::string(weights_path, weights_path_len); + return new gcpp::Gemma(tpath, cwpath, wpath, model_type, *pool); } void gcpp_Gemma_destructor(gcpp::Gemma* gemma) { delete gemma; } +void gcpp_Gemma_SetModelTraining(gcpp::Gemma* gemma, gcpp::ModelTraining training) { + gemma->model_training = training; +} + +gcpp::KVCache* gcpp_CreateKVCache(gcpp::Model model_type) { + gcpp::KVCache* cache = new gcpp::KVCache{}; + *cache = gcpp::CreateKVCache(model_type); + return cache; +} + +void gcpp_KVCache_destructor(gcpp::KVCache* kvcache) { + delete kvcache; +} + std::vector* std_vector_int_vector() { return new std::vector(); } @@ -99,11 +71,11 @@ const char* std_string_c_str(const std::string* s) { } bool gcpp_Gemma_Encode(gcpp::Gemma* gemma, const char* input, size_t len, std::vector* out) { - return gemma->Tokenizer().Encode(std::string(input, len), out).ok(); + return gemma->Tokenizer()->Encode(std::string(input, len), out).ok(); } bool gcpp_Gemma_Decode(gcpp::Gemma* gemma, int token, std::string* out) { - return gemma->Tokenizer().Decode(std::vector{token}, out).ok(); + return gemma->Tokenizer()->Decode(std::vector{token}, out).ok(); } bool gcpp_Gemma_Decodes(gcpp::Gemma* gemma, const int* tokens, int num_tokens, std::string* out) { @@ -112,59 +84,7 @@ bool gcpp_Gemma_Decodes(gcpp::Gemma* gemma, const int* tokens, int num_tokens, s for (int i = 0; i < num_tokens; i++) { v.push_back(tokens[i]); } - return gemma->Tokenizer().Decode(v, out).ok(); -} - -gcpp::InferenceArgs* gcpp_InferenceArgs_InferenceArgs(int argc, char* argv[]) { - return new gcpp::InferenceArgs(argc, argv); -} - -void gcpp_InferenceArgs_destructor(gcpp::InferenceArgs* args) { - delete args; -} - -const char* gcpp_InferenceArgs_Validate(gcpp::InferenceArgs* args) { - return args->Validate(); -} - -size_t gcpp_InferenceArgs_MaxTokens(gcpp::InferenceArgs* args) { - return args->max_tokens; -} - -void gcpp_InferenceArgs_SetMaxTokens(gcpp::InferenceArgs* args, size_t mt) { - args->max_tokens = mt; -} - -size_t gcpp_InferenceArgs_MaxGeneratedTokens(gcpp::InferenceArgs* args) { - return args->max_generated_tokens; -} - -void gcpp_InferenceArgs_SetMaxGeneratedTokens(gcpp::InferenceArgs* args, size_t mgt) { - args->max_generated_tokens = mgt; -} - -float gcpp_InferenceArgs_Temperature(gcpp::InferenceArgs* args) { - return args->temperature; -} - -void gcpp_InferenceArgs_SetTemperature(gcpp::InferenceArgs* args, float t) { - args->temperature = t; -} - -bool gcpp_InferenceArgs_Deterministic(gcpp::InferenceArgs* args) { - return args->deterministic; -} - -void gcpp_InferenceArgs_SetDeterministic(gcpp::InferenceArgs* args, bool d) { - args->deterministic = d; -} - -bool gcpp_InferenceArgs_Multiturn(gcpp::InferenceArgs* args) { - return args->multiturn; -} - -void gcpp_InferenceArgs_SetMultiturn(gcpp::InferenceArgs* args, bool mt) { - args->multiturn = mt; + return gemma->Tokenizer()->Decode(v, out).ok(); } std::mt19937* std_mt19937_mt19937() { @@ -188,24 +108,17 @@ typedef bool (*stream_callback)(void*, int, float); typedef bool (*accept_callback)(void*, int); void gcpp_GenerateGemma( - gcpp::Gemma* gemma, const gcpp::InferenceArgs* args, + gcpp::Gemma* gemma, const gcpp::RuntimeConfig* config, const std::vector* prompt, size_t start_pos, - hwy::ThreadPool* pool, hwy::ThreadPool* inner_pool, - void* stream_context, - stream_callback stream_token, - void* accept_context, - accept_callback accept_token, - std::mt19937* gen, int verbosity) { + gcpp::KVCache* kvcache, hwy::ThreadPool* pool, + void* stream_context, stream_callback stream_token, + std::mt19937* gen) { gcpp::GenerateGemma( - *gemma, *args, *prompt, start_pos, - *pool, *inner_pool, + *gemma, *config, *prompt, start_pos, *kvcache, *pool, [&stream_context, &stream_token](int token, float value) { return stream_token(stream_context, token, value); }, - [&accept_context, &accept_token](int token) { - return accept_token(accept_context, token); - }, - *gen, verbosity); + *gen); } } \ No newline at end of file diff --git a/crates/llm-chain-gemma-sys/src/bindings.rs b/crates/llm-chain-gemma-sys/src/bindings.rs index e0922ecc..ab5c2cd1 100644 --- a/crates/llm-chain-gemma-sys/src/bindings.rs +++ b/crates/llm-chain-gemma-sys/src/bindings.rs @@ -11,26 +11,11 @@ pub const gcpp_ModelTraining_GEMMA_PT: gcpp_ModelTraining = 1; pub const EOS_ID: i32 = 1; #[repr(C)] -pub struct gcpp_LoaderArgs { - _data: [u8; 0], - _marker: - core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, -} - -extern "C" { - pub fn gcpp_LoaderArgs_LoaderArgs(argc: ffi::c_int, argv: *mut *mut ffi::c_char) -> *mut gcpp_LoaderArgs; - pub fn gcpp_LoaderArgs_destructor(largs: *mut gcpp_LoaderArgs); - pub fn gcpp_LoaderArgs_Validate(largs: *mut gcpp_LoaderArgs) -> *const ffi::c_char; - pub fn gcpp_LoaderArgs_ModelType(largs: *const gcpp_LoaderArgs) -> gcpp_Model; - pub fn gcpp_LoaderArgs_ModelTraining(largs: *const gcpp_LoaderArgs) -> gcpp_ModelTraining; - pub fn gcpp_LoaderArgs_SetTokenizer(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char, n: ffi::c_uint); - pub fn gcpp_LoaderArgs_Tokenizer(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char; - pub fn gcpp_LoaderArgs_SetModel(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char, n: ffi::c_uint); - pub fn gcpp_LoaderArgs_Model(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char; - pub fn gcpp_LoaderArgs_SetCache(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char, n: ffi::c_uint); - pub fn gcpp_LoaderArgs_Cache(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char; - pub fn gcpp_LoaderArgs_SetModelTypeValue(largs: *mut gcpp_LoaderArgs, s: *const ffi::c_char, n: ffi::c_uint); - pub fn gcpp_LoaderArgs_ModelTypeValue(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char; +pub struct gcpp_RuntimeConfig { + pub max_tokens: ffi::c_uint, + pub max_generated_tokens: ffi::c_uint, + pub temperature: ffi::c_float, + pub verbosity: ffi::c_int, } #[repr(C)] @@ -53,24 +38,39 @@ pub struct gcpp_Gemma { } extern "C" { - pub fn gcpp_Gemma_Gemma(args: *mut gcpp_LoaderArgs, pool: *mut hwy_ThreadPool) -> *mut gcpp_Gemma; + pub fn gcpp_Gemma_Gemma( + tokenizer_path: *const ffi::c_char, tokenizer_path_len: ffi::c_uint, + compressed_weights_path: *const ffi::c_char, compressed_weights_path_len: ffi::c_uint, + weights_path: *const ffi::c_char, weights_path_len: ffi::c_uint, + model_type: gcpp_Model, pool: *mut hwy_ThreadPool) -> *mut gcpp_Gemma; pub fn gcpp_Gemma_destructor(gemma: *mut gcpp_Gemma); + pub fn gcpp_Gemma_SetModelTraining(gemma: *mut gcpp_Gemma, training: gcpp_ModelTraining); pub fn gcpp_Gemma_Encode(gemma: *mut gcpp_Gemma, input: *mut ffi::c_char, len: ffi::c_uint, out: *mut std_vector_int) -> ffi::c_char; pub fn gcpp_Gemma_Decode(gemma: *mut gcpp_Gemma, token: ffi::c_int, out: *mut std_string) -> ffi::c_char; pub fn gcpp_Gemma_Decodes(gemma: *mut gcpp_Gemma, tokens: *const ffi::c_int, num_tokens: ffi::c_int, out: *mut std_string) -> ffi::c_char; pub fn gcpp_GenerateGemma( - gemma: *mut gcpp_Gemma, args: *mut gcpp_InferenceArgs, + gemma: *mut gcpp_Gemma, config: *const gcpp_RuntimeConfig, prompt: *const std_vector_int, start_pos: ffi::c_uint, - pool: *mut hwy_ThreadPool, inner_pool: *mut hwy_ThreadPool, + kvcache: *mut gcpp_KVCache, pool: *mut hwy_ThreadPool, stream_context: *mut ffi::c_void, stream_token: extern fn(*mut ffi::c_void, ffi::c_int, ffi::c_float) -> ffi::c_char, - accept_context: *mut ffi::c_void, - accept_token: extern fn(*mut ffi::c_void, ffi::c_int) -> ffi::c_char, - gen: *mut std_mt19937, verbosity: ffi::c_int, + gen: *mut std_mt19937, ); } +#[repr(C)] +pub struct gcpp_KVCache { + _data: [u8; 0], + _marker: + core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + +extern "C" { + pub fn gcpp_CreateKVCache(model_type: gcpp_Model) -> *mut gcpp_KVCache; + pub fn gcpp_KVCache_destructor(cache: *mut gcpp_KVCache); +} + #[repr(C)] pub struct std_vector_int { _data: [u8; 0], @@ -134,30 +134,6 @@ extern "C" { pub fn std_string_c_str(s: *const std_string) -> *mut ffi::c_char; } -#[repr(C)] -pub struct gcpp_InferenceArgs { - _data: [u8; 0], - _marker: - core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, -} - -extern "C" { - pub fn gcpp_InferenceArgs_InferenceArgs(argc: ffi::c_int, argv: *mut *mut ffi::c_char) -> *mut gcpp_InferenceArgs; - pub fn gcpp_InferenceArgs_destructor(args: *mut gcpp_InferenceArgs); - pub fn gcpp_InferenceArgs_Validate(args: *mut gcpp_InferenceArgs) -> *const ffi::c_char; - - pub fn gcpp_InferenceArgs_MaxTokens(args: *const gcpp_InferenceArgs) -> ffi::c_uint; - pub fn gcpp_InferenceArgs_SetMaxTokens(args: *mut gcpp_InferenceArgs, mt: ffi::c_uint); - pub fn gcpp_InferenceArgs_MaxGeneratedTokens(args: *const gcpp_InferenceArgs) -> ffi::c_uint; - pub fn gcpp_InferenceArgs_SetMaxGeneratedTokens(args: *mut gcpp_InferenceArgs, mgt: ffi::c_uint); - pub fn gcpp_InferenceArgs_Temperature(args: *const gcpp_InferenceArgs) -> ffi::c_float; - pub fn gcpp_InferenceArgs_SetTemperature(args: *mut gcpp_InferenceArgs, t: ffi::c_float); - pub fn gcpp_InferenceArgs_Deterministic(args: *const gcpp_InferenceArgs) -> ffi::c_char; - pub fn gcpp_InferenceArgs_SetDeterministic(args: *mut gcpp_InferenceArgs, d: ffi::c_char); - pub fn gcpp_InferenceArgs_Multiturn(args: *const gcpp_InferenceArgs) -> ffi::c_char; - pub fn gcpp_InferenceArgs_SetMultiturn(args: *mut gcpp_InferenceArgs, mt: ffi::c_char); -} - #[repr(C)] pub struct std_mt19937 { _data: [u8; 0], @@ -176,65 +152,6 @@ extern "C" { mod test { use crate::*; - #[test] - fn create_and_delete_largs() { - let args = vec![ - "prog", - "--tokenizer", "tokenizer.spm", - "--model", "2b-pt", - "--compressed_weights", "2b-pt.sbs", - ]; - unsafe { - let largs = gcpp_LoaderArgs_LoaderArgs( - args.len() as ffi::c_int, - Vec::from_iter(args.into_iter().map(|arg| - ffi::CString::new(arg).unwrap().into_raw() - ).into_iter()).as_mut_ptr(), - ); - assert_eq!(gcpp_Model_GEMMA_2B, gcpp_LoaderArgs_ModelType(largs)); - assert_eq!(gcpp_ModelTraining_GEMMA_PT, gcpp_LoaderArgs_ModelTraining(largs)); - let tp = gcpp_LoaderArgs_Tokenizer(largs); - let s = ffi::CStr::from_ptr(tp).to_str().unwrap(); - assert_eq!(s, "tokenizer.spm"); - gcpp_LoaderArgs_destructor(largs); - } - } - - #[test] - fn create_and_delete_largs_direct() { - let tokenizer_path = "tokenizer.spm"; - let compressed_weights = "2b-pt.sbs"; - let model = "2b-pt"; - unsafe { - let largs = gcpp_LoaderArgs_LoaderArgs(0, std::ptr::null_mut()); - gcpp_LoaderArgs_SetTokenizer(largs, tokenizer_path.as_ptr() as *const i8, tokenizer_path.len() as ffi::c_uint); - gcpp_LoaderArgs_SetCache(largs, compressed_weights.as_ptr() as *const i8, compressed_weights.len() as ffi::c_uint); - gcpp_LoaderArgs_SetModelTypeValue(largs, model.as_ptr() as *const i8, model.len() as ffi::c_uint); - let err = gcpp_LoaderArgs_Validate(largs); - if err != std::ptr::null_mut() { - println!("{}", ffi::CStr::from_ptr(err).to_str().unwrap()); - } - assert_eq!(std::ptr::null(), err); - } - } - - #[test] - fn create_and_delete_iargs_direct() { - unsafe { - let iargs = gcpp_InferenceArgs_InferenceArgs(0, std::ptr::null_mut()); - assert_eq!(gcpp_InferenceArgs_Validate(iargs), std::ptr::null()); - - assert_eq!(gcpp_InferenceArgs_MaxGeneratedTokens(iargs), 2048); - assert_eq!(gcpp_InferenceArgs_MaxTokens(iargs), 3072); - - gcpp_InferenceArgs_SetMaxGeneratedTokens(iargs, 4096); - - assert_ne!(gcpp_InferenceArgs_Validate(iargs), std::ptr::null()); - - gcpp_InferenceArgs_destructor(iargs); - } - } - #[test] fn create_and_delete_pool() { unsafe { diff --git a/crates/llm-chain-gemma-sys/src/check.cc b/crates/llm-chain-gemma-sys/src/check.cc deleted file mode 100644 index dd8b023e..00000000 --- a/crates/llm-chain-gemma-sys/src/check.cc +++ /dev/null @@ -1,42 +0,0 @@ -#include -#include -#include - -int main(int argc, char* argv[]) { - gcpp::LoaderArgs largs(argc, argv); - gcpp::InferenceArgs iargs(argc, argv); - gcpp::AppArgs aargs(argc, argv); - - largs.Validate(); - - hwy::ThreadPool pool(1); - hwy::ThreadPool inner_pool(0); - - gcpp::Gemma gemma(largs, pool); - - std::mt19937 gen; - gen.seed(42); - - std::vector tokens; - gemma.Tokenizer().Encode( - "user\nWhat is a gemma?\nmodel\n", &tokens); - for (auto token : tokens) { - std::cout << "token: " << token << std::endl; - } - - gcpp::GenerateGemma( - gemma, iargs, tokens, 0, - pool, inner_pool, - [&gemma](int token, float value) { - std::string decoded; - gemma.Tokenizer().Decode(std::vector{token}, &decoded); - std::cout << decoded; - return true; - }, - [](int token) { return true; }, - gen, - 10 - ); - std::cout << std::endl; - return 0; -} diff --git a/crates/llm-chain-gemma-sys/wrapper.h b/crates/llm-chain-gemma-sys/wrapper.h deleted file mode 100644 index 02084e24..00000000 --- a/crates/llm-chain-gemma-sys/wrapper.h +++ /dev/null @@ -1 +0,0 @@ -#include \ No newline at end of file diff --git a/crates/llm-chain-gemma/src/context.rs b/crates/llm-chain-gemma/src/context.rs index 03f78890..3c8ecdf7 100644 --- a/crates/llm-chain-gemma/src/context.rs +++ b/crates/llm-chain-gemma/src/context.rs @@ -3,18 +3,15 @@ use llm_chain::output::StreamSegment; use llm_chain::tokens::{TokenCollection, Tokenizer, TokenizerError}; use llm_chain::traits::ExecutorCreationError; use llm_chain_gemma_sys::{ - gcpp_Gemma, gcpp_Gemma_Decode, gcpp_Gemma_Decodes, gcpp_Gemma_Encode, gcpp_Gemma_Gemma, - gcpp_Gemma_destructor, gcpp_GenerateGemma, gcpp_InferenceArgs, - gcpp_InferenceArgs_InferenceArgs, gcpp_InferenceArgs_MaxGeneratedTokens, - gcpp_InferenceArgs_Multiturn, gcpp_InferenceArgs_SetMaxTokens, - gcpp_InferenceArgs_SetTemperature, gcpp_InferenceArgs_Validate, gcpp_InferenceArgs_destructor, - gcpp_LoaderArgs_LoaderArgs, gcpp_LoaderArgs_ModelTraining, gcpp_LoaderArgs_SetCache, - gcpp_LoaderArgs_SetModelTypeValue, gcpp_LoaderArgs_SetTokenizer, gcpp_LoaderArgs_Validate, - gcpp_LoaderArgs_destructor, gcpp_ModelTraining, gcpp_ModelTraining_GEMMA_IT, hwy_ThreadPool, - hwy_ThreadPool_ThreadPool, hwy_ThreadPool_destructor, std_mt19937, std_mt19937_destructor, - std_mt19937_mt19937, std_mt19937_random_seed, std_string_c_str, std_string_destructor, - std_string_string, std_vector_int_destructor, std_vector_int_iter, std_vector_int_size, - std_vector_int_vector, EOS_ID, + gcpp_CreateKVCache, gcpp_Gemma, gcpp_Gemma_Decode, gcpp_Gemma_Decodes, gcpp_Gemma_Encode, + gcpp_Gemma_Gemma, gcpp_Gemma_SetModelTraining, gcpp_Gemma_destructor, gcpp_GenerateGemma, + gcpp_KVCache, gcpp_KVCache_destructor, gcpp_Model, gcpp_ModelTraining, + gcpp_ModelTraining_GEMMA_IT, gcpp_ModelTraining_GEMMA_PT, gcpp_Model_GEMMA_2B, + gcpp_Model_GEMMA_7B, gcpp_RuntimeConfig, hwy_ThreadPool, hwy_ThreadPool_ThreadPool, + hwy_ThreadPool_destructor, std_mt19937, std_mt19937_destructor, std_mt19937_mt19937, + std_mt19937_random_seed, std_string_c_str, std_string_destructor, std_string_string, + std_vector_int_destructor, std_vector_int_iter, std_vector_int_size, std_vector_int_vector, + EOS_ID, }; use std::ffi; use std::path::Path; @@ -24,85 +21,76 @@ pub struct GemmaContext { gemma: *mut gcpp_Gemma, model_training: gcpp_ModelTraining, gen: *mut std_mt19937, - pub iargs: *mut gcpp_InferenceArgs, + pub config: gcpp_RuntimeConfig, + kvcache: *mut gcpp_KVCache, pool: *mut hwy_ThreadPool, - inner_pool: *mut hwy_ThreadPool, pos: u32, } impl GemmaContext { pub fn new(options: &Options) -> Result { - unsafe { - let largs = gcpp_LoaderArgs_LoaderArgs(0, std::ptr::null_mut()); - if let Some(Opt::ModelType(mt)) = options.get(OptDiscriminants::ModelType) { - gcpp_LoaderArgs_SetModelTypeValue( - largs, - mt.clone().into_bytes().as_ptr() as *const i8, - mt.len() as ffi::c_uint, - ); + let mut model_type: gcpp_Model = gcpp_Model_GEMMA_2B; + let mut model_training: gcpp_ModelTraining = gcpp_ModelTraining_GEMMA_IT; + let mut tokenizer_path = String::new(); + let mut compressed_weights_path = String::new(); + let mut config = gcpp_RuntimeConfig { + max_tokens: 3072, + max_generated_tokens: 2048, + temperature: 1.0, + verbosity: 0, + }; + if let Some(Opt::ModelType(mt)) = options.get(OptDiscriminants::ModelType) { + let parts = Vec::from_iter(mt.split("-").into_iter()); + if parts.len() != 2 { + return Err(ExecutorCreationError::InvalidValue(format!( + "model type {} is invalid", + mt + ))); } - if let Some(Opt::Model(m)) = options.get(OptDiscriminants::Model) { - // Typically the downloaded model data is compressed and set as cache. - // TODO: consider the case of non-compressed one? - let path = m.to_path(); - gcpp_LoaderArgs_SetCache( - largs, - path.as_ptr() as *const i8, - path.len() as ffi::c_uint, - ); - // TODO: consider adding the option for tokenizer file. - let parent = Path::new(&path).parent(); - if parent.is_none() { - return Err(ExecutorCreationError::InvalidValue(String::from( - "no parent for path", - ))); + match parts[0] { + "2b" => {} + "7b" => { + model_type = gcpp_Model_GEMMA_7B; } - if let Some(tokenizer_path) = parent.unwrap().join("tokenizer.spm").to_str() { - gcpp_LoaderArgs_SetTokenizer( - largs, - tokenizer_path.as_ptr() as *const i8, - tokenizer_path.len() as ffi::c_uint, - ); - } else { - return Err(ExecutorCreationError::InvalidValue(String::from( - "conversion from path to str for tokenizer", + _ => { + return Err(ExecutorCreationError::InvalidValue(format!( + "model type {} must be 2b or 7b", + parts[0] ))); } } - - let err = gcpp_LoaderArgs_Validate(largs); - if err != std::ptr::null_mut() { - let msg = ffi::CString::from_raw(err as *mut ffi::c_char).into_string(); - if msg.is_err() { - return Err(ExecutorCreationError::InnerError(Box::new( - msg.unwrap_err(), + match parts[1] { + "it" => {} + "pt" => { + model_training = gcpp_ModelTraining_GEMMA_PT; + } + _ => { + return Err(ExecutorCreationError::InvalidValue(format!( + "model training {} must be it or pt", + parts[1] ))); } - gcpp_LoaderArgs_destructor(largs); - return Err(ExecutorCreationError::InvalidValue(msg.unwrap())); - } - - let iargs = gcpp_InferenceArgs_InferenceArgs(0, std::ptr::null_mut()); - if let Some(Opt::Temperature(t)) = options.get(OptDiscriminants::Temperature) { - gcpp_InferenceArgs_SetTemperature(iargs, *t); } - if let Some(Opt::MaxTokens(m)) = options.get(OptDiscriminants::MaxTokens) { - gcpp_InferenceArgs_SetMaxTokens(iargs, *m as ffi::c_uint); + } + if let Some(Opt::Model(m)) = options.get(OptDiscriminants::Model) { + compressed_weights_path = m.to_path(); + let parent = Path::new(&compressed_weights_path).parent(); + if parent.is_none() { + return Err(ExecutorCreationError::InvalidValue(String::from( + "no parent for path", + ))); } - - let err = gcpp_InferenceArgs_Validate(iargs); - if err != std::ptr::null_mut() { - let msg = ffi::CString::from_raw(err as *mut ffi::c_char).into_string(); - if msg.is_err() { - return Err(ExecutorCreationError::InnerError(Box::new( - msg.unwrap_err(), - ))); - } - gcpp_LoaderArgs_destructor(largs); - gcpp_InferenceArgs_destructor(iargs); - return Err(ExecutorCreationError::InvalidValue(msg.unwrap())); + if let Some(tpath) = parent.unwrap().join("tokenizer.spm").to_str() { + tokenizer_path = String::from(tpath); } - + } + if let Some(Opt::Temperature(t)) = options.get(OptDiscriminants::Temperature) { + config.temperature = *t; + } + if let Some(Opt::MaxTokens(m)) = options.get(OptDiscriminants::MaxTokens) { + config.max_tokens = *m as ffi::c_uint; + } + unsafe { let pool = hwy_ThreadPool_ThreadPool( if let Some(Opt::NThreads(nt)) = options.get(OptDiscriminants::NThreads) { *nt as ffi::c_uint @@ -110,23 +98,29 @@ impl GemmaContext { 0 }, ); - let inner_pool = hwy_ThreadPool_ThreadPool(1); - let gemma = gcpp_Gemma_Gemma(largs, pool); + let gemma = gcpp_Gemma_Gemma( + tokenizer_path.as_ptr() as *const i8, + tokenizer_path.len() as ffi::c_uint, + compressed_weights_path.as_ptr() as *const i8, + compressed_weights_path.len() as ffi::c_uint, + std::ptr::null(), + 0, + model_type, + pool, + ); + gcpp_Gemma_SetModelTraining(gemma, model_training); + let gen = std_mt19937_mt19937(); std_mt19937_random_seed(gen); - let model_training = gcpp_LoaderArgs_ModelTraining(largs); - - gcpp_LoaderArgs_destructor(largs); - Ok(GemmaContext { gemma: gemma, gen: gen, model_training: model_training as gcpp_ModelTraining, - iargs: iargs, + config: config, + kvcache: gcpp_CreateKVCache(model_type), pool: pool, - inner_pool: inner_pool, pos: 0, }) } @@ -138,9 +132,8 @@ impl Drop for GemmaContext { unsafe { gcpp_Gemma_destructor(self.gemma); std_mt19937_destructor(self.gen); - gcpp_InferenceArgs_destructor(self.iargs); + gcpp_KVCache_destructor(self.kvcache); hwy_ThreadPool_destructor(self.pool); - hwy_ThreadPool_destructor(self.inner_pool); } } } @@ -184,14 +177,10 @@ extern "C" fn stream_token( } } -extern "C" fn accept_token(_ctx: *mut ffi::c_void, _token: ffi::c_int) -> ffi::c_char { - true as ffi::c_char -} - impl GemmaContext { pub fn generate<'a>(&mut self, prompt: String, out: mpsc::UnboundedSender) { unsafe { - if gcpp_InferenceArgs_Multiturn(self.iargs) != 0 { + if self.model_training != gcpp_ModelTraining_GEMMA_IT { self.pos = 0 } let mut prompt_text = if self.model_training == gcpp_ModelTraining_GEMMA_IT { @@ -218,17 +207,14 @@ impl GemmaContext { }; gcpp_GenerateGemma( self.gemma, - self.iargs, + &mut self.config, tokens, self.pos, + self.kvcache, self.pool, - self.inner_pool, (&mut genctx as *mut GenerateContext) as *mut ffi::c_void, stream_token, - std::ptr::null_mut(), - accept_token, self.gen, - 0, ); self.pos = genctx.pos; std_vector_int_destructor(tokens); @@ -236,7 +222,7 @@ impl GemmaContext { } pub fn max_generated_tokens(&self) -> u32 { - unsafe { gcpp_InferenceArgs_MaxGeneratedTokens(self.iargs) } + self.config.max_generated_tokens as u32 } }