Skip to content

Commit 09848b9

Browse files
committed
Merge remote-tracking branch 'origin/codex/integrate-mlx-lm-with-.net-framework'
# Conflicts: # src/MLXSharp.Tests/ArraySmokeTests.cs # src/MLXSharp.Tests/ModelIntegrationTests.cs # src/MLXSharp.Tests/TestEnvironment.cs
2 parents 3ed2f32 + 96e3d31 commit 09848b9

File tree

2 files changed

+177
-25
lines changed

2 files changed

+177
-25
lines changed
Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,43 @@
11
using System;
2-
using System.IO;
3-
using System.Runtime.InteropServices;
2+
using MLXSharp.Core;
43
using Xunit;
54

65
namespace MLXSharp.Tests;
76

87
public sealed class NativeLibrarySmokeTests
98
{
109
[Fact]
11-
public void NativeLibraryProvidesExpectedExports()
10+
public void AddTwoFloatArrays()
1211
{
1312
TestEnvironment.EnsureInitialized();
1413

15-
var libraryPath = Environment.GetEnvironmentVariable("MLXSHARP_LIBRARY");
16-
Assert.False(string.IsNullOrWhiteSpace(libraryPath));
17-
Assert.True(File.Exists(libraryPath));
18-
19-
if (!NativeLibrary.TryLoad(libraryPath!, out var handle))
20-
{
21-
throw new InvalidOperationException($"Unable to load native library from '{libraryPath}'.");
22-
}
23-
24-
try
25-
{
26-
foreach (var export in TestEnvironment.RequiredNativeExports)
27-
{
28-
Assert.True(
29-
NativeLibrary.TryGetExport(handle, export, out _),
30-
$"Native library at '{libraryPath}' is missing required export '{export}'.");
31-
}
32-
}
33-
finally
34-
{
35-
NativeLibrary.Free(handle);
36-
}
14+
using var context = MlxContext.CreateCpu();
15+
16+
ReadOnlySpan<float> leftData = stackalloc float[] { 1f, 2f, 3f, 4f };
17+
ReadOnlySpan<float> rightData = stackalloc float[] { 5f, 6f, 7f, 8f };
18+
ReadOnlySpan<long> shape = stackalloc long[] { 2, 2 };
19+
20+
using var left = MlxArray.From(context, leftData, shape);
21+
using var right = MlxArray.From(context, rightData, shape);
22+
using var result = MlxArray.Add(left, right);
23+
24+
Assert.Equal(new[] { 6f, 8f, 10f, 12f }, result.ToArrayFloat32());
25+
Assert.Equal(shape.ToArray(), result.Shape);
26+
Assert.Equal(MlxDType.Float32, result.DType);
27+
}
28+
29+
[Fact]
30+
public void ZerosAllocatesRequestedShape()
31+
{
32+
TestEnvironment.EnsureInitialized();
33+
34+
using var context = MlxContext.CreateCpu();
35+
ReadOnlySpan<long> shape = stackalloc long[] { 3, 1 };
36+
37+
using var zeros = MlxArray.Zeros(context, shape, MlxDType.Float32);
38+
39+
Assert.Equal(MlxDType.Float32, zeros.DType);
40+
Assert.Equal(shape.ToArray(), zeros.Shape);
41+
Assert.All(zeros.ToArrayFloat32(), value => Assert.Equal(0f, value));
3742
}
3843
}
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
using System;
2+
using System.IO;
3+
using System.IO.Compression;
4+
using System.Net.Http;
5+
using System.Text.Json;
6+
7+
namespace MLXSharp.Tests;
8+
9+
internal static class NativeBinaryManager
10+
{
11+
private const string PackageId = "managedcode.mlxsharp";
12+
private const string BaseUrl = "https://api.nuget.org/v3-flatcontainer";
13+
14+
private static readonly object s_sync = new();
15+
private static bool s_attempted;
16+
private static string? s_cachedPath;
17+
private static string? s_lastError;
18+
19+
public static bool TryEnsureNativeLibrary(string repoRoot, out string? libraryPath, out string? error)
20+
{
21+
if (!OperatingSystem.IsMacOS() && !OperatingSystem.IsLinux())
22+
{
23+
libraryPath = null;
24+
error = "Official native binaries are only published for macOS and Linux.";
25+
return false;
26+
}
27+
28+
lock (s_sync)
29+
{
30+
if (!string.IsNullOrEmpty(s_cachedPath) && File.Exists(s_cachedPath))
31+
{
32+
libraryPath = s_cachedPath;
33+
error = null;
34+
return true;
35+
}
36+
37+
if (s_attempted)
38+
{
39+
libraryPath = s_cachedPath;
40+
error = s_lastError;
41+
return libraryPath is not null;
42+
}
43+
44+
s_attempted = true;
45+
46+
try
47+
{
48+
var path = DownloadOfficialBinary(repoRoot);
49+
s_cachedPath = path;
50+
s_lastError = null;
51+
libraryPath = path;
52+
error = null;
53+
return true;
54+
}
55+
catch (Exception ex)
56+
{
57+
s_cachedPath = null;
58+
s_lastError = ex.Message;
59+
libraryPath = null;
60+
error = s_lastError;
61+
return false;
62+
}
63+
}
64+
}
65+
66+
private static string DownloadOfficialBinary(string repoRoot)
67+
{
68+
var rid = GetRuntimeIdentifier();
69+
var fileName = OperatingSystem.IsMacOS() ? "libmlxsharp.dylib" : "libmlxsharp.so";
70+
var nativeDirectory = Path.Combine(repoRoot, "libs", "native-libs", rid);
71+
Directory.CreateDirectory(nativeDirectory);
72+
73+
var destination = Path.Combine(nativeDirectory, fileName);
74+
if (File.Exists(destination))
75+
{
76+
return destination;
77+
}
78+
79+
using var client = new HttpClient();
80+
var version = ResolvePackageVersion(client);
81+
var packageUrl = $"{BaseUrl}/{PackageId}/{version}/{PackageId}.{version}.nupkg";
82+
83+
using var packageStream = client.GetStreamAsync(packageUrl).GetAwaiter().GetResult();
84+
var tempFile = Path.GetTempFileName();
85+
try
86+
{
87+
using (var fileStream = File.OpenWrite(tempFile))
88+
{
89+
packageStream.CopyTo(fileStream);
90+
}
91+
92+
using var archive = ZipFile.OpenRead(tempFile);
93+
var entryPath = $"runtimes/{rid}/native/{fileName}";
94+
var entry = archive.GetEntry(entryPath) ??
95+
throw new InvalidOperationException($"The official package does not contain {entryPath}.");
96+
97+
entry.ExtractToFile(destination, overwrite: true);
98+
return destination;
99+
}
100+
finally
101+
{
102+
try
103+
{
104+
File.Delete(tempFile);
105+
}
106+
catch
107+
{
108+
// ignore cleanup errors
109+
}
110+
}
111+
}
112+
113+
private static string ResolvePackageVersion(HttpClient client)
114+
{
115+
var overrideVersion = Environment.GetEnvironmentVariable("MLXSHARP_OFFICIAL_NATIVE_VERSION");
116+
if (!string.IsNullOrWhiteSpace(overrideVersion))
117+
{
118+
return overrideVersion.Trim();
119+
}
120+
121+
var indexUrl = $"{BaseUrl}/{PackageId}/index.json";
122+
using var stream = client.GetStreamAsync(indexUrl).GetAwaiter().GetResult();
123+
using var document = JsonDocument.Parse(stream);
124+
if (!document.RootElement.TryGetProperty("versions", out var versions) || versions.GetArrayLength() == 0)
125+
{
126+
throw new InvalidOperationException("Unable to determine the latest ManagedCode.MLXSharp package version.");
127+
}
128+
129+
return versions[versions.GetArrayLength() - 1].GetString()
130+
?? throw new InvalidOperationException("ManagedCode.MLXSharp package version entry was null.");
131+
}
132+
133+
private static string GetRuntimeIdentifier()
134+
{
135+
if (OperatingSystem.IsMacOS())
136+
{
137+
return "osx-arm64";
138+
}
139+
140+
if (OperatingSystem.IsLinux())
141+
{
142+
return "linux-x64";
143+
}
144+
145+
throw new PlatformNotSupportedException("Unsupported platform for native MLXSharp binaries.");
146+
}
147+
}

0 commit comments

Comments
 (0)