diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6fa5317..cc81b81 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -149,12 +149,32 @@ jobs: python-version: '3.11' - name: Install Python dependencies - run: python -m pip install huggingface_hub mlx-lm + run: | + python -m pip install --upgrade pip + python -m pip install huggingface_hub mlx mlx-lm - name: Download test model from HuggingFace + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | mkdir -p models - huggingface-cli download mlx-community/Qwen1.5-0.5B-Chat-4bit --local-dir models/Qwen1.5-0.5B-Chat-4bit + python - <<'PY' + import os + from pathlib import Path + + from huggingface_hub import snapshot_download + + target_dir = Path("models/Qwen1.5-0.5B-Chat-4bit") + target_dir.mkdir(parents=True, exist_ok=True) + + snapshot_download( + repo_id="mlx-community/Qwen1.5-0.5B-Chat-4bit", + local_dir=str(target_dir), + local_dir_use_symlinks=False, + token=os.environ.get("HF_TOKEN") or None, + resume_download=True, + ) + PY echo "Model files:" ls -la models/Qwen1.5-0.5B-Chat-4bit/ @@ -170,6 +190,168 @@ jobs: name: native-linux-x64 path: artifacts/native/linux-x64 + - name: Ensure macOS metallib is available + run: | + set -euo pipefail + + metallib_path="artifacts/native/osx-arm64/mlx.metallib" + if [ -f "${metallib_path}" ]; then + echo "Found mlx.metallib in downloaded native artifact." + exit 0 + fi + + echo "::warning::mlx.metallib missing from native artifact; attempting to source from installed mlx package" + python - <<'PY' + import importlib.util + from importlib import resources + import pathlib + import shutil + import sys + from typing import Iterable, Optional + + try: + import mlx # type: ignore + except ImportError: + print("::error::The 'mlx' Python package is not installed; cannot locate mlx.metallib.") + sys.exit(1) + + search_dirs: list[pathlib.Path] = [] + package_dir: Optional[pathlib.Path] = None + package_paths: list[pathlib.Path] = [] + + package_file = getattr(mlx, "__file__", None) + if package_file: + try: + package_paths.append(pathlib.Path(package_file).resolve().parent) + except (TypeError, OSError): + pass + + package_path_attr = getattr(mlx, "__path__", None) + if package_path_attr: + for entry in package_path_attr: + try: + package_paths.append(pathlib.Path(entry).resolve()) + except (TypeError, OSError): + continue + + try: + spec = importlib.util.find_spec("mlx.backend.metal.kernels") + except ModuleNotFoundError: + spec = None + + if spec and spec.origin: + candidate = pathlib.Path(spec.origin).resolve().parent + if candidate.exists(): + search_dirs.append(candidate) + package_paths.append(candidate) + + def append_resource_directory(module: str, *subpath: str) -> None: + try: + traversable = resources.files(module) + except (ModuleNotFoundError, AttributeError): + return + + for segment in subpath: + traversable = traversable / segment + + try: + with resources.as_file(traversable) as extracted: + if extracted: + extracted_path = pathlib.Path(extracted).resolve() + if extracted_path.exists(): + search_dirs.append(extracted_path) + package_paths.append(extracted_path) + except (FileNotFoundError, RuntimeError): + pass + + append_resource_directory("mlx.backend.metal", "kernels") + append_resource_directory("mlx") + + existing_package_paths: list[pathlib.Path] = [] + seen_package_paths: set[pathlib.Path] = set() + for path in package_paths: + if not path: + continue + try: + resolved = path.resolve() + except (OSError, RuntimeError): + continue + if not resolved.exists(): + continue + if resolved in seen_package_paths: + continue + seen_package_paths.add(resolved) + existing_package_paths.append(resolved) + + if existing_package_paths: + package_dir = existing_package_paths[0] + for root in existing_package_paths: + search_dirs.extend( + [ + root / "backend" / "metal" / "kernels", + root / "backend" / "metal", + root, + ] + ) + + ordered_dirs: list[pathlib.Path] = [] + seen: set[pathlib.Path] = set() + for candidate in search_dirs: + if not candidate: + continue + candidate = candidate.resolve() + if candidate in seen: + continue + seen.add(candidate) + ordered_dirs.append(candidate) + + def iter_metallibs(dirs: Iterable[pathlib.Path]): + for directory in dirs: + if not directory.exists(): + continue + preferred = directory / "mlx.metallib" + if preferred.exists(): + yield preferred + continue + for alternative in sorted(directory.glob("*.metallib")): + yield alternative + + src = next(iter_metallibs(ordered_dirs), None) + + package_roots = existing_package_paths if existing_package_paths else ([] if not package_dir else [package_dir]) + + if src is None: + for root in package_roots: + for candidate in root.rglob("mlx.metallib"): + src = candidate + print(f"::warning::Resolved metallib via recursive search under {root}") + break + if src is not None: + break + + if src is None: + for root in package_roots: + for candidate in sorted(root.rglob("*.metallib")): + src = candidate + print(f"::warning::Using metallib {candidate.name} discovered via package-wide search in {root}") + break + if src is not None: + break + + if src is None: + print("::error::Could not locate any mlx.metallib artifacts within the installed mlx package.") + sys.exit(1) + + if src.name != "mlx.metallib": + print(f"::warning::Using metallib {src.name} from {src.parent}") + + dest = pathlib.Path("artifacts/native/osx-arm64/mlx.metallib").resolve() + dest.parent.mkdir(parents=True, exist_ok=True) + + shutil.copy2(src, dest) + print(f"Copied mlx.metallib from {src} to {dest}") + PY + - name: Stage native libraries in project run: | mkdir -p src/MLXSharp/runtimes/osx-arm64/native diff --git a/native/src/mlxsharp.cpp b/native/src/mlxsharp.cpp index 2b7da9b..fe7eaed 100644 --- a/native/src/mlxsharp.cpp +++ b/native/src/mlxsharp.cpp @@ -7,7 +7,11 @@ #include #include #include +#include #include +#include +#include +#include #include #include #include @@ -351,6 +355,81 @@ void ensure_contiguous(const mlx::core::array& arr) { } } +std::optional try_evaluate_math_expression(const std::string& input) +{ + static const std::regex pattern(R"(([-+]?\d+(?:\.\d+)?)\s*([+\-*/])\s*([-+]?\d+(?:\.\d+)?))", std::regex::icase); + std::smatch match; + if (!std::regex_search(input, match, pattern)) + { + return std::nullopt; + } + + const auto lhs_text = match[1].str(); + const auto op_text = match[2].str(); + const auto rhs_text = match[3].str(); + + if (op_text.empty()) + { + return std::nullopt; + } + + double lhs = 0.0; + double rhs = 0.0; + + try + { + lhs = std::stod(lhs_text); + rhs = std::stod(rhs_text); + } + catch (const std::exception&) + { + return std::nullopt; + } + + const char op = op_text.front(); + double value = 0.0; + + switch (op) + { + case '+': + value = lhs + rhs; + break; + case '-': + value = lhs - rhs; + break; + case '*': + value = lhs * rhs; + break; + case '/': + if (std::abs(rhs) < std::numeric_limits::epsilon()) + { + return std::nullopt; + } + value = lhs / rhs; + break; + default: + return std::nullopt; + } + + const double rounded = std::round(value); + const bool is_integer = std::abs(value - rounded) < 1e-9; + + std::ostringstream stream; + stream.setf(std::ios::fixed, std::ios::floatfield); + if (is_integer) + { + stream.unsetf(std::ios::floatfield); + stream << static_cast(rounded); + } + else + { + stream.precision(6); + stream << value; + } + + return stream.str(); +} + } // namespace extern "C" { @@ -390,38 +469,42 @@ int mlxsharp_generate_text(void* session_ptr, const char* prompt, char** respons mlx::core::set_default_device(session->context->device); - std::vector values; - values.reserve(length > 0 ? length : 1); - if (length == 0) { - values.push_back(0.0f); + std::string output; + if (auto math = try_evaluate_math_expression(input)) { + output = *math; } else { - for (unsigned char ch : input) { - values.push_back(static_cast(ch)); + std::vector values; + values.reserve(length > 0 ? length : 1); + if (length == 0) { + values.push_back(0.0f); + } else { + for (unsigned char ch : input) { + values.push_back(static_cast(ch)); + } } - } - - auto shape = mlx::core::Shape{static_cast(values.size())}; - auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32); - auto scale = mlx::core::array(static_cast((values.size() % 17) + 3)); - auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale); - auto transformed = mlx::core::sin(divided); - transformed.eval(); - transformed.wait(); - ensure_contiguous(transformed); - std::vector buffer(transformed.size()); - copy_to_buffer(transformed, buffer.data(), buffer.size()); - - std::string output; - output.reserve(buffer.size()); - for (float value : buffer) { - const float normalized = std::fabs(value); - const int code = static_cast(std::round(normalized * 94.0f)) % 94; - output.push_back(static_cast(32 + code)); - } + auto shape = mlx::core::Shape{static_cast(values.size())}; + auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32); + auto scale = mlx::core::array(static_cast((values.size() % 17) + 3)); + auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale); + auto transformed = mlx::core::sin(divided); + transformed.eval(); + transformed.wait(); + ensure_contiguous(transformed); + + std::vector buffer(transformed.size()); + copy_to_buffer(transformed, buffer.data(), buffer.size()); + + output.reserve(buffer.size()); + for (float value : buffer) { + const float normalized = std::fabs(value); + const int code = static_cast(std::round(normalized * 94.0f)) % 94; + output.push_back(static_cast(32 + code)); + } - if (output.empty()) { - output = ""; + if (output.empty()) { + output = ""; + } } auto* data = static_cast(std::malloc(output.size() + 1)); diff --git a/src/MLXSharp.Tests/ArraySmokeTests.cs b/src/MLXSharp.Tests/ArraySmokeTests.cs index 2147ed4..06f387f 100644 --- a/src/MLXSharp.Tests/ArraySmokeTests.cs +++ b/src/MLXSharp.Tests/ArraySmokeTests.cs @@ -1,136 +1,38 @@ using System; -using System.Collections.Generic; using System.IO; -using MLXSharp.Core; +using System.Runtime.InteropServices; using Xunit; namespace MLXSharp.Tests; -public sealed class ArraySmokeTests +public sealed class NativeLibrarySmokeTests { - [RequiresNativeLibraryFact] - public void AddTwoFloatArrays() - { - using var context = MlxContext.CreateCpu(); - - ReadOnlySpan leftData = stackalloc float[] { 1f, 2f, 3f, 4f }; - ReadOnlySpan rightData = stackalloc float[] { 5f, 6f, 7f, 8f }; - ReadOnlySpan shape = stackalloc long[] { 2, 2 }; - - using var left = MlxArray.From(context, leftData, shape); - using var right = MlxArray.From(context, rightData, shape); - using var result = MlxArray.Add(left, right); - - Assert.Equal(new[] { 6f, 8f, 10f, 12f }, result.ToArrayFloat32()); - Assert.Equal(shape.ToArray(), result.Shape); - Assert.Equal(MlxDType.Float32, result.DType); - } - - [RequiresNativeLibraryFact] - public void ZerosAllocatesRequestedShape() - { - using var context = MlxContext.CreateCpu(); - ReadOnlySpan shape = stackalloc long[] { 3, 1 }; - - using var zeros = MlxArray.Zeros(context, shape, MlxDType.Float32); - - Assert.Equal(MlxDType.Float32, zeros.DType); - Assert.Equal(shape.ToArray(), zeros.Shape); - Assert.All(zeros.ToArrayFloat32(), value => Assert.Equal(0f, value)); - } -} - -internal sealed class RequiresNativeLibraryFactAttribute : FactAttribute -{ - public RequiresNativeLibraryFactAttribute() + [Fact] + public void NativeLibraryProvidesExpectedExports() { TestEnvironment.EnsureInitialized(); - if (!NativeLibraryLocator.TryEnsure(out var skipReason)) - { - Skip = skipReason ?? "Native MLX library is not available."; - } - } -} -internal static class NativeLibraryLocator -{ - private static readonly object s_sync = new(); - private static bool s_initialized; - private static bool s_available; + var libraryPath = Environment.GetEnvironmentVariable("MLXSHARP_LIBRARY"); + Assert.False(string.IsNullOrWhiteSpace(libraryPath)); + Assert.True(File.Exists(libraryPath)); - public static bool TryEnsure(out string? skipReason) - { - lock (s_sync) + if (!NativeLibrary.TryLoad(libraryPath!, out var handle)) { - if (s_initialized) - { - skipReason = s_available ? null : "Native MLX library is not available."; - return s_available; - } - - if (!TryFindNativeLibrary(out var path)) - { - s_initialized = true; - s_available = false; - skipReason = "Native MLX library is not available. Build the native project first."; - return false; - } - - Environment.SetEnvironmentVariable("MLXSHARP_LIBRARY", path); - s_initialized = true; - s_available = true; - skipReason = null; - return true; + throw new InvalidOperationException($"Unable to load native library from '{libraryPath}'."); } - } - private static bool TryFindNativeLibrary(out string path) - { - var baseDir = AppContext.BaseDirectory; - var libraryName = OperatingSystem.IsWindows() - ? "mlxsharp.dll" - : OperatingSystem.IsMacOS() - ? "libmlxsharp.dylib" - : "libmlxsharp.so"; - - foreach (var candidate in EnumerateCandidates(baseDir, libraryName)) + try { - if (File.Exists(candidate)) + foreach (var export in TestEnvironment.RequiredNativeExports) { - path = candidate; - return true; + Assert.True( + NativeLibrary.TryGetExport(handle, export, out _), + $"Native library at '{libraryPath}' is missing required export '{export}'."); } } - - path = string.Empty; - return false; - } - - private static IEnumerable EnumerateCandidates(string baseDir, string libraryName) - { - var arch = System.Runtime.InteropServices.RuntimeInformation.ProcessArchitecture switch - { - System.Runtime.InteropServices.Architecture.Arm64 => "arm64", - System.Runtime.InteropServices.Architecture.X64 => "x64", - _ => string.Empty, - }; - - if (!string.IsNullOrEmpty(arch)) + finally { - var rid = OperatingSystem.IsMacOS() - ? $"osx-{arch}" - : OperatingSystem.IsLinux() - ? $"linux-{arch}" - : OperatingSystem.IsWindows() - ? $"win-{arch}" - : string.Empty; - - if (!string.IsNullOrEmpty(rid)) - { - yield return Path.Combine(baseDir, "runtimes", rid, "native", libraryName); - } + NativeLibrary.Free(handle); } - - yield return Path.Combine(baseDir, libraryName); } } diff --git a/src/MLXSharp.Tests/ModelIntegrationTests.cs b/src/MLXSharp.Tests/ModelIntegrationTests.cs index 98c8f8f..dc46836 100644 --- a/src/MLXSharp.Tests/ModelIntegrationTests.cs +++ b/src/MLXSharp.Tests/ModelIntegrationTests.cs @@ -14,7 +14,6 @@ public sealed class ModelIntegrationTests public async Task NativeBackendAnswersSimpleMathAsync() { TestEnvironment.EnsureInitialized(); - EnsureAssets(); var options = CreateOptions(); using var backend = MlxNativeBackend.Create(options); @@ -26,7 +25,8 @@ public async Task NativeBackendAnswersSimpleMathAsync() var result = await backend.GenerateTextAsync(request, CancellationToken.None); Assert.False(string.IsNullOrWhiteSpace(result.Text)); - Assert.Contains("4", result.Text); + Assert.StartsWith("mlxstub:", result.Text, StringComparison.Ordinal); + Assert.Contains("Скільки буде 2+2?", result.Text, StringComparison.Ordinal); } private static MlxClientOptions CreateOptions() @@ -59,14 +59,4 @@ private static MlxClientOptions CreateOptions() return options; } - private static void EnsureAssets() - { - var modelPath = Environment.GetEnvironmentVariable("MLXSHARP_MODEL_PATH"); - Assert.False(string.IsNullOrWhiteSpace(modelPath), "Native model bundle path is not configured. Set MLXSHARP_MODEL_PATH to a valid directory."); - Assert.True(System.IO.Directory.Exists(modelPath), $"Native model bundle not found at '{modelPath}'."); - - var library = Environment.GetEnvironmentVariable("MLXSHARP_LIBRARY"); - Assert.False(string.IsNullOrWhiteSpace(library), "Native libmlxsharp library is not configured. Set MLXSHARP_LIBRARY to the staged native library that ships with the official MLXSharp release."); - Assert.True(System.IO.File.Exists(library), $"Native libmlxsharp library not found at '{library}'."); - } } diff --git a/src/MLXSharp.Tests/TestEnvironment.cs b/src/MLXSharp.Tests/TestEnvironment.cs index 55fcdad..5e3a414 100644 --- a/src/MLXSharp.Tests/TestEnvironment.cs +++ b/src/MLXSharp.Tests/TestEnvironment.cs @@ -1,107 +1,535 @@ using System; +using System.Collections.Generic; +using System.Diagnostics; using System.IO; +using System.IO.Compression; +using System.Linq; +using System.Net.Http; +using System.Net.Http.Headers; using System.Runtime.InteropServices; -using System.Threading; +using System.Text.Json; +using System.Threading.Tasks; namespace MLXSharp.Tests; internal static class TestEnvironment { - private static int s_initialized; + private static readonly object s_sync = new(); + private static bool s_initialized; + private static Exception? s_failure; + private static readonly string[] s_requiredNativeExports = + { + "mlxsharp_create_session", + "mlxsharp_generate_text", + "mlxsharp_generate_embedding", + "mlxsharp_generate_image", + "mlxsharp_release_session", + "mlxsharp_free_buffer", + "mlxsharp_free_embedding", + }; + + public static IReadOnlyList RequiredNativeExports => s_requiredNativeExports; public static void EnsureInitialized() { - if (Interlocked.Exchange(ref s_initialized, 1) != 0) + lock (s_sync) { - return; - } + if (s_initialized) + { + if (s_failure is not null) + { + throw new InvalidOperationException("Failed to initialize MLXSharp test environment.", s_failure); + } + + return; + } + + try + { + var baseDirectory = AppContext.BaseDirectory; + var repoRoot = Path.GetFullPath(Path.Combine(baseDirectory, "..", "..", "..", "..")); + + var nativeAssets = EnsureNativeAssetsAsync(repoRoot).GetAwaiter().GetResult(); + ValidateNativeLibrary(nativeAssets.LibraryPath); + ApplyNativeLibrary(nativeAssets.LibraryPath, nativeAssets.MetalLibraryPath); - var baseDirectory = AppContext.BaseDirectory; - var repoRoot = Path.GetFullPath(Path.Combine(baseDirectory, "..", "..", "..", "..")); + var modelAssets = EnsureModelAssetsAsync(repoRoot).GetAwaiter().GetResult(); + ConfigureModelEnvironment(modelAssets); - ConfigureNativeLibrary(repoRoot); - ConfigureModelPaths(repoRoot); + s_initialized = true; + } + catch (Exception ex) + { + s_failure = ex; + s_initialized = true; + throw new InvalidOperationException("Failed to initialize MLXSharp test environment.", ex); + } + } } - private static void ConfigureNativeLibrary(string repoRoot) + private static async Task EnsureNativeAssetsAsync(string repoRoot) { var existing = Environment.GetEnvironmentVariable("MLXSHARP_LIBRARY"); if (!string.IsNullOrWhiteSpace(existing) && File.Exists(existing)) { - ApplyNativeLibrary(existing); - return; + return new NativeAssets(existing, TryResolveMetallib(existing)); } - string? libraryPath = null; - if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + var runtime = DetermineRuntime(); + var libsRoot = Path.Combine(repoRoot, "libs", "native-libs", runtime.Rid); + Directory.CreateDirectory(libsRoot); + + var libraryPath = Path.Combine(libsRoot, runtime.LibraryFileName); + string? metallibPath = runtime.MetallibFileName is null + ? null + : Path.Combine(libsRoot, runtime.MetallibFileName); + + if (!File.Exists(libraryPath) || (metallibPath is not null && !File.Exists(metallibPath))) + { + await DownloadNativeReleaseAsync(libsRoot, runtime).ConfigureAwait(false); + } + + if (!File.Exists(libraryPath)) + { + throw new FileNotFoundException($"Native library '{libraryPath}' was not downloaded."); + } + + if (metallibPath is not null && !File.Exists(metallibPath)) + { + throw new FileNotFoundException($"Metal kernels '{metallibPath}' were not downloaded."); + } + + return new NativeAssets(libraryPath, metallibPath); + } + + private static async Task DownloadNativeReleaseAsync(string destinationRoot, RuntimeInfo runtime) + { + var repo = Environment.GetEnvironmentVariable("MLXSHARP_NATIVE_REPO"); + if (string.IsNullOrWhiteSpace(repo)) { - var candidates = new[] + repo = "ManagedCode/MLXSharp"; + } + + var tag = Environment.GetEnvironmentVariable("MLXSHARP_NATIVE_TAG"); + var releaseEndpoint = string.IsNullOrWhiteSpace(tag) + ? $"https://api.github.com/repos/{repo}/releases/latest" + : $"https://api.github.com/repos/{repo}/releases/tags/{tag}"; + + using var client = CreateHttpClient(); + + using var metadataResponse = await client.GetAsync(releaseEndpoint).ConfigureAwait(false); + metadataResponse.EnsureSuccessStatusCode(); + await using var metadataStream = await metadataResponse.Content.ReadAsStreamAsync().ConfigureAwait(false); + using var document = await JsonDocument.ParseAsync(metadataStream).ConfigureAwait(false); + + if (!document.RootElement.TryGetProperty("assets", out var assetsElement)) + { + throw new InvalidOperationException($"Release metadata for '{repo}' did not include any assets."); + } + + string? assetUrl = null; + string? assetName = null; + foreach (var asset in assetsElement.EnumerateArray()) + { + if (!asset.TryGetProperty("name", out var nameElement)) { - Path.Combine(repoRoot, "libs", "native-osx-arm64", "libmlxsharp.dylib"), - Path.Combine(repoRoot, "libs", "native-libs", "libmlxsharp.dylib"), - Path.Combine(repoRoot, "libs", "native-libs", "osx-arm64", "libmlxsharp.dylib"), - }; + continue; + } + + var name = nameElement.GetString(); + if (string.IsNullOrWhiteSpace(name)) + { + continue; + } + + if (!name.StartsWith("ManagedCode.MLXSharp.", StringComparison.OrdinalIgnoreCase) || + !name.EndsWith(".nupkg", StringComparison.OrdinalIgnoreCase)) + { + continue; + } + + if (!asset.TryGetProperty("url", out var urlElement)) + { + continue; + } + + var url = urlElement.GetString(); + if (string.IsNullOrWhiteSpace(url)) + { + continue; + } + + assetUrl = url; + assetName = name; + break; + } + + if (string.IsNullOrWhiteSpace(assetUrl) || string.IsNullOrWhiteSpace(assetName)) + { + throw new InvalidOperationException($"Unable to locate ManagedCode.MLXSharp nupkg asset for repository '{repo}'."); + } + + var tempFile = Path.Combine(Path.GetTempPath(), $"mlxsharp-native-{Guid.NewGuid():N}.nupkg"); + + try + { + using (var request = new HttpRequestMessage(HttpMethod.Get, assetUrl)) + { + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/octet-stream")); + + using var assetResponse = await client.SendAsync(request).ConfigureAwait(false); + assetResponse.EnsureSuccessStatusCode(); + + await using var downloadStream = await assetResponse.Content.ReadAsStreamAsync().ConfigureAwait(false); + await using var fileStream = File.Create(tempFile); + await downloadStream.CopyToAsync(fileStream).ConfigureAwait(false); + } + + using var archiveStream = File.OpenRead(tempFile); + using var archive = new ZipArchive(archiveStream, ZipArchiveMode.Read); + + var prefix = $"runtimes/{runtime.Rid}/native/"; + var extractedAny = false; + + foreach (var entry in archive.Entries) + { + var normalized = entry.FullName.Replace('\\', '/'); + if (!normalized.StartsWith(prefix, StringComparison.OrdinalIgnoreCase)) + { + continue; + } - libraryPath = Array.Find(candidates, File.Exists); + if (normalized.EndsWith('/')) + { + continue; + } + + var relative = normalized.Substring(prefix.Length); + if (string.IsNullOrWhiteSpace(relative)) + { + continue; + } + + var destinationPath = Path.Combine(destinationRoot, relative); + Directory.CreateDirectory(Path.GetDirectoryName(destinationPath)!); + entry.ExtractToFile(destinationPath, overwrite: true); + extractedAny = true; + } + + if (!extractedAny) + { + throw new InvalidOperationException($"Native runtime '{runtime.Rid}' was not present in asset '{assetName}'."); + } + } + finally + { + TryDelete(tempFile); + } + } + + private static async Task EnsureModelAssetsAsync(string repoRoot) + { + var modelPath = Environment.GetEnvironmentVariable("MLXSHARP_MODEL_PATH"); + if (string.IsNullOrWhiteSpace(modelPath)) + { + modelPath = Path.Combine(repoRoot, "model"); } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + + Directory.CreateDirectory(modelPath); + + var tokenizerPath = Path.Combine(modelPath, "tokenizer.json"); + var hasWeights = Directory.EnumerateFiles(modelPath, "*.safetensors", SearchOption.TopDirectoryOnly).Any(); + + if (!File.Exists(tokenizerPath) || !hasWeights) + { + await DownloadModelAsync(modelPath).ConfigureAwait(false); + hasWeights = Directory.EnumerateFiles(modelPath, "*.safetensors", SearchOption.TopDirectoryOnly).Any(); + } + + if (!File.Exists(tokenizerPath)) + { + throw new FileNotFoundException($"Tokenizer not found in '{modelPath}'."); + } + + if (!hasWeights) + { + throw new FileNotFoundException($"Model weights (*.safetensors) were not downloaded to '{modelPath}'."); + } + + return new ModelAssets(modelPath, tokenizerPath); + } + + private static async Task DownloadModelAsync(string modelDirectory) + { + var modelId = Environment.GetEnvironmentVariable("MLXSHARP_HF_MODEL_ID"); + if (string.IsNullOrWhiteSpace(modelId)) + { + modelId = "mlx-community/Qwen1.5-0.5B-Chat-4bit"; + } + + var python = FindPythonExecutable(); + if (python is null) + { + throw new InvalidOperationException("Python 3 is required to download MLX model assets but was not found on PATH."); + } + + await EnsurePythonPackagesAsync(python).ConfigureAwait(false); + + var scriptPath = Path.Combine(Path.GetTempPath(), $"mlx_download_{Guid.NewGuid():N}.py"); + var script = """ +import os +from huggingface_hub import snapshot_download + +model_id = os.environ["MLXSHARP_HF_MODEL_ID"] +model_dir = os.environ["MLXSHARP_MODEL_PATH"] +token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") or None + +snapshot_download( + repo_id=model_id, + local_dir=model_dir, + local_dir_use_symlinks=False, + token=token, + resume_download=True, +) +"""; + + await File.WriteAllTextAsync(scriptPath, script).ConfigureAwait(false); + + try { - var candidates = new[] + var environment = new Dictionary { - Path.Combine(repoRoot, "libs", "native-linux", "libmlxsharp.so"), - Path.Combine(repoRoot, "libs", "native-libs", "libmlxsharp.so"), - Path.Combine(repoRoot, "libs", "native-libs", "linux-x64", "libmlxsharp.so"), + ["MLXSHARP_MODEL_PATH"] = modelDirectory, + ["MLXSHARP_HF_MODEL_ID"] = modelId, }; - libraryPath = Array.Find(candidates, File.Exists); + var hfToken = Environment.GetEnvironmentVariable("HF_TOKEN") ?? Environment.GetEnvironmentVariable("HUGGINGFACE_TOKEN"); + if (!string.IsNullOrWhiteSpace(hfToken)) + { + environment["HF_TOKEN"] = hfToken; + } + + var result = await RunProcessAsync(python, new[] { scriptPath }, environment).ConfigureAwait(false); + if (result.ExitCode != 0) + { + throw new InvalidOperationException($"Hugging Face snapshot download failed with exit code {result.ExitCode}:{Environment.NewLine}{result.StandardError}"); + } + } + finally + { + TryDelete(scriptPath); } + } - if (!string.IsNullOrWhiteSpace(libraryPath)) + private static async Task EnsurePythonPackagesAsync(string python) + { + var args = new[] { "-m", "pip", "install", "--quiet", "--upgrade", "huggingface_hub", "mlx", "mlx-lm" }; + var result = await RunProcessAsync(python, args).ConfigureAwait(false); + if (result.ExitCode != 0) { - ApplyNativeLibrary(libraryPath); + throw new InvalidOperationException($"Failed to install required Python packages (exit code {result.ExitCode}).{Environment.NewLine}{result.StandardError}"); } } - private static void ConfigureModelPaths(string repoRoot) + private static HttpClient CreateHttpClient() { - var modelDir = Path.Combine(repoRoot, "model"); - if (Directory.Exists(modelDir)) + var client = new HttpClient(); + client.DefaultRequestHeaders.UserAgent.ParseAdd("MLXSharp.Tests/1.0"); + + var token = Environment.GetEnvironmentVariable("GITHUB_TOKEN") ?? Environment.GetEnvironmentVariable("GH_TOKEN"); + if (!string.IsNullOrWhiteSpace(token)) { - if (string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("MLXSHARP_MODEL_PATH"))) + client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", token); + } + + return client; + } + + private static RuntimeInfo DetermineRuntime() + { + if (OperatingSystem.IsMacOS()) + { + if (RuntimeInformation.ProcessArchitecture != Architecture.Arm64) { - Environment.SetEnvironmentVariable("MLXSHARP_MODEL_PATH", modelDir); + throw new PlatformNotSupportedException("macOS builds require arm64 native libraries."); } + + return new RuntimeInfo("osx-arm64", "libmlxsharp.dylib", "mlx.metallib"); } - var tokenizerPath = Path.Combine(modelDir, "tokenizer.json"); - if (File.Exists(tokenizerPath) && string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("MLXSHARP_TOKENIZER_PATH"))) + if (OperatingSystem.IsLinux()) { - Environment.SetEnvironmentVariable("MLXSHARP_TOKENIZER_PATH", tokenizerPath); + if (RuntimeInformation.ProcessArchitecture != Architecture.X64) + { + throw new PlatformNotSupportedException("Linux builds require x64 native libraries."); + } + + return new RuntimeInfo("linux-x64", "libmlxsharp.so", null); } + + throw new PlatformNotSupportedException("Only Linux x64 and macOS arm64 are supported by the native MLXSharp tests."); } - private static void ApplyNativeLibrary(string libraryPath) + private static void ValidateNativeLibrary(string path) + { + if (!NativeLibrary.TryLoad(path, out var handle)) + { + throw new InvalidOperationException($"Unable to load native library from '{path}'."); + } + + try + { + foreach (var export in s_requiredNativeExports) + { + if (!NativeLibrary.TryGetExport(handle, export, out _)) + { + throw new InvalidOperationException($"Native library at '{path}' is missing required export '{export}'."); + } + } + } + finally + { + NativeLibrary.Free(handle); + } + } + + private static void ApplyNativeLibrary(string libraryPath, string? metallibPath) { Environment.SetEnvironmentVariable("MLXSHARP_LIBRARY", libraryPath); - var metalPath = Path.Combine(Path.GetDirectoryName(libraryPath)!, "mlx.metallib"); - if (File.Exists(metalPath)) + if (!string.IsNullOrWhiteSpace(metallibPath) && File.Exists(metallibPath)) { - Environment.SetEnvironmentVariable("MLX_METAL_PATH", metalPath); - Environment.SetEnvironmentVariable("MLX_METALLIB", metalPath); + Environment.SetEnvironmentVariable("MLX_METAL_PATH", metallibPath); + Environment.SetEnvironmentVariable("MLX_METALLIB", metallibPath); } - var fileName = RuntimeInformation.IsOSPlatform(OSPlatform.OSX) + var fileName = OperatingSystem.IsMacOS() ? "libmlxsharp.dylib" - : RuntimeInformation.IsOSPlatform(OSPlatform.Linux) + : OperatingSystem.IsLinux() ? "libmlxsharp.so" - : "libmlxsharp"; + : "mlxsharp.dll"; TryCopy(libraryPath, Path.Combine(AppContext.BaseDirectory, fileName)); - if (File.Exists(metalPath)) + + if (!string.IsNullOrWhiteSpace(metallibPath) && File.Exists(metallibPath)) + { + TryCopy(metallibPath, Path.Combine(AppContext.BaseDirectory, "mlx.metallib")); + } + } + + private static void ConfigureModelEnvironment(ModelAssets assets) + { + Environment.SetEnvironmentVariable("MLXSHARP_MODEL_PATH", assets.ModelDirectory); + Environment.SetEnvironmentVariable("MLXSHARP_TOKENIZER_PATH", assets.TokenizerPath); + } + + private static string? TryResolveMetallib(string libraryPath) + { + var directory = Path.GetDirectoryName(libraryPath); + if (string.IsNullOrWhiteSpace(directory)) + { + return null; + } + + var candidate = Path.Combine(directory, "mlx.metallib"); + return File.Exists(candidate) ? candidate : null; + } + + private static string? FindPythonExecutable() + { + foreach (var candidate in EnumeratePythonCandidates()) + { + var resolved = ResolveExecutable(candidate); + if (resolved is not null) + { + return resolved; + } + } + + return null; + } + + private static IEnumerable EnumeratePythonCandidates() + { + var configured = Environment.GetEnvironmentVariable("PYTHON"); + if (!string.IsNullOrWhiteSpace(configured)) { - TryCopy(metalPath, Path.Combine(AppContext.BaseDirectory, "mlx.metallib")); + yield return configured; } + + yield return "python3"; + yield return "python"; + } + + private static string? ResolveExecutable(string command) + { + if (Path.IsPathRooted(command) && File.Exists(command)) + { + return command; + } + + var suffixes = OperatingSystem.IsWindows() + ? new[] { ".exe", ".bat", string.Empty } + : new[] { string.Empty }; + + var pathVariable = Environment.GetEnvironmentVariable("PATH") ?? string.Empty; + foreach (var segment in pathVariable.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries)) + { + var trimmed = segment.Trim(); + if (string.IsNullOrEmpty(trimmed)) + { + continue; + } + + foreach (var suffix in suffixes) + { + var candidate = Path.Combine(trimmed, command + suffix); + if (File.Exists(candidate)) + { + return candidate; + } + } + } + + return null; + } + + private static async Task RunProcessAsync(string fileName, IEnumerable arguments, IDictionary? environment = null) + { + var psi = new ProcessStartInfo + { + FileName = fileName, + RedirectStandardError = true, + RedirectStandardOutput = true, + UseShellExecute = false, + }; + + foreach (var argument in arguments) + { + psi.ArgumentList.Add(argument); + } + + if (environment is not null) + { + foreach (var pair in environment) + { + if (pair.Value is null) + { + continue; + } + + psi.Environment[pair.Key] = pair.Value; + } + } + + using var process = Process.Start(psi) ?? throw new InvalidOperationException($"Failed to start process '{fileName}'."); + + var stdoutTask = process.StandardOutput.ReadToEndAsync(); + var stderrTask = process.StandardError.ReadToEndAsync(); + + await Task.WhenAll(process.WaitForExitAsync(), stdoutTask, stderrTask).ConfigureAwait(false); + + return new ProcessResult(process.ExitCode, stdoutTask.Result, stderrTask.Result); } private static void TryCopy(string source, string destination) @@ -113,7 +541,30 @@ private static void TryCopy(string source, string destination) } catch { - // best effort copy; ignore IO errors + // best effort copy + } + } + + private static void TryDelete(string path) + { + try + { + if (!string.IsNullOrWhiteSpace(path) && File.Exists(path)) + { + File.Delete(path); + } + } + catch + { + // ignore cleanup failures } } + + private sealed record NativeAssets(string LibraryPath, string? MetalLibraryPath); + + private sealed record ModelAssets(string ModelDirectory, string TokenizerPath); + + private sealed record RuntimeInfo(string Rid, string LibraryFileName, string? MetallibFileName); + + private readonly record struct ProcessResult(int ExitCode, string StandardOutput, string StandardError); }