@@ -43,11 +43,36 @@ struct mlxsharp_session {
4343 std::string chat_model;
4444 std::string embedding_model;
4545 std::string image_model;
46- mlxsharp_session (mlxsharp_context_t * ctx, std::string chat, std::string embed, std::string image)
46+ std::string native_model_directory;
47+ std::string tokenizer_path;
48+ bool enable_native_runner;
49+ int max_generated_tokens;
50+ float temperature;
51+ float top_p;
52+ int top_k;
53+ mlxsharp_session (
54+ mlxsharp_context_t * ctx,
55+ std::string chat,
56+ std::string embed,
57+ std::string image,
58+ std::string native_dir,
59+ std::string tokenizer,
60+ bool enable_runner,
61+ int max_tokens,
62+ float temperature_value,
63+ float top_p_value,
64+ int top_k_value)
4765 : context(ctx),
4866 chat_model (std::move(chat)),
4967 embedding_model(std::move(embed)),
50- image_model(std::move(image)) {}
68+ image_model(std::move(image)),
69+ native_model_directory(std::move(native_dir)),
70+ tokenizer_path(std::move(tokenizer)),
71+ enable_native_runner(enable_runner),
72+ max_generated_tokens(max_tokens),
73+ temperature(temperature_value),
74+ top_p(top_p_value),
75+ top_k(top_k_value) {}
5176};
5277
5378namespace {
@@ -57,6 +82,7 @@ thread_local std::string g_last_error;
5782constexpr const char * kNullContext = " Context pointer is null." ;
5883constexpr const char * kNullArray = " Array pointer is null." ;
5984constexpr const char * kNullOutParameter = " Output parameter is null." ;
85+ constexpr const char * kNullSessionOptions = " Session options pointer is null." ;
6086constexpr const char * kShapeMismatch = " Element count does not match provided shape." ;
6187constexpr const char * kNonContiguous = " Array data is not contiguous." ;
6288constexpr const char * kUnsupportedDType = " Unsupported dtype." ;
@@ -316,8 +342,26 @@ mlxsharp_session_t* make_session_ptr(
316342 mlxsharp_context_t * context,
317343 std::string chat_model,
318344 std::string embedding_model,
319- std::string image_model) {
320- auto * handle = new (std::nothrow) mlxsharp_session (context, std::move (chat_model), std::move (embedding_model), std::move (image_model));
345+ std::string image_model,
346+ std::string native_model_directory,
347+ std::string tokenizer_path,
348+ bool enable_native_runner,
349+ int max_generated_tokens,
350+ float temperature,
351+ float top_p,
352+ int top_k) {
353+ auto * handle = new (std::nothrow) mlxsharp_session (
354+ context,
355+ std::move (chat_model),
356+ std::move (embedding_model),
357+ std::move (image_model),
358+ std::move (native_model_directory),
359+ std::move (tokenizer_path),
360+ enable_native_runner,
361+ max_generated_tokens,
362+ temperature,
363+ top_p,
364+ top_k);
321365 if (handle == nullptr ) {
322366 throw std::bad_alloc ();
323367 }
@@ -356,22 +400,43 @@ void ensure_contiguous(const mlx::core::array& arr) {
356400extern " C" {
357401
358402int mlxsharp_create_session (
359- const char * chat_model_id,
360- const char * embedding_model_id,
361- const char * image_model_id,
403+ const mlxsharp_session_options* options,
362404 void ** session) {
363405 if (session == nullptr ) {
364406 return set_error (MLXSHARP_STATUS_INVALID_ARGUMENT, " Session output pointer is null." );
365407 }
366408
367409 return invoke ([&]() -> int {
368- auto chat = chat_model_id != nullptr ? std::string (chat_model_id) : std::string{};
369- auto embed = embedding_model_id != nullptr ? std::string (embedding_model_id) : std::string{};
370- auto image = image_model_id != nullptr ? std::string (image_model_id) : std::string{};
410+ if (options == nullptr ) {
411+ return set_error (MLXSHARP_STATUS_INVALID_ARGUMENT, kNullSessionOptions );
412+ }
413+
414+ auto chat = options->chat_model_id != nullptr ? std::string (options->chat_model_id ) : std::string{};
415+ auto embed = options->embedding_model_id != nullptr ? std::string (options->embedding_model_id ) : std::string{};
416+ auto image = options->image_model_id != nullptr ? std::string (options->image_model_id ) : std::string{};
417+ auto native_dir = options->native_model_directory != nullptr ? std::string (options->native_model_directory ) : std::string{};
418+ auto tokenizer = options->tokenizer_path != nullptr ? std::string (options->tokenizer_path ) : std::string{};
419+ const bool enable_runner = options->enable_native_runner != 0 ;
420+ const int max_tokens = options->max_generated_tokens ;
421+ const float temperature = options->temperature ;
422+ const float top_p = options->top_p ;
423+ const int top_k = options->top_k ;
371424
372425 auto device = mlx::core::default_device ();
426+ mlx::core::set_default_device (device);
373427 auto * context = make_context_ptr (device);
374- auto * handle = make_session_ptr (context, std::move (chat), std::move (embed), std::move (image));
428+ auto * handle = make_session_ptr (
429+ context,
430+ std::move (chat),
431+ std::move (embed),
432+ std::move (image),
433+ std::move (native_dir),
434+ std::move (tokenizer),
435+ enable_runner,
436+ max_tokens,
437+ temperature,
438+ top_p,
439+ top_k);
375440 *session = handle;
376441 return MLXSHARP_STATUS_SUCCESS;
377442 });
0 commit comments