Skip to content

Commit 53cc881

Browse files
authored
Merge pull request #8 from managedcode/codex/integrate-mlx-lm-with-.net-framework
Handle missing mlx.metallib in CI pipeline
2 parents 0503a45 + 6a1e823 commit 53cc881

File tree

5 files changed

+168
-19
lines changed

5 files changed

+168
-19
lines changed

.github/workflows/ci.yml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,17 @@ jobs:
172172

173173
- name: Stage native libraries in project
174174
run: |
175+
set -euo pipefail
176+
175177
mkdir -p src/MLXSharp/runtimes/osx-arm64/native
176178
cp artifacts/native/osx-arm64/libmlxsharp.dylib src/MLXSharp/runtimes/osx-arm64/native/
177-
cp artifacts/native/osx-arm64/mlx.metallib src/MLXSharp/runtimes/osx-arm64/native/
179+
180+
if [ -f artifacts/native/osx-arm64/mlx.metallib ]; then
181+
cp artifacts/native/osx-arm64/mlx.metallib src/MLXSharp/runtimes/osx-arm64/native/
182+
else
183+
echo "::warning::mlx.metallib not found in macOS native artifact; continuing without Metal shaders"
184+
fi
185+
178186
mkdir -p src/MLXSharp/runtimes/linux-x64/native
179187
cp artifacts/native/linux-x64/libmlxsharp.so src/MLXSharp/runtimes/linux-x64/native/
180188
@@ -186,7 +194,12 @@ jobs:
186194
TEST_OUTPUT="src/MLXSharp.Tests/bin/Release/net9.0"
187195
mkdir -p "$TEST_OUTPUT/runtimes/osx-arm64/native"
188196
cp src/MLXSharp/runtimes/osx-arm64/native/libmlxsharp.dylib "$TEST_OUTPUT/runtimes/osx-arm64/native/"
189-
cp src/MLXSharp/runtimes/osx-arm64/native/mlx.metallib "$TEST_OUTPUT/runtimes/osx-arm64/native/"
197+
198+
if [ -f src/MLXSharp/runtimes/osx-arm64/native/mlx.metallib ]; then
199+
cp src/MLXSharp/runtimes/osx-arm64/native/mlx.metallib "$TEST_OUTPUT/runtimes/osx-arm64/native/"
200+
else
201+
echo "::warning::mlx.metallib not staged; tests will continue without Metal shaders"
202+
fi
190203
ls -la "$TEST_OUTPUT/runtimes/osx-arm64/native/"
191204
192205
- name: Run tests

native/include/mlxsharp/api.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,21 @@ typedef struct mlx_usage {
137137
int output_tokens;
138138
} mlx_usage;
139139

140+
typedef struct mlxsharp_session_options {
141+
const char* chat_model_id;
142+
const char* embedding_model_id;
143+
const char* image_model_id;
144+
const char* native_model_directory;
145+
const char* tokenizer_path;
146+
int enable_native_runner;
147+
int max_generated_tokens;
148+
float temperature;
149+
float top_p;
150+
int top_k;
151+
} mlxsharp_session_options;
152+
140153
int mlxsharp_create_session(
141-
const char* chat_model_id,
142-
const char* embedding_model_id,
143-
const char* image_model_id,
154+
const mlxsharp_session_options* options,
144155
void** session);
145156

146157
int mlxsharp_generate_text(

native/src/mlxsharp.cpp

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,36 @@ struct mlxsharp_session {
4343
std::string chat_model;
4444
std::string embedding_model;
4545
std::string image_model;
46-
mlxsharp_session(mlxsharp_context_t* ctx, std::string chat, std::string embed, std::string image)
46+
std::string native_model_directory;
47+
std::string tokenizer_path;
48+
bool enable_native_runner;
49+
int max_generated_tokens;
50+
float temperature;
51+
float top_p;
52+
int top_k;
53+
mlxsharp_session(
54+
mlxsharp_context_t* ctx,
55+
std::string chat,
56+
std::string embed,
57+
std::string image,
58+
std::string native_dir,
59+
std::string tokenizer,
60+
bool enable_runner,
61+
int max_tokens,
62+
float temperature_value,
63+
float top_p_value,
64+
int top_k_value)
4765
: context(ctx),
4866
chat_model(std::move(chat)),
4967
embedding_model(std::move(embed)),
50-
image_model(std::move(image)) {}
68+
image_model(std::move(image)),
69+
native_model_directory(std::move(native_dir)),
70+
tokenizer_path(std::move(tokenizer)),
71+
enable_native_runner(enable_runner),
72+
max_generated_tokens(max_tokens),
73+
temperature(temperature_value),
74+
top_p(top_p_value),
75+
top_k(top_k_value) {}
5176
};
5277

5378
namespace {
@@ -57,6 +82,7 @@ thread_local std::string g_last_error;
5782
constexpr const char* kNullContext = "Context pointer is null.";
5883
constexpr const char* kNullArray = "Array pointer is null.";
5984
constexpr const char* kNullOutParameter = "Output parameter is null.";
85+
constexpr const char* kNullSessionOptions = "Session options pointer is null.";
6086
constexpr const char* kShapeMismatch = "Element count does not match provided shape.";
6187
constexpr const char* kNonContiguous = "Array data is not contiguous.";
6288
constexpr const char* kUnsupportedDType = "Unsupported dtype.";
@@ -316,8 +342,26 @@ mlxsharp_session_t* make_session_ptr(
316342
mlxsharp_context_t* context,
317343
std::string chat_model,
318344
std::string embedding_model,
319-
std::string image_model) {
320-
auto* handle = new (std::nothrow) mlxsharp_session(context, std::move(chat_model), std::move(embedding_model), std::move(image_model));
345+
std::string image_model,
346+
std::string native_model_directory,
347+
std::string tokenizer_path,
348+
bool enable_native_runner,
349+
int max_generated_tokens,
350+
float temperature,
351+
float top_p,
352+
int top_k) {
353+
auto* handle = new (std::nothrow) mlxsharp_session(
354+
context,
355+
std::move(chat_model),
356+
std::move(embedding_model),
357+
std::move(image_model),
358+
std::move(native_model_directory),
359+
std::move(tokenizer_path),
360+
enable_native_runner,
361+
max_generated_tokens,
362+
temperature,
363+
top_p,
364+
top_k);
321365
if (handle == nullptr) {
322366
throw std::bad_alloc();
323367
}
@@ -356,22 +400,43 @@ void ensure_contiguous(const mlx::core::array& arr) {
356400
extern "C" {
357401

358402
int mlxsharp_create_session(
359-
const char* chat_model_id,
360-
const char* embedding_model_id,
361-
const char* image_model_id,
403+
const mlxsharp_session_options* options,
362404
void** session) {
363405
if (session == nullptr) {
364406
return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Session output pointer is null.");
365407
}
366408

367409
return invoke([&]() -> int {
368-
auto chat = chat_model_id != nullptr ? std::string(chat_model_id) : std::string{};
369-
auto embed = embedding_model_id != nullptr ? std::string(embedding_model_id) : std::string{};
370-
auto image = image_model_id != nullptr ? std::string(image_model_id) : std::string{};
410+
if (options == nullptr) {
411+
return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, kNullSessionOptions);
412+
}
413+
414+
auto chat = options->chat_model_id != nullptr ? std::string(options->chat_model_id) : std::string{};
415+
auto embed = options->embedding_model_id != nullptr ? std::string(options->embedding_model_id) : std::string{};
416+
auto image = options->image_model_id != nullptr ? std::string(options->image_model_id) : std::string{};
417+
auto native_dir = options->native_model_directory != nullptr ? std::string(options->native_model_directory) : std::string{};
418+
auto tokenizer = options->tokenizer_path != nullptr ? std::string(options->tokenizer_path) : std::string{};
419+
const bool enable_runner = options->enable_native_runner != 0;
420+
const int max_tokens = options->max_generated_tokens;
421+
const float temperature = options->temperature;
422+
const float top_p = options->top_p;
423+
const int top_k = options->top_k;
371424

372425
auto device = mlx::core::default_device();
426+
mlx::core::set_default_device(device);
373427
auto* context = make_context_ptr(device);
374-
auto* handle = make_session_ptr(context, std::move(chat), std::move(embed), std::move(image));
428+
auto* handle = make_session_ptr(
429+
context,
430+
std::move(chat),
431+
std::move(embed),
432+
std::move(image),
433+
std::move(native_dir),
434+
std::move(tokenizer),
435+
enable_runner,
436+
max_tokens,
437+
temperature,
438+
top_p,
439+
top_k);
375440
*session = handle;
376441
return MLXSHARP_STATUS_SUCCESS;
377442
});

src/MLXSharp/Backends/MlxNativeBackend.cs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ public static MlxNativeBackend Create(MlxClientOptions options)
2727
ArgumentNullException.ThrowIfNull(options);
2828
MlxNativeLibrary.EnsureLoaded(options.LibraryPath);
2929

30-
var status = MlxNativeMethods.CreateSession(options.ChatModelId, options.EmbeddingModelId, options.ImageModelId, out var session);
30+
using var sessionOptions = new MarshaledSessionOptions(options);
31+
var status = MlxNativeMethods.CreateSession(in sessionOptions.Value, out var session);
3132
if (status != 0 || session.IsInvalid)
3233
{
3334
session.Dispose();
@@ -201,6 +202,50 @@ private MlxTextResult GenerateTextFallback(MlxTextRequest request)
201202
return existing;
202203
}
203204

205+
private sealed class MarshaledSessionOptions : IDisposable
206+
{
207+
public MlxSessionOptions Value;
208+
209+
public MarshaledSessionOptions(MlxClientOptions options)
210+
{
211+
Value = new MlxSessionOptions
212+
{
213+
ChatModelId = Allocate(options.ChatModelId),
214+
EmbeddingModelId = Allocate(options.EmbeddingModelId),
215+
ImageModelId = Allocate(options.ImageModelId),
216+
NativeModelDirectory = Allocate(options.NativeModelDirectory),
217+
TokenizerPath = Allocate(options.TokenizerPath),
218+
EnableNativeModelRunner = options.EnableNativeModelRunner ? 1 : 0,
219+
MaxGeneratedTokens = options.MaxGeneratedTokens,
220+
Temperature = options.Temperature,
221+
TopP = options.TopP,
222+
TopK = options.TopK,
223+
};
224+
}
225+
226+
public void Dispose()
227+
{
228+
Free(Value.ChatModelId);
229+
Free(Value.EmbeddingModelId);
230+
Free(Value.ImageModelId);
231+
Free(Value.NativeModelDirectory);
232+
Free(Value.TokenizerPath);
233+
}
234+
235+
private static nint Allocate(string? value)
236+
{
237+
return value is null ? nint.Zero : Marshal.StringToCoTaskMemUTF8(value);
238+
}
239+
240+
private static void Free(nint pointer)
241+
{
242+
if (pointer != nint.Zero)
243+
{
244+
Marshal.FreeCoTaskMem(pointer);
245+
}
246+
}
247+
}
248+
204249
private void ThrowIfDisposed()
205250
{
206251
if (_disposed)

src/MLXSharp/Native/MlxNativeMethods.cs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ internal static partial class MlxNativeMethods
1818
{
1919
private const string LibraryName = "libmlxsharp";
2020

21-
[LibraryImport(LibraryName, EntryPoint = "mlxsharp_create_session", StringMarshalling = StringMarshalling.Utf8)]
22-
public static partial int CreateSession(string chatModelId, string embeddingModelId, string imageModelId, out SafeMlxSessionHandle session);
21+
[LibraryImport(LibraryName, EntryPoint = "mlxsharp_create_session")]
22+
public static partial int CreateSession(in MlxSessionOptions options, out SafeMlxSessionHandle session);
2323

2424
[LibraryImport(LibraryName, EntryPoint = "mlxsharp_release_session")]
2525
public static partial void ReleaseSession(nint session);
@@ -142,3 +142,18 @@ internal struct MlxUsage
142142
public int InputTokens;
143143
public int OutputTokens;
144144
}
145+
146+
[StructLayout(LayoutKind.Sequential)]
147+
internal struct MlxSessionOptions
148+
{
149+
public nint ChatModelId;
150+
public nint EmbeddingModelId;
151+
public nint ImageModelId;
152+
public nint NativeModelDirectory;
153+
public nint TokenizerPath;
154+
public int EnableNativeModelRunner;
155+
public int MaxGeneratedTokens;
156+
public float Temperature;
157+
public float TopP;
158+
public int TopK;
159+
}

0 commit comments

Comments
 (0)