diff --git a/Source/MLX/ErrorHandler.swift b/Source/MLX/ErrorHandler.swift index 4d5f9235..98fa689b 100644 --- a/Source/MLX/ErrorHandler.swift +++ b/Source/MLX/ErrorHandler.swift @@ -214,7 +214,7 @@ public func withError(_ body: () async throws -> R) async throws -> R { } /// Error type for caught errors during ``withError(_:)-6g4wn``. -public enum MLXError: LocalizedError { +public enum MLXError: LocalizedError, Sendable, Equatable { case caught(String) public var errorDescription: String? { diff --git a/Source/MLX/MLXFast.swift b/Source/MLX/MLXFast.swift index f72b870f..ea5acc5f 100644 --- a/Source/MLX/MLXFast.swift +++ b/Source/MLX/MLXFast.swift @@ -37,6 +37,37 @@ public enum MLXFast { return MLXArray(result) } + /// Optimized implementation of `NN.RoPE` with array offset for batched inference. + /// + /// This overload accepts an array offset, allowing different position offsets for each + /// sequence in a batch. The offset can be a scalar array or a vector with length + /// matching the batch size. + /// + /// - Parameters: + /// - array: input array + /// - dimensions: The feature dimensions to be rotated. If the input feature is larger + /// than dims then the rest is left unchanged. + /// - traditional: If `true` choose the traditional implementation which is slightly less efficient. + /// - base: The base used to compute angular frequency for each dimension in the positional encodings. + /// - scale: The scale used to scale the positions. + /// - offset: The position offset as an array. Can be a scalar or a vector of offsets for each batch element. + /// - freqs: Optional frequencies to use with RoPE. + /// - stream: stream or device to evaluate on + /// - Returns: The input with rotary positional encoding applied. + public static func RoPE( + _ array: MLXArray, dimensions: Int, traditional: Bool, base: Float?, scale: Float, + offset: MLXArray, + freqs: MLXArray? = nil, stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + let base = mlx_optional_float(value: base ?? 0, has_value: base != nil) + mlx_fast_rope_offset_array( + &result, + array.ctx, Int32(dimensions), traditional, base, scale, offset.ctx, + (freqs ?? .mlxNone).ctx, stream.ctx) + return MLXArray(result) + } + /// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V` /// /// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150). @@ -245,6 +276,19 @@ public func RoPE( offset: offset, freqs: freqs, stream: stream) } +/// Optimized implementation of `NN.RoPE` with array offset for batched inference. +/// +/// > Note: `MLXNN.RoPE` uses this implementation internally. +public func RoPE( + _ array: MLXArray, dimensions: Int, traditional: Bool, base: Float?, scale: Float, + offset: MLXArray, + freqs: MLXArray? = nil, stream: StreamOrDevice = .default +) -> MLXArray { + return MLXFast.RoPE( + array, dimensions: dimensions, traditional: traditional, base: base, scale: scale, + offset: offset, freqs: freqs, stream: stream) +} + /// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V` /// /// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150). diff --git a/Source/MLXNN/PositionalEncoding.swift b/Source/MLXNN/PositionalEncoding.swift index c833e57f..7423a7dc 100644 --- a/Source/MLXNN/PositionalEncoding.swift +++ b/Source/MLXNN/PositionalEncoding.swift @@ -45,6 +45,21 @@ final public class RoPE: Module, UnaryLayer { return x.reshaped(shape) } + /// Evaluate with array offset for batched inference with different positions per sequence. + /// + /// - Parameters: + /// - x: input array + /// - offset: position offset array (scalar or vector with length matching batch size) + /// - Returns: the input with rotary positional encoding applied + public func callAsFunction(_ x: MLXArray, offset: MLXArray) -> MLXArray { + let shape = x.shape + var x = x.reshaped(-1, x.dim(-2), x.dim(-1)) + x = MLXFast.RoPE( + x, dimensions: dimensions, traditional: traditional, base: base, scale: scale, + offset: offset) + return x.reshaped(shape) + } + /// Evaluate with `offset` of `0`. public func callAsFunction(_ x: MLXArray) -> MLXArray { callAsFunction(x, offset: 0) diff --git a/Tests/MLXTests/IntegrationTests.swift b/Tests/MLXTests/IntegrationTests.swift index b975816d..0e2fde94 100644 --- a/Tests/MLXTests/IntegrationTests.swift +++ b/Tests/MLXTests/IntegrationTests.swift @@ -6996,6 +6996,62 @@ class MLXIntegrationTests: XCTestCase { accuracy: 2.3360192871093752) } + func testRoPEArrayOffset() { + MLXRandom.seed(42) + let batch = MLXRandom.uniform(0.0 ..< 1.0, [3, 8, 16]) + XCTAssertEqual(batch.shape, [3, 8, 16]) + + let offsets = MLXArray([50, 20, 0]) + + // Test MLXFast.RoPE with array offset + let result = MLXFast.RoPE( + batch, dimensions: 8, traditional: false, + base: 10000, scale: 1.0, offset: offsets) + XCTAssertEqual(result.shape, [3, 8, 16]) + + // Verify against individual scalar offset calls + for i in 0..<3 { + let single = batch[i].expandedDimensions(axis: 0) + let offsetValue = [50, 20, 0][i] + let expected = MLXFast.RoPE( + single, dimensions: 8, traditional: false, + base: 10000, scale: 1.0, offset: offsetValue) + XCTAssert(allClose(result[i], expected[0]).all().item()) + } + } + + func testRoPEArrayOffsetModule() { + MLXRandom.seed(123) + let batch = MLXRandom.uniform(0.0 ..< 1.0, [3, 8, 16]) + let offsets = MLXArray([10, 5, 0]) + + let rope = RoPE(dimensions: 8) + + // Test MLXNN RoPE module with array offset + let result = rope(batch, offset: offsets) + XCTAssertEqual(result.shape, [3, 8, 16]) + + // Verify shape and dtype preserved + XCTAssertEqual(result.dtype, batch.dtype) + } + + func testRoPEScalarArrayOffset() { + MLXRandom.seed(99) + let a = MLXRandom.uniform(0.0 ..< 1.0, [2, 8, 16]) + + // Scalar array offset should work the same as int offset + let scalarOffset = MLXArray(5) + let resultArray = MLXFast.RoPE( + a, dimensions: 8, traditional: false, + base: 10000, scale: 1.0, offset: scalarOffset) + + let resultInt = MLXFast.RoPE( + a, dimensions: 8, traditional: false, + base: 10000, scale: 1.0, offset: 5) + + XCTAssert(allClose(resultArray, resultInt).all().item()) + } + func testSinusoidalPositionalEncoding() { MLXRandom.seed(226) let a = MLXRandom.uniform(0.0 ..< 1.0, [2, 8, 16])