diff --git a/Cargo.lock b/Cargo.lock index 5b33e92db3..d5d55db632 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11571,6 +11571,14 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "mlx" +version = "0.1.0" +dependencies = [ + "data", + "swift-rs 1.0.7 (git+https://github.com/yujonglee/swift-rs?rev=41a1605)", +] + [[package]] name = "monostate" version = "0.1.18" diff --git a/Cargo.toml b/Cargo.toml index a87819d430..9707ad7d4b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,6 +76,7 @@ hypr-llm-proxy = { path = "crates/llm-proxy", package = "llm-proxy" } hypr-loops = { path = "crates/loops", package = "loops" } hypr-mac = { path = "crates/mac", package = "mac" } hypr-mcp = { path = "crates/mcp", package = "mcp" } +hypr-mlx = { path = "crates/mlx", package = "mlx" } hypr-moonshine = { path = "crates/moonshine", package = "moonshine" } hypr-nango = { path = "crates/nango", package = "nango" } hypr-notch = { path = "crates/notch", package = "notch" } diff --git a/crates/mlx/.gitignore b/crates/mlx/.gitignore new file mode 100644 index 0000000000..9a1d9108ef --- /dev/null +++ b/crates/mlx/.gitignore @@ -0,0 +1 @@ +swift-lib/.build diff --git a/crates/mlx/Cargo.toml b/crates/mlx/Cargo.toml new file mode 100644 index 0000000000..add03931bc --- /dev/null +++ b/crates/mlx/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "mlx" +version = "0.1.0" +edition = "2024" + +[target.'cfg(target_os = "macos")'.build-dependencies] +swift-rs = { workspace = true, features = ["build"] } + +[target.'cfg(target_os = "macos")'.dependencies] +swift-rs = { workspace = true } + +[dev-dependencies] +hypr-data = { workspace = true } diff --git a/crates/mlx/build.rs b/crates/mlx/build.rs new file mode 100644 index 0000000000..4539dfa5c4 --- /dev/null +++ b/crates/mlx/build.rs @@ -0,0 +1,156 @@ +fn main() { + #[cfg(target_os = "macos")] + { + let out_dir = std::env::var("OUT_DIR").unwrap(); + + swift_rs::SwiftLinker::new("15.0") + .with_package("hypr-mlx-swift", "./swift-lib/") + .link(); + println!("cargo:rustc-link-lib=c++"); + + let swift_path = std::process::Command::new("xcrun") + .args(["--toolchain", "default", "--find", "swift"]) + .output() + .expect("failed to run xcrun"); + let swift_bin = String::from_utf8_lossy(&swift_path.stdout) + .trim() + .to_string(); + let swift_bin_path = std::path::Path::new(&swift_bin); + if let Some(usr) = swift_bin_path.parent().and_then(|b| b.parent()) { + let lib_dir = usr.join("lib"); + if let Ok(entries) = std::fs::read_dir(&lib_dir) { + for entry in entries.flatten() { + let name = entry.file_name(); + let name = name.to_string_lossy(); + if name.starts_with("swift-") { + let compat_dir = entry.path().join("macosx"); + if compat_dir.exists() { + println!("cargo:rustc-link-search=native={}", compat_dir.display()); + println!("cargo:rustc-link-arg=-Wl,-rpath,{}", compat_dir.display()); + } + } + } + } + } + + compile_metal_shaders(&out_dir); + } + + #[cfg(not(target_os = "macos"))] + { + println!("cargo:warning=Swift linking is only available on macOS"); + } +} + +#[cfg(target_os = "macos")] +fn compile_metal_shaders(out_dir: &str) { + use std::path::{Path, PathBuf}; + use std::process::Command; + + let swift_build_dir = Path::new(out_dir).join("swift-rs/hypr-mlx-swift"); + let metal_dir = swift_build_dir.join("checkouts/mlx-swift/Source/Cmlx/mlx-generated/metal"); + + if !metal_dir.exists() { + println!("cargo:warning=Metal shaders directory not found, skipping metallib build"); + return; + } + + let metal_compiler = PathBuf::from( + "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/metal", + ); + let metallib_tool = PathBuf::from( + "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/metallib", + ); + + if !metal_compiler.exists() || !metallib_tool.exists() { + println!("cargo:warning=Metal compiler not found, skipping metallib build"); + return; + } + + let air_dir = Path::new(out_dir).join("metal-air"); + std::fs::create_dir_all(&air_dir).expect("failed to create air dir"); + + let mut air_files = Vec::new(); + collect_metal_files( + &metal_dir, + &mut air_files, + &metal_compiler, + &metal_dir, + &air_dir, + ); + + if air_files.is_empty() { + println!("cargo:warning=No .metal files found"); + return; + } + + let metallib_path = Path::new(out_dir).join("mlx.metallib"); + let mut cmd = Command::new(&metallib_tool); + for air in &air_files { + cmd.arg(air); + } + cmd.arg("-o").arg(&metallib_path); + let status = cmd.status().expect("failed to run metallib"); + assert!(status.success(), "metallib linking failed"); + + // Copy to target dir so the binary can find it at runtime + if let Ok(target_dir) = find_target_deps_dir(out_dir) { + let dest = target_dir.join("mlx.metallib"); + std::fs::copy(&metallib_path, &dest).ok(); + } +} + +#[cfg(target_os = "macos")] +fn collect_metal_files( + dir: &std::path::Path, + air_files: &mut Vec, + metal_compiler: &std::path::Path, + include_dir: &std::path::Path, + air_dir: &std::path::Path, +) { + let Ok(entries) = std::fs::read_dir(dir) else { + return; + }; + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() { + collect_metal_files(&path, air_files, metal_compiler, include_dir, air_dir); + } else if path.extension().is_some_and(|e| e == "metal") { + let stem = path + .strip_prefix(include_dir) + .unwrap_or(&path) + .to_string_lossy() + .replace('/', "_") + .replace(".metal", ".air"); + let air_path = air_dir.join(&stem); + + let status = std::process::Command::new(metal_compiler) + .args(["-std=metal3.1", "-w"]) + .arg(format!("-I{}", include_dir.display())) + .arg("-c") + .arg(&path) + .arg("-o") + .arg(&air_path) + .status() + .expect("failed to run metal compiler"); + + if status.success() { + air_files.push(air_path); + } else { + println!("cargo:warning=Failed to compile {}", path.display()); + } + } + } +} + +#[cfg(target_os = "macos")] +fn find_target_deps_dir(out_dir: &str) -> Result { + let out = std::path::Path::new(out_dir); + let target_profile = out + .parent() + .and_then(|p| p.parent()) + .and_then(|p| p.parent()) + .ok_or(())?; + let deps = target_profile.join("deps"); + if deps.exists() { Ok(deps) } else { Err(()) } +} diff --git a/crates/mlx/src/lib.rs b/crates/mlx/src/lib.rs new file mode 100644 index 0000000000..037802536a --- /dev/null +++ b/crates/mlx/src/lib.rs @@ -0,0 +1,91 @@ +#[cfg(target_os = "macos")] +use swift_rs::{Bool, SRObject, SRString, swift}; + +#[cfg(target_os = "macos")] +swift!(fn _mlx_smoke_test() -> Bool); + +#[cfg(target_os = "macos")] +pub fn smoke_test() -> bool { + unsafe { _mlx_smoke_test() } +} + +#[cfg(target_os = "macos")] +swift!(fn _mlx_qwen_asr_init(model_source: &SRString) -> Bool); + +#[cfg(target_os = "macos")] +swift!(fn _mlx_qwen_asr_transcribe_file(audio_path: &SRString) -> SRObject); + +#[cfg(target_os = "macos")] +#[repr(C)] +pub struct MlxAsrResultFfi { + pub text: SRString, + pub success: bool, + pub error: SRString, +} + +#[derive(Debug, Clone)] +pub struct AsrResult { + pub text: String, + pub success: bool, + pub error: String, +} + +#[cfg(target_os = "macos")] +pub fn qwen_asr_init(model_source: &str) -> bool { + let source = SRString::from(model_source); + unsafe { _mlx_qwen_asr_init(&source) } +} + +#[cfg(not(target_os = "macos"))] +pub fn qwen_asr_init(_model_source: &str) -> bool { + false +} + +#[cfg(target_os = "macos")] +pub fn qwen_asr_transcribe_file(audio_path: &str) -> AsrResult { + let path = SRString::from(audio_path); + let result = unsafe { _mlx_qwen_asr_transcribe_file(&path) }; + AsrResult { + text: result.text.to_string(), + success: result.success, + error: result.error.to_string(), + } +} + +#[cfg(not(target_os = "macos"))] +pub fn qwen_asr_transcribe_file(_audio_path: &str) -> AsrResult { + AsrResult { + text: String::new(), + success: false, + error: "mlx ASR is only available on macOS".to_string(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(target_os = "macos")] + #[test] + fn test_qwen_asr_with_hypr_data_audio() { + let home = std::env::var("HOME").expect("HOME must be set"); + let local_model_path = format!("{home}/Downloads/model.safetensors"); + assert!( + std::path::Path::new(&local_model_path).exists(), + "expected local model at {}", + local_model_path + ); + + assert!( + qwen_asr_init(&local_model_path), + "failed to initialize qwen asr model" + ); + + let result = qwen_asr_transcribe_file(hypr_data::english_1::AUDIO_PATH); + assert!(result.success, "asr failed: {}", result.error); + assert!( + !result.text.trim().is_empty(), + "transcription output is unexpectedly empty" + ); + } +} diff --git a/crates/mlx/swift-lib/Package.resolved b/crates/mlx/swift-lib/Package.resolved new file mode 100644 index 0000000000..fb9f79c1a9 --- /dev/null +++ b/crates/mlx/swift-lib/Package.resolved @@ -0,0 +1,283 @@ +{ + "pins" : [ + { + "identity" : "async-http-client", + "kind" : "remoteSourceControl", + "location" : "https://github.com/swift-server/async-http-client", + "state" : { + "revision" : "52ed9d172018e31f2dbb46f0d4f58d66e13c281e", + "version" : "1.31.0" + } + }, + { + "identity" : "eventsource", + "kind" : "remoteSourceControl", + "location" : "https://github.com/mattt/EventSource.git", + "state" : { + "revision" : "ca2a9d90cbe49e09b92f4b6ebd922c03ebea51d0", + "version" : "1.3.0" + } + }, + { + "identity" : "mlx-audio-swift", + "kind" : "remoteSourceControl", + "location" : "https://github.com/Blaizzy/mlx-audio-swift.git", + "state" : { + "branch" : "main", + "revision" : "cc3b3880be05caf908970729e15ec209d018f06d" + } + }, + { + "identity" : "mlx-swift", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ml-explore/mlx-swift.git", + "state" : { + "revision" : "6ba4827fb82c97d012eec9ab4b2de21f85c3b33d", + "version" : "0.30.6" + } + }, + { + "identity" : "mlx-swift-lm", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ml-explore/mlx-swift-lm.git", + "state" : { + "revision" : "360c5052b81cc154b04ee0933597a4ad6db4b8ae", + "version" : "2.30.3" + } + }, + { + "identity" : "swift-algorithms", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-algorithms.git", + "state" : { + "revision" : "87e50f483c54e6efd60e885f7f5aa946cee68023", + "version" : "1.2.1" + } + }, + { + "identity" : "swift-asn1", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-asn1.git", + "state" : { + "revision" : "810496cf121e525d660cd0ea89a758740476b85f", + "version" : "1.5.1" + } + }, + { + "identity" : "swift-async-algorithms", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-async-algorithms.git", + "state" : { + "revision" : "2971dd5d9f6e0515664b01044826bcea16e59fac", + "version" : "1.1.2" + } + }, + { + "identity" : "swift-atomics", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-atomics.git", + "state" : { + "revision" : "b601256eab081c0f92f059e12818ac1d4f178ff7", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-certificates", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-certificates.git", + "state" : { + "revision" : "24ccdeeeed4dfaae7955fcac9dbf5489ed4f1a25", + "version" : "1.18.0" + } + }, + { + "identity" : "swift-collections", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-collections.git", + "state" : { + "revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-configuration", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-configuration.git", + "state" : { + "revision" : "b4768bd68d8a6fb356bd372cb41905046244fcae", + "version" : "1.0.2" + } + }, + { + "identity" : "swift-crypto", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-crypto.git", + "state" : { + "revision" : "6f70fa9eab24c1fd982af18c281c4525d05e3095", + "version" : "4.2.0" + } + }, + { + "identity" : "swift-distributed-tracing", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-distributed-tracing.git", + "state" : { + "revision" : "baa932c1336f7894145cbaafcd34ce2dd0b77c97", + "version" : "1.3.1" + } + }, + { + "identity" : "swift-http-structured-headers", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-http-structured-headers.git", + "state" : { + "revision" : "76d7627bd88b47bf5a0f8497dd244885960dde0b", + "version" : "1.6.0" + } + }, + { + "identity" : "swift-http-types", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-http-types.git", + "state" : { + "revision" : "45eb0224913ea070ec4fba17291b9e7ecf4749ca", + "version" : "1.5.1" + } + }, + { + "identity" : "swift-huggingface", + "kind" : "remoteSourceControl", + "location" : "https://github.com/huggingface/swift-huggingface.git", + "state" : { + "revision" : "0cafd982bbd09e61485cb28fcbb3a4f0ad9e8b0c", + "version" : "0.7.0" + } + }, + { + "identity" : "swift-jinja", + "kind" : "remoteSourceControl", + "location" : "https://github.com/huggingface/swift-jinja.git", + "state" : { + "revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0", + "version" : "2.3.1" + } + }, + { + "identity" : "swift-log", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-log.git", + "state" : { + "revision" : "2778fd4e5a12a8aaa30a3ee8285f4ce54c5f3181", + "version" : "1.9.1" + } + }, + { + "identity" : "swift-nio", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio.git", + "state" : { + "revision" : "9b92dcd5c22ae17016ad867852e0850f1f9f93ed", + "version" : "2.94.1" + } + }, + { + "identity" : "swift-nio-extras", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio-extras.git", + "state" : { + "revision" : "3df009d563dc9f21a5c85b33d8c2e34d2e4f8c3b", + "version" : "1.32.1" + } + }, + { + "identity" : "swift-nio-http2", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio-http2.git", + "state" : { + "revision" : "979f431f1f1e75eb61562440cb2862a70d791d3d", + "version" : "1.39.1" + } + }, + { + "identity" : "swift-nio-ssl", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio-ssl.git", + "state" : { + "revision" : "173cc69a058623525a58ae6710e2f5727c663793", + "version" : "2.36.0" + } + }, + { + "identity" : "swift-nio-transport-services", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio-transport-services.git", + "state" : { + "revision" : "60c3e187154421171721c1a38e800b390680fb5d", + "version" : "1.26.0" + } + }, + { + "identity" : "swift-numerics", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-numerics", + "state" : { + "revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2", + "version" : "1.1.1" + } + }, + { + "identity" : "swift-rs", + "kind" : "remoteSourceControl", + "location" : "https://github.com/Brendonovich/swift-rs", + "state" : { + "revision" : "01980f981bc642a6da382cc0788f18fdd4cde6df" + } + }, + { + "identity" : "swift-service-context", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-service-context.git", + "state" : { + "revision" : "1983448fefc717a2bc2ebde5490fe99873c5b8a6", + "version" : "1.2.1" + } + }, + { + "identity" : "swift-service-lifecycle", + "kind" : "remoteSourceControl", + "location" : "https://github.com/swift-server/swift-service-lifecycle", + "state" : { + "revision" : "1de37290c0ab3c5a96028e0f02911b672fd42348", + "version" : "2.9.1" + } + }, + { + "identity" : "swift-system", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-system", + "state" : { + "revision" : "7c6ad0fc39d0763e0b699210e4124afd5041c5df", + "version" : "1.6.4" + } + }, + { + "identity" : "swift-transformers", + "kind" : "remoteSourceControl", + "location" : "https://github.com/huggingface/swift-transformers.git", + "state" : { + "revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0", + "version" : "1.1.6" + } + }, + { + "identity" : "swift-xet", + "kind" : "remoteSourceControl", + "location" : "https://github.com/mattt/swift-xet.git", + "state" : { + "revision" : "299ed22a2c8df8e91b3d0c6d929d96e37ee88911", + "version" : "0.2.2" + } + } + ], + "version" : 2 +} diff --git a/crates/mlx/swift-lib/Package.swift b/crates/mlx/swift-lib/Package.swift new file mode 100644 index 0000000000..a7fedc2344 --- /dev/null +++ b/crates/mlx/swift-lib/Package.swift @@ -0,0 +1,33 @@ +// swift-tools-version:5.9 + +import PackageDescription + +let package = Package( + name: "hypr-mlx-swift", + platforms: [.macOS("15.0")], + products: [ + .library( + name: "hypr-mlx-swift", + type: .static, + targets: ["swift-lib"]) + ], + dependencies: [ + .package( + url: "https://github.com/Brendonovich/swift-rs", + revision: "01980f981bc642a6da382cc0788f18fdd4cde6df"), + .package( + url: "https://github.com/Blaizzy/mlx-audio-swift.git", + branch: "main"), + ], + targets: [ + .target( + name: "swift-lib", + dependencies: [ + .product(name: "SwiftRs", package: "swift-rs"), + .product(name: "MLXAudioCore", package: "mlx-audio-swift"), + .product(name: "MLXAudioSTT", package: "mlx-audio-swift"), + ], + path: "src" + ) + ] +) diff --git a/crates/mlx/swift-lib/src/EntryPoint.swift b/crates/mlx/swift-lib/src/EntryPoint.swift new file mode 100644 index 0000000000..69fa0515f6 --- /dev/null +++ b/crates/mlx/swift-lib/src/EntryPoint.swift @@ -0,0 +1,142 @@ +import Foundation +import MLXAudioCore +import MLXAudioSTT +import SwiftRs + +@_cdecl("_mlx_smoke_test") +public func _mlxSmokeTest() -> Bool { + return true +} + +public final class MlxAsrResult: NSObject { + public var text: SRString + public var success: Bool + public var error: SRString + + public init(text: String, success: Bool, error: String) { + self.text = SRString(text) + self.success = success + self.error = SRString(error) + } +} + +private let qwenRepoID = "mlx-community/Qwen3-ASR-0.6B-8bit" +private var qwenAsrModel: Qwen3ASRModel? + +private func expandTilde(_ path: String) -> String { + if path == "~" { + return NSHomeDirectory() + } + if path.hasPrefix("~/") { + return NSString(string: NSHomeDirectory()).appendingPathComponent(String(path.dropFirst(2))) + } + return path +} + +private func ensureCacheFromLocal(safetensorsPath: String, repoID: String) async throws { + let modelSubdir = repoID.replacingOccurrences(of: "/", with: "_") + let cacheDir = URL.cachesDirectory + .appendingPathComponent("mlx-audio") + .appendingPathComponent(modelSubdir) + + let fm = FileManager.default + try fm.createDirectory(at: cacheDir, withIntermediateDirectories: true) + + let linkPath = cacheDir.appendingPathComponent("model.safetensors").path + if !fm.fileExists(atPath: linkPath) { + try fm.createSymbolicLink(atPath: linkPath, withDestinationPath: safetensorsPath) + } + + let configPath = cacheDir.appendingPathComponent("config.json") + guard !fm.fileExists(atPath: configPath.path) else { return } + + let metadataFiles = [ + "config.json", + "tokenizer_config.json", + "vocab.json", + "merges.txt", + "generation_config.json", + ] + + for filename in metadataFiles { + let filePath = cacheDir.appendingPathComponent(filename) + if fm.fileExists(atPath: filePath.path) { continue } + + guard let url = URL(string: "https://huggingface.co/\(repoID)/resolve/main/\(filename)") else { + continue + } + let (data, response) = try await URLSession.shared.data(from: url) + if let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode == 200 { + try data.write(to: filePath) + } + } +} + +@_cdecl("_mlx_qwen_asr_init") +public func _mlxQwenAsrInit(modelSource: SRString) -> Bool { + let semaphore = DispatchSemaphore(value: 0) + var success = false + + Task { + do { + let rawSource = modelSource.toString().trimmingCharacters(in: .whitespacesAndNewlines) + let source = rawSource.isEmpty ? qwenRepoID : rawSource + let expandedSource = expandTilde(source) + + if source.hasSuffix(".safetensors") && FileManager.default.fileExists(atPath: expandedSource) { + try await ensureCacheFromLocal(safetensorsPath: expandedSource, repoID: qwenRepoID) + } + + qwenAsrModel = try await Qwen3ASRModel.fromPretrained(qwenRepoID) + success = true + } catch { + print("mlx init error: \(error)") + qwenAsrModel = nil + success = false + } + semaphore.signal() + } + + semaphore.wait() + return success +} + +@_cdecl("_mlx_qwen_asr_transcribe_file") +public func _mlxQwenAsrTranscribeFile(audioPath: SRString) -> MlxAsrResult { + let semaphore = DispatchSemaphore(value: 0) + var output = MlxAsrResult(text: "", success: false, error: "unknown error") + + Task { + do { + guard let model = qwenAsrModel else { + output = MlxAsrResult(text: "", success: false, error: "model is not initialized") + semaphore.signal() + return + } + + let path = expandTilde(audioPath.toString()) + let url = URL(fileURLWithPath: path) + + let (sampleRate, audioData) = try loadAudioArray(from: url) + if Int(sampleRate) != model.sampleRate { + output = MlxAsrResult( + text: "", + success: false, + error: "unsupported sample rate \(sampleRate), expected \(model.sampleRate)" + ) + semaphore.signal() + return + } + + let result = model.generate(audio: audioData) + output = MlxAsrResult(text: result.text, success: true, error: "") + } catch { + output = MlxAsrResult(text: "", success: false, error: String(describing: error)) + } + + semaphore.signal() + } + + semaphore.wait() + return output +}