Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ permissions:
env:
# Bump to invalidate every cache entry without source surgery (e.g., after a
# known-bad cache or an Xcode toolchain upgrade we want to flush manually).
CACHE_SALT: v2-vmlx-5b84387
CACHE_SALT: v3-pr-cold-deriveddata
# Pin Xcode so cache keys are stable across runner image bumps. When you
# need to upgrade, change here AND in setup-xcode below.
XCODE_VERSION: "26.4.1"
Expand Down
158 changes: 127 additions & 31 deletions Packages/OsaurusCore/Services/ModelRuntime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ public actor ModelRuntime {
/// by vmlx-swift-lm's `OSAURUS-INTEGRATION.md` (Coordinator-owned KV
/// sizing) plus osaurus's per-environment disk-path config. See the
/// file-level comment for rationale on each knob.
private nonisolated static func buildCacheCoordinatorConfig(
nonisolated private static func buildCacheCoordinatorConfig(
modelName: String
) -> CacheCoordinatorConfig {
let diskCacheDir = OsaurusPaths.diskKVCache()
Expand Down Expand Up @@ -478,7 +478,7 @@ public actor ModelRuntime {
/// Best-effort writability probe for the disk cache directory. Uses a
/// tempfile round-trip rather than `FileManager.isWritableFile(atPath:)`
/// so symlinks / ACLs / out-of-disk conditions are caught.
private nonisolated static func isDirectoryWritable(_ url: URL) -> Bool {
nonisolated private static func isDirectoryWritable(_ url: URL) -> Bool {
let probe = url.appendingPathComponent(".osaurus_write_probe_\(UUID().uuidString)")
do {
try Data().write(to: probe)
Expand Down Expand Up @@ -536,15 +536,13 @@ public actor ModelRuntime {
nonisolated static func isKnownHybridModel(name: String) -> Bool {
let lower = name.lowercased()
// Mamba+Attn+MoE — Nemotron-3 / Cascade-2 / Hyper.
if lower.contains("nemotron-3") || lower.contains("nemotron-cascade")
|| lower.contains("nemotron_h")
{
let nemotronMarkers = ["nemotron-3", "nemotron-cascade", "nemotron_h"]
if nemotronMarkers.contains(where: lower.contains) {
return true
}
// Qwen 3.5 / 3.6 MoE family (qwen3_5_moe model_type) covers Holo3 too.
if lower.contains("qwen3.5") || lower.contains("qwen3.6") || lower.contains("holo3")
|| lower.contains("holo-3")
{
let qwenMoEMarkers = ["qwen3.5", "qwen3.6", "holo3", "holo-3"]
if qwenMoEMarkers.contains(where: lower.contains) {
return true
}
// MiniMax M2 / M2.7 — gated SSM in some layers.
Expand Down Expand Up @@ -687,8 +685,12 @@ public actor ModelRuntime {
var accumulated = ""
var pendingTools: [ServiceToolInvocation] = []
let augmented = ModelRuntime.applyJSONMode(messages, jsonMode: parameters.jsonMode)
let templateMessages = ModelRuntime.applyLocalTemplateCompatibility(
augmented,
modelName: modelName
)
let events = try await generateEventStream(
chatBuilder: { ModelRuntime.mapOpenAIChatToMLX(augmented) },
chatBuilder: { ModelRuntime.mapOpenAIChatToMLX(templateMessages) },
parameters: parameters,
stopSequences: stopSequences,
tools: tools,
Expand Down Expand Up @@ -731,8 +733,12 @@ public actor ModelRuntime {
modelName: String
) async throws -> AsyncThrowingStream<String, Error> {
let augmented = ModelRuntime.applyJSONMode(messages, jsonMode: parameters.jsonMode)
let templateMessages = ModelRuntime.applyLocalTemplateCompatibility(
augmented,
modelName: modelName
)
let events = try await generateEventStream(
chatBuilder: { ModelRuntime.mapOpenAIChatToMLX(augmented) },
chatBuilder: { ModelRuntime.mapOpenAIChatToMLX(templateMessages) },
parameters: parameters,
stopSequences: stopSequences,
tools: tools,
Expand Down Expand Up @@ -837,7 +843,7 @@ public actor ModelRuntime {

/// Computes a deterministic hash from system content and tool names.
/// Used by the HTTP API to expose a prefix_hash field in responses.
public nonisolated static func computePrefixHash(
nonisolated public static func computePrefixHash(
systemContent: String,
toolNames: [String]
) -> String {
Expand Down Expand Up @@ -934,6 +940,98 @@ public actor ModelRuntime {
return out
}

/// Local chat-template compatibility shims.
///
/// Gemma-family MLX templates have had uneven `system` role handling
/// across shipped variants. For those local models only, mirror the
/// system instructions into the first user turn and remove the standalone
/// system role so the model sees the same instructions even when the
/// template ignores `role == system`.
nonisolated static func applyLocalTemplateCompatibility(
_ messages: [ChatMessage],
modelName: String
) -> [ChatMessage] {
guard ModelFamilyGuidance.family(for: modelName) == .googleGemma else {
return messages
}

let systemText =
messages
.compactMap { message -> String? in
guard message.role == "system",
let content = message.content?.trimmingCharacters(in: .whitespacesAndNewlines),
!content.isEmpty
else { return nil }
return content
}
.joined(separator: "\n\n")
guard !systemText.isEmpty else { return messages }

let systemPreamble = """
System instructions:
\(systemText)
"""
var adapted = messages.filter { $0.role != "system" }
guard let firstUserIndex = adapted.firstIndex(where: { $0.role == "user" }) else {
adapted.insert(ChatMessage(role: "user", content: systemPreamble), at: 0)
return adapted
}

let user = adapted[firstUserIndex]
let mergedContent = mergeSystemPreamble(systemPreamble, withUserContent: user.content)
let mergedParts = prependSystemPreamble(
systemPreamble,
mergedContent: mergedContent,
to: user.contentParts
)
adapted[firstUserIndex] = ChatMessage(
role: user.role,
content: mergedContent,
contentParts: mergedParts,
tool_calls: user.tool_calls,
tool_call_id: user.tool_call_id,
reasoning_content: user.reasoning_content
)
return adapted
}

nonisolated private static func mergeSystemPreamble(
_ preamble: String,
withUserContent content: String?
) -> String {
guard let content,
!content.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
else {
return preamble
}

return """
\(preamble)

User message:
\(content)
"""
}

nonisolated private static func prependSystemPreamble(
_ preamble: String,
mergedContent: String,
to contentParts: [MessageContentPart]?
) -> [MessageContentPart]? {
guard var parts = contentParts else { return nil }
guard !parts.isEmpty else { return [.text(preamble)] }

for index in parts.indices {
if case .text = parts[index] {
parts[index] = .text(mergedContent)
return parts
}
}

parts.insert(.text(preamble), at: 0)
return parts
}

/// Map OpenAI-format chat messages to MLX `Chat.Message`s.
///
/// Assistant tool calls and tool-role responses flow through
Expand Down Expand Up @@ -1050,11 +1148,11 @@ public actor ModelRuntime {
if urlString.hasPrefix("data:image/") {
if let commaIndex = urlString.firstIndex(of: ",") {
let base64String = String(urlString[urlString.index(after: commaIndex)...])
if let imageData = Data(base64Encoded: base64String),
guard
let imageData = Data(base64Encoded: base64String),
let ciImage = CIImage(data: imageData)
{
sources.append(.ciImage(ciImage))
}
else { continue }
sources.append(.ciImage(ciImage))
}
} else if let url = URL(string: urlString) {
sources.append(.url(url))
Expand Down Expand Up @@ -1187,14 +1285,13 @@ public actor ModelRuntime {
)
else { return 0 }
var total: Int64 = 0
for case let fileURL as URL in enumerator {
if fileURL.pathExtension.lowercased() == "safetensors" {
if let attrs = try? fm.attributesOfItem(atPath: fileURL.path),
let size = attrs[.size] as? NSNumber
{
total += size.int64Value
}
}
for case let fileURL as URL in enumerator
where fileURL.pathExtension.lowercased() == "safetensors" {
guard
let attrs = try? fm.attributesOfItem(atPath: fileURL.path),
let size = attrs[.size] as? NSNumber
else { continue }
total += size.int64Value
}
return total
}
Expand Down Expand Up @@ -1315,9 +1412,8 @@ public actor ModelRuntime {
do {
try validateJANGTQSidecarIfRequired(at: directory, name: name)
return
} catch let error as NSError
where error.domain == "ModelRuntime" && error.code == 2
{
} catch let error as NSError {
guard error.domain == "ModelRuntime", error.code == 2 else { throw error }
// Forward mismatch: stamp says mxtq, sidecar missing. Try one HF fetch.
// Build the candidate id list: canonical `<org>/<repo>` first,
// then — for flat-layout local ids that aren't directly mappable
Expand Down Expand Up @@ -1546,7 +1642,7 @@ public actor ModelRuntime {
/// global, and so each test's override is naturally scoped to its own
/// task tree via `withValue { ... }`.
@TaskLocal
static var sidecarFetcherForTests: (@Sendable (_ url: URL, _ dest: URL) async throws -> Void)? = nil
static var sidecarFetcherForTests: (@Sendable (_ url: URL, _ dest: URL) async throws -> Void)?

/// Pure, testable sibling of `findLocalDirectory` that takes the root
/// explicitly. Exposed at module scope so the symlink-resolution
Expand All @@ -1569,10 +1665,10 @@ public actor ModelRuntime {
// that discovery path already resolves symlinks per-level, so keeping
// the two symmetric here closes the asymmetry.
let resolved = url.resolvingSymlinksInPath()
let hasConfig = fm.fileExists(atPath: resolved.appendingPathComponent("config.json").path)
if let items = try? fm.contentsOfDirectory(at: resolved, includingPropertiesForKeys: nil),
hasConfig && items.contains(where: { $0.pathExtension == "safetensors" })
{
guard fm.fileExists(atPath: resolved.appendingPathComponent("config.json").path),
let items = try? fm.contentsOfDirectory(at: resolved, includingPropertiesForKeys: nil)
else { return nil }
if items.contains(where: { $0.pathExtension == "safetensors" }) {
return resolved
}
return nil
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//
// LocalTemplateCompatibilityTests.swift
// osaurusTests
//
// Regression coverage for local MLX chat-template compatibility shims.
//

import Foundation
import Testing

@testable import OsaurusCore

struct LocalTemplateCompatibilityTests {

@Test func nonGemmaKeepsSystemRoleUntouched() {
let messages = [
ChatMessage(role: "system", content: "Your name is Gerald."),
ChatMessage(role: "user", content: "Who are you?"),
]

let adapted = ModelRuntime.applyLocalTemplateCompatibility(
messages,
modelName: "qwen3-32b-mlx"
)

#expect(adapted.map(\.role) == ["system", "user"])
#expect(adapted[0].content == "Your name is Gerald.")
#expect(adapted[1].content == "Who are you?")
}

@Test func gemmaMovesSystemInstructionsIntoFirstUserTurn() {
let messages = [
ChatMessage(role: "system", content: "Your name is Gerald."),
ChatMessage(role: "user", content: "Who are you?"),
]

let adapted = ModelRuntime.applyLocalTemplateCompatibility(
messages,
modelName: "OsaurusAI/gemma-4-E4B-it-8bit"
)

#expect(adapted.map(\.role) == ["user"])
#expect(adapted[0].content?.contains("System instructions:") == true)
#expect(adapted[0].content?.contains("Your name is Gerald.") == true)
#expect(adapted[0].content?.contains("User message:") == true)
#expect(adapted[0].content?.contains("Who are you?") == true)
}

@Test func gemmaPreservesUserImagePartsWhileAddingInstructions() {
let imageData = Data([0x89, 0x50, 0x4E, 0x47])
let messages = [
ChatMessage(role: "system", content: "Describe images tersely."),
ChatMessage(role: "user", text: "What is in this image?", imageData: [imageData]),
]

let adapted = ModelRuntime.applyLocalTemplateCompatibility(
messages,
modelName: "gemma-4-26b-a4b-it"
)

#expect(adapted.count == 1)
#expect(adapted[0].role == "user")
#expect(adapted[0].content?.contains("Describe images tersely.") == true)
#expect(adapted[0].content?.contains("What is in this image?") == true)

let parts = adapted[0].contentParts ?? []
#expect(parts.count == 2)
if case .text(let text) = parts[0] {
#expect(text.contains("System instructions:"))
#expect(text.contains("What is in this image?"))
} else {
Issue.record("first content part should be text")
}

if case .imageUrl(let url, _) = parts[1] {
#expect(url.hasPrefix("data:image/png;base64,"))
} else {
Issue.record("second content part should preserve the image")
}
}
}
Loading