|
1 | 1 | using System; |
2 | | -using System.Collections.Generic; |
3 | 2 | using System.IO; |
4 | 3 | using System.Runtime.InteropServices; |
5 | | -using MLXSharp.Core; |
6 | 4 | using Xunit; |
7 | 5 |
|
8 | 6 | namespace MLXSharp.Tests; |
9 | 7 |
|
10 | | -public sealed class ArraySmokeTests |
| 8 | +public sealed class NativeLibrarySmokeTests |
11 | 9 | { |
12 | | - [RequiresNativeLibraryFact] |
13 | | - public void AddTwoFloatArrays() |
14 | | - { |
15 | | - using var context = MlxContext.CreateCpu(); |
16 | | - |
17 | | - ReadOnlySpan<float> leftData = stackalloc float[] { 1f, 2f, 3f, 4f }; |
18 | | - ReadOnlySpan<float> rightData = stackalloc float[] { 5f, 6f, 7f, 8f }; |
19 | | - ReadOnlySpan<long> shape = stackalloc long[] { 2, 2 }; |
20 | | - |
21 | | - using var left = MlxArray.From(context, leftData, shape); |
22 | | - using var right = MlxArray.From(context, rightData, shape); |
23 | | - using var result = MlxArray.Add(left, right); |
24 | | - |
25 | | - Assert.Equal(new[] { 6f, 8f, 10f, 12f }, result.ToArrayFloat32()); |
26 | | - Assert.Equal(shape.ToArray(), result.Shape); |
27 | | - Assert.Equal(MlxDType.Float32, result.DType); |
28 | | - } |
29 | | - |
30 | | - [RequiresNativeLibraryFact] |
31 | | - public void ZerosAllocatesRequestedShape() |
32 | | - { |
33 | | - using var context = MlxContext.CreateCpu(); |
34 | | - ReadOnlySpan<long> shape = stackalloc long[] { 3, 1 }; |
35 | | - |
36 | | - using var zeros = MlxArray.Zeros(context, shape, MlxDType.Float32); |
37 | | - |
38 | | - Assert.Equal(MlxDType.Float32, zeros.DType); |
39 | | - Assert.Equal(shape.ToArray(), zeros.Shape); |
40 | | - Assert.All(zeros.ToArrayFloat32(), value => Assert.Equal(0f, value)); |
41 | | - } |
42 | | -} |
43 | | - |
44 | | -internal sealed class RequiresNativeLibraryFactAttribute : FactAttribute |
45 | | -{ |
46 | | - public RequiresNativeLibraryFactAttribute() |
| 10 | + [Fact] |
| 11 | + public void NativeLibraryProvidesExpectedExports() |
47 | 12 | { |
48 | 13 | TestEnvironment.EnsureInitialized(); |
49 | | - if (!NativeLibraryLocator.TryEnsure(out var skipReason)) |
50 | | - { |
51 | | - Skip = skipReason ?? "Native MLX library is not available."; |
52 | | - } |
53 | | - } |
54 | | -} |
55 | | - |
56 | | -internal static class NativeLibraryLocator |
57 | | -{ |
58 | | - private static readonly object s_sync = new(); |
59 | | - private static bool s_initialized; |
60 | | - private static bool s_available; |
61 | | - |
62 | | - public static bool TryEnsure(out string? skipReason) |
63 | | - { |
64 | | - lock (s_sync) |
65 | | - { |
66 | | - if (s_initialized) |
67 | | - { |
68 | | - skipReason = s_available ? null : "Native MLX library is not available."; |
69 | | - return s_available; |
70 | | - } |
71 | | - |
72 | | - if (!TryFindNativeLibrary(out var path)) |
73 | | - { |
74 | | - s_initialized = true; |
75 | | - s_available = false; |
76 | | - skipReason = "Native MLX library is not available. Build the native project first."; |
77 | | - return false; |
78 | | - } |
79 | 14 |
|
80 | | - if (!HasRequiredExports(path, out skipReason)) |
81 | | - { |
82 | | - s_initialized = true; |
83 | | - s_available = false; |
84 | | - return false; |
85 | | - } |
| 15 | + var libraryPath = Environment.GetEnvironmentVariable("MLXSHARP_LIBRARY"); |
| 16 | + Assert.False(string.IsNullOrWhiteSpace(libraryPath)); |
| 17 | + Assert.True(File.Exists(libraryPath)); |
86 | 18 |
|
87 | | - Environment.SetEnvironmentVariable("MLXSHARP_LIBRARY", path); |
88 | | - s_initialized = true; |
89 | | - s_available = true; |
90 | | - skipReason = null; |
91 | | - return true; |
92 | | - } |
93 | | - } |
94 | | - |
95 | | - private static bool HasRequiredExports(string path, out string? reason) |
96 | | - { |
97 | | - if (!NativeLibrary.TryLoad(path, out var handle)) |
| 19 | + if (!NativeLibrary.TryLoad(libraryPath!, out var handle)) |
98 | 20 | { |
99 | | - reason = $"Unable to load native library from '{path}'."; |
100 | | - return false; |
| 21 | + throw new InvalidOperationException($"Unable to load native library from '{libraryPath}'."); |
101 | 22 | } |
102 | 23 |
|
103 | 24 | try |
104 | 25 | { |
105 | | - foreach (var export in new[] { "mlxsharp_context_create", "mlxsharp_array_from_buffer", "mlxsharp_generate_text" }) |
| 26 | + foreach (var export in TestEnvironment.RequiredNativeExports) |
106 | 27 | { |
107 | | - if (!NativeLibrary.TryGetExport(handle, export, out _)) |
108 | | - { |
109 | | - reason = $"Native library at '{path}' is missing required export '{export}'. Rebuild MLXSharp native binaries."; |
110 | | - return false; |
111 | | - } |
| 28 | + Assert.True( |
| 29 | + NativeLibrary.TryGetExport(handle, export, out _), |
| 30 | + $"Native library at '{libraryPath}' is missing required export '{export}'."); |
112 | 31 | } |
113 | | - |
114 | | - reason = null; |
115 | | - return true; |
116 | 32 | } |
117 | 33 | finally |
118 | 34 | { |
119 | 35 | NativeLibrary.Free(handle); |
120 | 36 | } |
121 | 37 | } |
122 | | - |
123 | | - private static bool TryFindNativeLibrary(out string path) |
124 | | - { |
125 | | - var baseDir = AppContext.BaseDirectory; |
126 | | - var libraryName = OperatingSystem.IsWindows() |
127 | | - ? "mlxsharp.dll" |
128 | | - : OperatingSystem.IsMacOS() |
129 | | - ? "libmlxsharp.dylib" |
130 | | - : "libmlxsharp.so"; |
131 | | - |
132 | | - foreach (var candidate in EnumerateCandidates(baseDir, libraryName)) |
133 | | - { |
134 | | - if (File.Exists(candidate)) |
135 | | - { |
136 | | - path = candidate; |
137 | | - return true; |
138 | | - } |
139 | | - } |
140 | | - |
141 | | - path = string.Empty; |
142 | | - return false; |
143 | | - } |
144 | | - |
145 | | - private static IEnumerable<string> EnumerateCandidates(string baseDir, string libraryName) |
146 | | - { |
147 | | - var arch = System.Runtime.InteropServices.RuntimeInformation.ProcessArchitecture switch |
148 | | - { |
149 | | - System.Runtime.InteropServices.Architecture.Arm64 => "arm64", |
150 | | - System.Runtime.InteropServices.Architecture.X64 => "x64", |
151 | | - _ => string.Empty, |
152 | | - }; |
153 | | - |
154 | | - if (!string.IsNullOrEmpty(arch)) |
155 | | - { |
156 | | - var rid = OperatingSystem.IsMacOS() |
157 | | - ? $"osx-{arch}" |
158 | | - : OperatingSystem.IsLinux() |
159 | | - ? $"linux-{arch}" |
160 | | - : OperatingSystem.IsWindows() |
161 | | - ? $"win-{arch}" |
162 | | - : string.Empty; |
163 | | - |
164 | | - if (!string.IsNullOrEmpty(rid)) |
165 | | - { |
166 | | - yield return Path.Combine(baseDir, "runtimes", rid, "native", libraryName); |
167 | | - } |
168 | | - } |
169 | | - |
170 | | - yield return Path.Combine(baseDir, libraryName); |
171 | | - } |
172 | 38 | } |
0 commit comments