Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,17 @@ jobs:

- name: Stage native libraries in project
run: |
set -euo pipefail

mkdir -p src/MLXSharp/runtimes/osx-arm64/native
cp artifacts/native/osx-arm64/libmlxsharp.dylib src/MLXSharp/runtimes/osx-arm64/native/
Copy link

Copilot AI Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding set -euo pipefail at line 175 makes the script more strict, but the existing cp command at line 178 will still fail if libmlxsharp.dylib is missing. Consider adding a similar conditional check for the required libmlxsharp.dylib file to maintain consistency with the optional mlx.metallib handling, or document why this file's absence should cause a hard failure while mlx.metallib's absence should not.

Suggested change
cp artifacts/native/osx-arm64/libmlxsharp.dylib src/MLXSharp/runtimes/osx-arm64/native/
if [ -f artifacts/native/osx-arm64/libmlxsharp.dylib ]; then
cp artifacts/native/osx-arm64/libmlxsharp.dylib src/MLXSharp/runtimes/osx-arm64/native/
else
echo "::error::libmlxsharp.dylib not found in macOS native artifact; cannot continue"
exit 1
fi

Copilot uses AI. Check for mistakes.
cp artifacts/native/osx-arm64/mlx.metallib src/MLXSharp/runtimes/osx-arm64/native/

if [ -f artifacts/native/osx-arm64/mlx.metallib ]; then
cp artifacts/native/osx-arm64/mlx.metallib src/MLXSharp/runtimes/osx-arm64/native/
else
echo "::warning::mlx.metallib not found in macOS native artifact; continuing without Metal shaders"
fi

mkdir -p src/MLXSharp/runtimes/linux-x64/native
cp artifacts/native/linux-x64/libmlxsharp.so src/MLXSharp/runtimes/linux-x64/native/

Expand All @@ -186,7 +194,12 @@ jobs:
TEST_OUTPUT="src/MLXSharp.Tests/bin/Release/net9.0"
mkdir -p "$TEST_OUTPUT/runtimes/osx-arm64/native"
cp src/MLXSharp/runtimes/osx-arm64/native/libmlxsharp.dylib "$TEST_OUTPUT/runtimes/osx-arm64/native/"
cp src/MLXSharp/runtimes/osx-arm64/native/mlx.metallib "$TEST_OUTPUT/runtimes/osx-arm64/native/"

if [ -f src/MLXSharp/runtimes/osx-arm64/native/mlx.metallib ]; then
cp src/MLXSharp/runtimes/osx-arm64/native/mlx.metallib "$TEST_OUTPUT/runtimes/osx-arm64/native/"
else
echo "::warning::mlx.metallib not staged; tests will continue without Metal shaders"
fi
ls -la "$TEST_OUTPUT/runtimes/osx-arm64/native/"

- name: Run tests
Expand Down
17 changes: 14 additions & 3 deletions native/include/mlxsharp/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,21 @@ typedef struct mlx_usage {
int output_tokens;
} mlx_usage;

typedef struct mlxsharp_session_options {
const char* chat_model_id;
const char* embedding_model_id;
const char* image_model_id;
const char* native_model_directory;
const char* tokenizer_path;
int enable_native_runner;
int max_generated_tokens;
float temperature;
float top_p;
int top_k;
} mlxsharp_session_options;

int mlxsharp_create_session(
const char* chat_model_id,
const char* embedding_model_id,
const char* image_model_id,
const mlxsharp_session_options* options,
void** session);

int mlxsharp_generate_text(
Expand Down
87 changes: 76 additions & 11 deletions native/src/mlxsharp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,36 @@ struct mlxsharp_session {
std::string chat_model;
std::string embedding_model;
std::string image_model;
mlxsharp_session(mlxsharp_context_t* ctx, std::string chat, std::string embed, std::string image)
std::string native_model_directory;
std::string tokenizer_path;
bool enable_native_runner;
int max_generated_tokens;
float temperature;
float top_p;
int top_k;
mlxsharp_session(
mlxsharp_context_t* ctx,
std::string chat,
std::string embed,
std::string image,
std::string native_dir,
std::string tokenizer,
bool enable_runner,
int max_tokens,
float temperature_value,
float top_p_value,
int top_k_value)
: context(ctx),
chat_model(std::move(chat)),
embedding_model(std::move(embed)),
image_model(std::move(image)) {}
image_model(std::move(image)),
native_model_directory(std::move(native_dir)),
tokenizer_path(std::move(tokenizer)),
enable_native_runner(enable_runner),
max_generated_tokens(max_tokens),
temperature(temperature_value),
top_p(top_p_value),
top_k(top_k_value) {}
};

namespace {
Expand All @@ -57,6 +82,7 @@ thread_local std::string g_last_error;
constexpr const char* kNullContext = "Context pointer is null.";
constexpr const char* kNullArray = "Array pointer is null.";
constexpr const char* kNullOutParameter = "Output parameter is null.";
constexpr const char* kNullSessionOptions = "Session options pointer is null.";
constexpr const char* kShapeMismatch = "Element count does not match provided shape.";
constexpr const char* kNonContiguous = "Array data is not contiguous.";
constexpr const char* kUnsupportedDType = "Unsupported dtype.";
Expand Down Expand Up @@ -316,8 +342,26 @@ mlxsharp_session_t* make_session_ptr(
mlxsharp_context_t* context,
std::string chat_model,
std::string embedding_model,
std::string image_model) {
auto* handle = new (std::nothrow) mlxsharp_session(context, std::move(chat_model), std::move(embedding_model), std::move(image_model));
std::string image_model,
std::string native_model_directory,
std::string tokenizer_path,
bool enable_native_runner,
int max_generated_tokens,
float temperature,
float top_p,
int top_k) {
auto* handle = new (std::nothrow) mlxsharp_session(
context,
std::move(chat_model),
std::move(embedding_model),
std::move(image_model),
std::move(native_model_directory),
std::move(tokenizer_path),
enable_native_runner,
max_generated_tokens,
temperature,
top_p,
top_k);
if (handle == nullptr) {
throw std::bad_alloc();
}
Expand Down Expand Up @@ -356,22 +400,43 @@ void ensure_contiguous(const mlx::core::array& arr) {
extern "C" {

int mlxsharp_create_session(
const char* chat_model_id,
const char* embedding_model_id,
const char* image_model_id,
const mlxsharp_session_options* options,
void** session) {
if (session == nullptr) {
return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Session output pointer is null.");
}

return invoke([&]() -> int {
auto chat = chat_model_id != nullptr ? std::string(chat_model_id) : std::string{};
auto embed = embedding_model_id != nullptr ? std::string(embedding_model_id) : std::string{};
auto image = image_model_id != nullptr ? std::string(image_model_id) : std::string{};
if (options == nullptr) {
return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, kNullSessionOptions);
}

auto chat = options->chat_model_id != nullptr ? std::string(options->chat_model_id) : std::string{};
auto embed = options->embedding_model_id != nullptr ? std::string(options->embedding_model_id) : std::string{};
auto image = options->image_model_id != nullptr ? std::string(options->image_model_id) : std::string{};
auto native_dir = options->native_model_directory != nullptr ? std::string(options->native_model_directory) : std::string{};
auto tokenizer = options->tokenizer_path != nullptr ? std::string(options->tokenizer_path) : std::string{};
const bool enable_runner = options->enable_native_runner != 0;
const int max_tokens = options->max_generated_tokens;
const float temperature = options->temperature;
const float top_p = options->top_p;
const int top_k = options->top_k;

auto device = mlx::core::default_device();
mlx::core::set_default_device(device);
auto* context = make_context_ptr(device);
auto* handle = make_session_ptr(context, std::move(chat), std::move(embed), std::move(image));
auto* handle = make_session_ptr(
context,
std::move(chat),
std::move(embed),
std::move(image),
std::move(native_dir),
std::move(tokenizer),
enable_runner,
max_tokens,
temperature,
top_p,
top_k);
*session = handle;
return MLXSHARP_STATUS_SUCCESS;
});
Expand Down
47 changes: 46 additions & 1 deletion src/MLXSharp/Backends/MlxNativeBackend.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ public static MlxNativeBackend Create(MlxClientOptions options)
ArgumentNullException.ThrowIfNull(options);
MlxNativeLibrary.EnsureLoaded(options.LibraryPath);

var status = MlxNativeMethods.CreateSession(options.ChatModelId, options.EmbeddingModelId, options.ImageModelId, out var session);
using var sessionOptions = new MarshaledSessionOptions(options);
var status = MlxNativeMethods.CreateSession(in sessionOptions.Value, out var session);
if (status != 0 || session.IsInvalid)
{
session.Dispose();
Expand Down Expand Up @@ -201,6 +202,50 @@ private MlxTextResult GenerateTextFallback(MlxTextRequest request)
return existing;
}

private sealed class MarshaledSessionOptions : IDisposable
{
public MlxSessionOptions Value;

public MarshaledSessionOptions(MlxClientOptions options)
{
Value = new MlxSessionOptions
{
ChatModelId = Allocate(options.ChatModelId),
EmbeddingModelId = Allocate(options.EmbeddingModelId),
ImageModelId = Allocate(options.ImageModelId),
NativeModelDirectory = Allocate(options.NativeModelDirectory),
TokenizerPath = Allocate(options.TokenizerPath),
EnableNativeModelRunner = options.EnableNativeModelRunner ? 1 : 0,
MaxGeneratedTokens = options.MaxGeneratedTokens,
Temperature = options.Temperature,
TopP = options.TopP,
TopK = options.TopK,
};
}

public void Dispose()
{
Free(Value.ChatModelId);
Free(Value.EmbeddingModelId);
Free(Value.ImageModelId);
Free(Value.NativeModelDirectory);
Free(Value.TokenizerPath);
}

private static nint Allocate(string? value)
{
return value is null ? nint.Zero : Marshal.StringToCoTaskMemUTF8(value);
}

private static void Free(nint pointer)
{
if (pointer != nint.Zero)
{
Marshal.FreeCoTaskMem(pointer);
}
}
}

private void ThrowIfDisposed()
{
if (_disposed)
Expand Down
19 changes: 17 additions & 2 deletions src/MLXSharp/Native/MlxNativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ internal static partial class MlxNativeMethods
{
private const string LibraryName = "libmlxsharp";

[LibraryImport(LibraryName, EntryPoint = "mlxsharp_create_session", StringMarshalling = StringMarshalling.Utf8)]
public static partial int CreateSession(string chatModelId, string embeddingModelId, string imageModelId, out SafeMlxSessionHandle session);
[LibraryImport(LibraryName, EntryPoint = "mlxsharp_create_session")]
public static partial int CreateSession(in MlxSessionOptions options, out SafeMlxSessionHandle session);

[LibraryImport(LibraryName, EntryPoint = "mlxsharp_release_session")]
public static partial void ReleaseSession(nint session);
Expand Down Expand Up @@ -142,3 +142,18 @@ internal struct MlxUsage
public int InputTokens;
public int OutputTokens;
}

[StructLayout(LayoutKind.Sequential)]
internal struct MlxSessionOptions
{
public nint ChatModelId;
public nint EmbeddingModelId;
public nint ImageModelId;
public nint NativeModelDirectory;
public nint TokenizerPath;
public int EnableNativeModelRunner;
public int MaxGeneratedTokens;
public float Temperature;
public float TopP;
public int TopK;
}
Loading