Skip to content

Commit

Permalink
Uprev gemma.cpp version
Browse files Browse the repository at this point in the history
  • Loading branch information
jmuk committed Mar 15, 2024
1 parent 3c616fb commit 94e4d11
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 371 deletions.
Binary file removed crates/llm-chain-gemma-sys/a.out
Binary file not shown.
157 changes: 35 additions & 122 deletions crates/llm-chain-gemma-sys/src/bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,6 @@

extern "C" {

gcpp::LoaderArgs* gcpp_LoaderArgs_LoaderArgs(int argc, char* argv[]) {
return new gcpp::LoaderArgs(argc, argv);
}

void gcpp_LoaderArgs_destructor(gcpp::LoaderArgs* args) {
delete args;
}

const char* gcpp_LoaderArgs_Validate(gcpp::LoaderArgs* args) {
return args->Validate();
}

gcpp::Model gcpp_LoaderArgs_ModelType(const gcpp::LoaderArgs* args) {
return args->ModelType();
}

gcpp::ModelTraining gcpp_LoaderArgs_ModelTraining(const gcpp::LoaderArgs* args) {
return args->ModelTraining();
}

void gcpp_LoaderArgs_SetTokenizer(gcpp::LoaderArgs* args, char* path, size_t n) {
args->tokenizer.path = std::string(path, n);
}

const char* gcpp_LoaderArgs_Tokenizer(gcpp::LoaderArgs* args) {
return args->tokenizer.path.c_str();
}

void gcpp_LoaderArgs_SetModel(gcpp::LoaderArgs* args, char* path, size_t n) {
args->model.path = std::string(path, n);
}

const char* gcpp_LoaderArgs_Model(gcpp::LoaderArgs* args) {
return args->model.path.c_str();
}

void gcpp_LoaderArgs_SetCache(gcpp::LoaderArgs* args, char* path, size_t n) {
args->cache.path = std::string(path, n);
}

const char* gcpp_LoaderArgs_Cache(gcpp::LoaderArgs* args) {
return args->cache.path.c_str();
}

void gcpp_LoaderArgs_SetModelTypeValue(gcpp::LoaderArgs* args, char* v, size_t n) {
args->model_type = std::string(v, n);
}

const char* gcpp_LoaderArgs_ModelTypeValue(gcpp::LoaderArgs* args) {
return args->model_type.c_str();
}

hwy::ThreadPool* hwy_ThreadPool_ThreadPool(size_t num_threads) {
return new hwy::ThreadPool(num_threads);
}
Expand All @@ -62,14 +10,38 @@ void hwy_ThreadPool_destructor(hwy::ThreadPool* pool) {
delete pool;
}

gcpp::Gemma* gcpp_Gemma_Gemma(const gcpp::LoaderArgs* args, hwy::ThreadPool* pool) {
return new gcpp::Gemma(*args, *pool);
gcpp::Gemma* gcpp_Gemma_Gemma(
const char* tokenizer_path, size_t tokenizer_path_len,
const char* compressed_weights_path, size_t compressed_weights_path_len,
const char* weights_path, size_t weights_path_len,
gcpp::Model model_type, hwy::ThreadPool* pool) {
gcpp::Path tpath;
tpath.path = std::string(tokenizer_path, tokenizer_path_len);
gcpp::Path cwpath;
cwpath.path = std::string(compressed_weights_path, compressed_weights_path_len);
gcpp::Path wpath;
wpath.path = std::string(weights_path, weights_path_len);
return new gcpp::Gemma(tpath, cwpath, wpath, model_type, *pool);
}

void gcpp_Gemma_destructor(gcpp::Gemma* gemma) {
delete gemma;
}

void gcpp_Gemma_SetModelTraining(gcpp::Gemma* gemma, gcpp::ModelTraining training) {
gemma->model_training = training;
}

gcpp::KVCache* gcpp_CreateKVCache(gcpp::Model model_type) {
gcpp::KVCache* cache = new gcpp::KVCache{};
*cache = gcpp::CreateKVCache(model_type);
return cache;
}

void gcpp_KVCache_destructor(gcpp::KVCache* kvcache) {
delete kvcache;
}

std::vector<int>* std_vector_int_vector() {
return new std::vector<int>();
}
Expand Down Expand Up @@ -99,11 +71,11 @@ const char* std_string_c_str(const std::string* s) {
}

bool gcpp_Gemma_Encode(gcpp::Gemma* gemma, const char* input, size_t len, std::vector<int>* out) {
return gemma->Tokenizer().Encode(std::string(input, len), out).ok();
return gemma->Tokenizer()->Encode(std::string(input, len), out).ok();
}

bool gcpp_Gemma_Decode(gcpp::Gemma* gemma, int token, std::string* out) {
return gemma->Tokenizer().Decode(std::vector<int>{token}, out).ok();
return gemma->Tokenizer()->Decode(std::vector<int>{token}, out).ok();
}

bool gcpp_Gemma_Decodes(gcpp::Gemma* gemma, const int* tokens, int num_tokens, std::string* out) {
Expand All @@ -112,59 +84,7 @@ bool gcpp_Gemma_Decodes(gcpp::Gemma* gemma, const int* tokens, int num_tokens, s
for (int i = 0; i < num_tokens; i++) {
v.push_back(tokens[i]);
}
return gemma->Tokenizer().Decode(v, out).ok();
}

gcpp::InferenceArgs* gcpp_InferenceArgs_InferenceArgs(int argc, char* argv[]) {
return new gcpp::InferenceArgs(argc, argv);
}

void gcpp_InferenceArgs_destructor(gcpp::InferenceArgs* args) {
delete args;
}

const char* gcpp_InferenceArgs_Validate(gcpp::InferenceArgs* args) {
return args->Validate();
}

size_t gcpp_InferenceArgs_MaxTokens(gcpp::InferenceArgs* args) {
return args->max_tokens;
}

void gcpp_InferenceArgs_SetMaxTokens(gcpp::InferenceArgs* args, size_t mt) {
args->max_tokens = mt;
}

size_t gcpp_InferenceArgs_MaxGeneratedTokens(gcpp::InferenceArgs* args) {
return args->max_generated_tokens;
}

void gcpp_InferenceArgs_SetMaxGeneratedTokens(gcpp::InferenceArgs* args, size_t mgt) {
args->max_generated_tokens = mgt;
}

float gcpp_InferenceArgs_Temperature(gcpp::InferenceArgs* args) {
return args->temperature;
}

void gcpp_InferenceArgs_SetTemperature(gcpp::InferenceArgs* args, float t) {
args->temperature = t;
}

bool gcpp_InferenceArgs_Deterministic(gcpp::InferenceArgs* args) {
return args->deterministic;
}

void gcpp_InferenceArgs_SetDeterministic(gcpp::InferenceArgs* args, bool d) {
args->deterministic = d;
}

bool gcpp_InferenceArgs_Multiturn(gcpp::InferenceArgs* args) {
return args->multiturn;
}

void gcpp_InferenceArgs_SetMultiturn(gcpp::InferenceArgs* args, bool mt) {
args->multiturn = mt;
return gemma->Tokenizer()->Decode(v, out).ok();
}

std::mt19937* std_mt19937_mt19937() {
Expand All @@ -188,24 +108,17 @@ typedef bool (*stream_callback)(void*, int, float);
typedef bool (*accept_callback)(void*, int);

void gcpp_GenerateGemma(
gcpp::Gemma* gemma, const gcpp::InferenceArgs* args,
gcpp::Gemma* gemma, const gcpp::RuntimeConfig* config,
const std::vector<int>* prompt, size_t start_pos,
hwy::ThreadPool* pool, hwy::ThreadPool* inner_pool,
void* stream_context,
stream_callback stream_token,
void* accept_context,
accept_callback accept_token,
std::mt19937* gen, int verbosity) {
gcpp::KVCache* kvcache, hwy::ThreadPool* pool,
void* stream_context, stream_callback stream_token,
std::mt19937* gen) {
gcpp::GenerateGemma(
*gemma, *args, *prompt, start_pos,
*pool, *inner_pool,
*gemma, *config, *prompt, start_pos, *kvcache, *pool,
[&stream_context, &stream_token](int token, float value) {
return stream_token(stream_context, token, value);
},
[&accept_context, &accept_token](int token) {
return accept_token(accept_context, token);
},
*gen, verbosity);
*gen);
}

}
135 changes: 26 additions & 109 deletions crates/llm-chain-gemma-sys/src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,11 @@ pub const gcpp_ModelTraining_GEMMA_PT: gcpp_ModelTraining = 1;
pub const EOS_ID: i32 = 1;

#[repr(C)]
pub struct gcpp_LoaderArgs {
_data: [u8; 0],
_marker:
core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>,
}

extern "C" {
pub fn gcpp_LoaderArgs_LoaderArgs(argc: ffi::c_int, argv: *mut *mut ffi::c_char) -> *mut gcpp_LoaderArgs;
pub fn gcpp_LoaderArgs_destructor(largs: *mut gcpp_LoaderArgs);
pub fn gcpp_LoaderArgs_Validate(largs: *mut gcpp_LoaderArgs) -> *const ffi::c_char;
pub fn gcpp_LoaderArgs_ModelType(largs: *const gcpp_LoaderArgs) -> gcpp_Model;
pub fn gcpp_LoaderArgs_ModelTraining(largs: *const gcpp_LoaderArgs) -> gcpp_ModelTraining;
pub fn gcpp_LoaderArgs_SetTokenizer(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char, n: ffi::c_uint);
pub fn gcpp_LoaderArgs_Tokenizer(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char;
pub fn gcpp_LoaderArgs_SetModel(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char, n: ffi::c_uint);
pub fn gcpp_LoaderArgs_Model(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char;
pub fn gcpp_LoaderArgs_SetCache(largs: *mut gcpp_LoaderArgs, path: *const ffi::c_char, n: ffi::c_uint);
pub fn gcpp_LoaderArgs_Cache(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char;
pub fn gcpp_LoaderArgs_SetModelTypeValue(largs: *mut gcpp_LoaderArgs, s: *const ffi::c_char, n: ffi::c_uint);
pub fn gcpp_LoaderArgs_ModelTypeValue(largs: *const gcpp_LoaderArgs) -> *mut ffi::c_char;
pub struct gcpp_RuntimeConfig {
pub max_tokens: ffi::c_uint,
pub max_generated_tokens: ffi::c_uint,
pub temperature: ffi::c_float,
pub verbosity: ffi::c_int,
}

#[repr(C)]
Expand All @@ -53,24 +38,39 @@ pub struct gcpp_Gemma {
}

extern "C" {
pub fn gcpp_Gemma_Gemma(args: *mut gcpp_LoaderArgs, pool: *mut hwy_ThreadPool) -> *mut gcpp_Gemma;
pub fn gcpp_Gemma_Gemma(
tokenizer_path: *const ffi::c_char, tokenizer_path_len: ffi::c_uint,
compressed_weights_path: *const ffi::c_char, compressed_weights_path_len: ffi::c_uint,
weights_path: *const ffi::c_char, weights_path_len: ffi::c_uint,
model_type: gcpp_Model, pool: *mut hwy_ThreadPool) -> *mut gcpp_Gemma;
pub fn gcpp_Gemma_destructor(gemma: *mut gcpp_Gemma);
pub fn gcpp_Gemma_SetModelTraining(gemma: *mut gcpp_Gemma, training: gcpp_ModelTraining);
pub fn gcpp_Gemma_Encode(gemma: *mut gcpp_Gemma, input: *mut ffi::c_char, len: ffi::c_uint, out: *mut std_vector_int) -> ffi::c_char;
pub fn gcpp_Gemma_Decode(gemma: *mut gcpp_Gemma, token: ffi::c_int, out: *mut std_string) -> ffi::c_char;
pub fn gcpp_Gemma_Decodes(gemma: *mut gcpp_Gemma, tokens: *const ffi::c_int, num_tokens: ffi::c_int, out: *mut std_string) -> ffi::c_char;

pub fn gcpp_GenerateGemma(
gemma: *mut gcpp_Gemma, args: *mut gcpp_InferenceArgs,
gemma: *mut gcpp_Gemma, config: *const gcpp_RuntimeConfig,
prompt: *const std_vector_int, start_pos: ffi::c_uint,
pool: *mut hwy_ThreadPool, inner_pool: *mut hwy_ThreadPool,
kvcache: *mut gcpp_KVCache, pool: *mut hwy_ThreadPool,
stream_context: *mut ffi::c_void,
stream_token: extern fn(*mut ffi::c_void, ffi::c_int, ffi::c_float) -> ffi::c_char,
accept_context: *mut ffi::c_void,
accept_token: extern fn(*mut ffi::c_void, ffi::c_int) -> ffi::c_char,
gen: *mut std_mt19937, verbosity: ffi::c_int,
gen: *mut std_mt19937,
);
}

#[repr(C)]
pub struct gcpp_KVCache {
_data: [u8; 0],
_marker:
core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>,
}

extern "C" {
pub fn gcpp_CreateKVCache(model_type: gcpp_Model) -> *mut gcpp_KVCache;
pub fn gcpp_KVCache_destructor(cache: *mut gcpp_KVCache);
}

#[repr(C)]
pub struct std_vector_int {
_data: [u8; 0],
Expand Down Expand Up @@ -134,30 +134,6 @@ extern "C" {
pub fn std_string_c_str(s: *const std_string) -> *mut ffi::c_char;
}

#[repr(C)]
pub struct gcpp_InferenceArgs {
_data: [u8; 0],
_marker:
core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>,
}

extern "C" {
pub fn gcpp_InferenceArgs_InferenceArgs(argc: ffi::c_int, argv: *mut *mut ffi::c_char) -> *mut gcpp_InferenceArgs;
pub fn gcpp_InferenceArgs_destructor(args: *mut gcpp_InferenceArgs);
pub fn gcpp_InferenceArgs_Validate(args: *mut gcpp_InferenceArgs) -> *const ffi::c_char;

pub fn gcpp_InferenceArgs_MaxTokens(args: *const gcpp_InferenceArgs) -> ffi::c_uint;
pub fn gcpp_InferenceArgs_SetMaxTokens(args: *mut gcpp_InferenceArgs, mt: ffi::c_uint);
pub fn gcpp_InferenceArgs_MaxGeneratedTokens(args: *const gcpp_InferenceArgs) -> ffi::c_uint;
pub fn gcpp_InferenceArgs_SetMaxGeneratedTokens(args: *mut gcpp_InferenceArgs, mgt: ffi::c_uint);
pub fn gcpp_InferenceArgs_Temperature(args: *const gcpp_InferenceArgs) -> ffi::c_float;
pub fn gcpp_InferenceArgs_SetTemperature(args: *mut gcpp_InferenceArgs, t: ffi::c_float);
pub fn gcpp_InferenceArgs_Deterministic(args: *const gcpp_InferenceArgs) -> ffi::c_char;
pub fn gcpp_InferenceArgs_SetDeterministic(args: *mut gcpp_InferenceArgs, d: ffi::c_char);
pub fn gcpp_InferenceArgs_Multiturn(args: *const gcpp_InferenceArgs) -> ffi::c_char;
pub fn gcpp_InferenceArgs_SetMultiturn(args: *mut gcpp_InferenceArgs, mt: ffi::c_char);
}

#[repr(C)]
pub struct std_mt19937 {
_data: [u8; 0],
Expand All @@ -176,65 +152,6 @@ extern "C" {
mod test {
use crate::*;

#[test]
fn create_and_delete_largs() {
let args = vec![
"prog",
"--tokenizer", "tokenizer.spm",
"--model", "2b-pt",
"--compressed_weights", "2b-pt.sbs",
];
unsafe {
let largs = gcpp_LoaderArgs_LoaderArgs(
args.len() as ffi::c_int,
Vec::from_iter(args.into_iter().map(|arg|
ffi::CString::new(arg).unwrap().into_raw()
).into_iter()).as_mut_ptr(),
);
assert_eq!(gcpp_Model_GEMMA_2B, gcpp_LoaderArgs_ModelType(largs));
assert_eq!(gcpp_ModelTraining_GEMMA_PT, gcpp_LoaderArgs_ModelTraining(largs));
let tp = gcpp_LoaderArgs_Tokenizer(largs);
let s = ffi::CStr::from_ptr(tp).to_str().unwrap();
assert_eq!(s, "tokenizer.spm");
gcpp_LoaderArgs_destructor(largs);
}
}

#[test]
fn create_and_delete_largs_direct() {
let tokenizer_path = "tokenizer.spm";
let compressed_weights = "2b-pt.sbs";
let model = "2b-pt";
unsafe {
let largs = gcpp_LoaderArgs_LoaderArgs(0, std::ptr::null_mut());
gcpp_LoaderArgs_SetTokenizer(largs, tokenizer_path.as_ptr() as *const i8, tokenizer_path.len() as ffi::c_uint);
gcpp_LoaderArgs_SetCache(largs, compressed_weights.as_ptr() as *const i8, compressed_weights.len() as ffi::c_uint);
gcpp_LoaderArgs_SetModelTypeValue(largs, model.as_ptr() as *const i8, model.len() as ffi::c_uint);
let err = gcpp_LoaderArgs_Validate(largs);
if err != std::ptr::null_mut() {
println!("{}", ffi::CStr::from_ptr(err).to_str().unwrap());
}
assert_eq!(std::ptr::null(), err);
}
}

#[test]
fn create_and_delete_iargs_direct() {
unsafe {
let iargs = gcpp_InferenceArgs_InferenceArgs(0, std::ptr::null_mut());
assert_eq!(gcpp_InferenceArgs_Validate(iargs), std::ptr::null());

assert_eq!(gcpp_InferenceArgs_MaxGeneratedTokens(iargs), 2048);
assert_eq!(gcpp_InferenceArgs_MaxTokens(iargs), 3072);

gcpp_InferenceArgs_SetMaxGeneratedTokens(iargs, 4096);

assert_ne!(gcpp_InferenceArgs_Validate(iargs), std::ptr::null());

gcpp_InferenceArgs_destructor(iargs);
}
}

#[test]
fn create_and_delete_pool() {
unsafe {
Expand Down
Loading

0 comments on commit 94e4d11

Please sign in to comment.