Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:

mac_build_and_test:
needs: lint
if: github.repository == 'ml-explore/mlx-swift-lm'
if: always() && (github.repository == 'osaurus-ai/vmlx-swift-lm' || (github.repository == 'ml-explore/mlx-swift-lm' && needs.lint.result == 'success'))
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v6
Expand All @@ -89,12 +89,12 @@ jobs:
xcrun --show-sdk-build-version
swift --version
rm -rf ~/Library/Developer/Xcode/DerivedData/*
xcodebuild build-for-testing -scheme mlx-swift-lm-Package -destination 'platform=macOS'
xcodebuild build-for-testing -scheme vmlx-swift-lm-Package -destination 'platform=macOS'

- name: Run Tests (Xcode, macOS)
shell: sh
run: |
xcrun xctest ~/Library/Developer/Xcode/DerivedData/mlx-swift-lm-*/Build/Products/Debug/MLXLMTests.xctest
xcrun xctest ~/Library/Developer/Xcode/DerivedData/vmlx-swift-lm-*/Build/Products/Debug/MLXLMTests.xctest

- name: Upload test results
if: failure()
Expand Down
67 changes: 57 additions & 10 deletions Libraries/MLXLMCommon/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,19 @@ public enum Chat {
_ content: String, images: [UserInput.Image] = [],
videos: [UserInput.Video] = [], audios: [UserInput.Audio] = []
) -> Self {
Self(role: .system, content: content,
images: images, videos: videos, audios: audios)
Self(
role: .system, content: content,
images: images, videos: videos, audios: audios)
}

/// Build an assistant message with plain text content.
public static func assistant(
_ content: String, images: [UserInput.Image] = [],
videos: [UserInput.Video] = [], audios: [UserInput.Audio] = []
) -> Self {
Self(role: .assistant, content: content,
images: images, videos: videos, audios: audios)
Self(
role: .assistant, content: content,
images: images, videos: videos, audios: audios)
}

/// Build an assistant message that issued one or more tool
Expand All @@ -108,8 +110,9 @@ public enum Chat {
_ content: String, images: [UserInput.Image] = [],
videos: [UserInput.Video] = [], audios: [UserInput.Audio] = []
) -> Self {
Self(role: .user, content: content,
images: images, videos: videos, audios: audios)
Self(
role: .user, content: content,
images: images, videos: videos, audios: audios)
}

/// Build a tool-role message carrying the result of a tool call.
Expand Down Expand Up @@ -239,16 +242,60 @@ public struct DefaultMessageGenerator: MessageGenerator {
}
}

/// Implementation of ``MessageGenerator`` that produces the default
/// dict shape but omits `system` roles.
/// Implementation of ``MessageGenerator`` for templates that cannot accept
/// `system` roles.
///
/// System instructions are preserved by folding them into the first user
/// message. This keeps unsupported roles out of the rendered chat template
/// without silently discarding host instructions.
public struct NoSystemMessageGenerator: MessageGenerator {
public init() {}

public func generate(messages: [Chat.Message]) -> [Message] {
messages
.filter { $0.role != .system }
foldSystemMessagesIntoFirstUser(messages)
.map { generate(message: $0) }
}

private func foldSystemMessagesIntoFirstUser(_ messages: [Chat.Message]) -> [Chat.Message] {
var systemTexts: [String] = []
var remaining: [Chat.Message] = []

for message in messages {
if message.role == .system {
if message.content.contains(where: { !$0.isWhitespace }) {
systemTexts.append(message.content)
}
} else {
remaining.append(message)
}
}

guard !systemTexts.isEmpty else { return remaining }

let systemPreamble = """
System instructions:
\(systemTexts.joined(separator: "\n\n"))
"""

guard let firstUserIndex = remaining.firstIndex(where: { $0.role == .user }) else {
remaining.insert(.user(systemPreamble), at: 0)
return remaining
}

var user = remaining[firstUserIndex]
if user.content.contains(where: { !$0.isWhitespace }) {
user.content = """
\(systemPreamble)

User message:
\(user.content)
"""
} else {
user.content = systemPreamble
}
remaining[firstUserIndex] = user
return remaining
}
}

// MARK: - Default dict construction
Expand Down
39 changes: 26 additions & 13 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
// The swift-tools-version declares the minimum version of Swift required to build this package.

import CompilerPluginSupport
import Foundation
import PackageDescription

let packageDirectory = URL(fileURLWithPath: #filePath).deletingLastPathComponent().path
let hasLocalRunBench = FileManager.default.fileExists(
atPath: "\(packageDirectory)/RunBench"
)

let package = Package(
name: "vmlx-swift-lm",
platforms: [
Expand Down Expand Up @@ -36,7 +42,9 @@ let package = Package(
targets: ["IntegrationTestHelpers"]),
],
dependencies: [
.package(url: "https://github.com/osaurus-ai/mlx-swift", revision: "0a56f9041d56b4b8161f67a6cbd540ae66efc9fd"),
.package(
url: "https://github.com/osaurus-ai/mlx-swift",
revision: "0a56f9041d56b4b8161f67a6cbd540ae66efc9fd"),
.package(url: "https://github.com/swiftlang/swift-syntax.git", from: "600.0.0-latest"),
// swift-transformers 1.0.0+ transitively uses huggingface/
// swift-jinja 2.x which already contains the three root-cause
Expand Down Expand Up @@ -135,18 +143,6 @@ let package = Package(
],
path: "CompileBench"
),
.executableTarget(
name: "RunBench",
dependencies: [
"MLXLMCommon",
"MLXLLM",
"MLXVLM",
"MLXHuggingFace",
.product(name: "MLX", package: "mlx-swift"),
.product(name: "Transformers", package: "swift-transformers"),
],
path: "RunBench"
),
.testTarget(
name: "MLXLMTests",
dependencies: [
Expand Down Expand Up @@ -185,6 +181,23 @@ let package = Package(
]
)

if hasLocalRunBench {
package.targets.append(
.executableTarget(
name: "RunBench",
dependencies: [
"MLXLMCommon",
"MLXLLM",
"MLXVLM",
"MLXHuggingFace",
.product(name: "MLX", package: "mlx-swift"),
.product(name: "Transformers", package: "swift-transformers"),
],
path: "RunBench"
)
)
}

if Context.environment["MLX_SWIFT_BUILD_DOC"] == "1"
|| Context.environment["SPI_GENERATE_DOCS"] == "1"
{
Expand Down
58 changes: 41 additions & 17 deletions Tests/MLXLMTests/ChatMessageToolCallTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,16 @@ struct ChatMessageToolCallTests {
@Test("multiple tool calls all get emitted with distinct ids")
func multipleToolCalls() {
let calls = [
ToolCall(function: .init(
name: "get_weather",
arguments: ["city": .string("NYC")]
)),
ToolCall(function: .init(
name: "get_time",
arguments: ["tz": .string("America/New_York")]
)),
ToolCall(
function: .init(
name: "get_weather",
arguments: ["city": .string("NYC")]
)),
ToolCall(
function: .init(
name: "get_time",
arguments: ["tz": .string("America/New_York")]
)),
]
let msg = Chat.Message.assistant("", toolCalls: calls)
let dict = defaultMessageDict(for: msg)
Expand All @@ -142,24 +144,46 @@ struct ChatMessageToolCallTests {

@Test("DefaultMessageGenerator passes tool_calls through")
func defaultGeneratorTransit() {
let call = ToolCall(function: .init(
name: "search", arguments: ["q": .string("swift")]))
let call = ToolCall(
function: .init(
name: "search", arguments: ["q": .string("swift")]))
let msg = Chat.Message.assistant("", toolCalls: [call])
let gen = DefaultMessageGenerator()
let dict = gen.generate(message: msg)
#expect(dict["tool_calls"] != nil)
}

@Test("NoSystemMessageGenerator drops system but preserves tool_calls")
func noSystemGeneratorPreservesToolCalls() {
let call = ToolCall(function: .init(
name: "f", arguments: [:]))
@Test("NoSystemMessageGenerator folds system into user and preserves tool_calls")
func noSystemGeneratorPreservesSystemAndToolCalls() {
let call = ToolCall(
function: .init(
name: "f", arguments: [:]))
let messages: [Chat.Message] = [
.system("ignored"),
.system("follow system instructions"),
.user("hello"),
.assistant("", toolCalls: [call]),
]
let out = NoSystemMessageGenerator().generate(messages: messages)
#expect(out.count == 1)
#expect(out.first?["tool_calls"] != nil)
#expect(out.count == 2)
#expect(out[0]["role"] as? String == "user")
#expect((out[0]["content"] as? String)?.contains("System instructions:") == true)
#expect((out[0]["content"] as? String)?.contains("follow system instructions") == true)
#expect((out[0]["content"] as? String)?.contains("User message:") == true)
#expect((out[0]["content"] as? String)?.contains("hello") == true)
#expect(out[1]["tool_calls"] != nil)
}

@Test("NoSystemMessageGenerator inserts user turn when only system is present")
func noSystemGeneratorPreservesSystemWithoutUser() {
let messages: [Chat.Message] = [
.system("system only"),
.assistant("already answered"),
]
let out = NoSystemMessageGenerator().generate(messages: messages)
#expect(out.count == 2)
#expect(out[0]["role"] as? String == "user")
#expect((out[0]["content"] as? String)?.contains("system only") == true)
#expect(out[1]["role"] as? String == "assistant")
#expect(out[1]["content"] as? String == "already answered")
}
}
64 changes: 39 additions & 25 deletions Tests/MLXLMTests/EvalTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,38 @@ import MLXNN
import MLXOptimizers
import XCTest

private actor EvalTestSamplingGate {
static let shared = EvalTestSamplingGate()

func sampleToken(vocabSize: Int, samplerId: Int) -> Int {
let logits = MLXRandom.normal([1, vocabSize])
return withRandomState(MLXRandom.RandomState(seed: UInt64(samplerId))) {
if samplerId % 2 == 0 {
return categorical(logits).item(Int.self)
} else {
return logits.argMax(axis: -1).item(Int.self)
}
}
}

func sampleSequence(samplerId: Int, samplesPerTask: Int) -> [Int] {
let logits = MLXArray.ones([1, 50])
var taskResults: [Int] = []
let sampler = CategoricalSampler(temperature: 1.0)

for sampleId in 0 ..< samplesPerTask {
let token = withRandomState(
MLXRandom.RandomState(seed: UInt64(samplerId * 1000 + sampleId))
) {
sampler.sample(logits: logits)
}
taskResults.append(token.item(Int.self))
}

return taskResults
}
}

public class EvalTests: XCTestCase {

func testLlamaEval() throws {
Expand Down Expand Up @@ -112,14 +144,8 @@ public class EvalTests: XCTestCase {

for samplerId in 0 ..< numSamplers {
group.addTask {
let logits = MLXRandom.normal([1, vocabSize])
return withRandomState(MLXRandom.RandomState(seed: UInt64(samplerId))) {
if samplerId % 2 == 0 {
return categorical(logits).item(Int.self)
} else {
return logits.argMax(axis: -1).item(Int.self)
}
}
await EvalTestSamplingGate.shared.sampleToken(
vocabSize: vocabSize, samplerId: samplerId)
}
}

Expand All @@ -146,7 +172,7 @@ public class EvalTests: XCTestCase {
quantize(model: model, groupSize: 64, bits: 4)
eval(model)

let prompt = MLXArray(Array(0..<20))[.newAxis, .ellipsis]
let prompt = MLXArray(Array(0 ..< 20))[.newAxis, .ellipsis]
let input = LMInput(text: .init(tokens: prompt))
let maxTokens = 100

Expand All @@ -156,7 +182,7 @@ public class EvalTests: XCTestCase {

let baselineStart = CFAbsoluteTimeGetCurrent()
var baselineCount = 0
while let _ = baselineIterator.next() {
while baselineIterator.next() != nil {
baselineCount += 1
}
let baselineElapsed = CFAbsoluteTimeGetCurrent() - baselineStart
Expand All @@ -175,7 +201,7 @@ public class EvalTests: XCTestCase {

let compiledStart = CFAbsoluteTimeGetCurrent()
var compiledCount = 0
while let _ = compiledIterator.next() {
while compiledIterator.next() != nil {
compiledCount += 1
}
let compiledElapsed = CFAbsoluteTimeGetCurrent() - compiledStart
Expand All @@ -199,20 +225,8 @@ public class EvalTests: XCTestCase {

for samplerId in 0 ..< numSamplers {
group.addTask {
let logits = MLXArray.ones([1, 50])
var taskResults: [Int] = []
let sampler = CategoricalSampler(temperature: 1.0)

for sampleId in 0 ..< samplesPerTask {
let token = withRandomState(
MLXRandom.RandomState(seed: UInt64(samplerId * 1000 + sampleId))
) {
return sampler.sample(logits: logits)
}
taskResults.append(token.item(Int.self))
}

return taskResults
await EvalTestSamplingGate.shared.sampleSequence(
samplerId: samplerId, samplesPerTask: samplesPerTask)
}
}

Expand Down