Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Source/MLX/ErrorHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ public func withError<R>(_ body: () async throws -> R) async throws -> R {
}

/// Error type for caught errors during ``withError(_:)-6g4wn``.
public enum MLXError: LocalizedError {
public enum MLXError: LocalizedError, Sendable, Equatable {
case caught(String)

public var errorDescription: String? {
Expand Down
44 changes: 44 additions & 0 deletions Source/MLX/MLXFast.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,37 @@ public enum MLXFast {
return MLXArray(result)
}

/// Optimized implementation of `NN.RoPE` with array offset for batched inference.
///
/// This overload accepts an array offset, allowing different position offsets for each
/// sequence in a batch. The offset can be a scalar array or a vector with length
/// matching the batch size.
///
/// - Parameters:
/// - array: input array
/// - dimensions: The feature dimensions to be rotated. If the input feature is larger
/// than dims then the rest is left unchanged.
/// - traditional: If `true` choose the traditional implementation which is slightly less efficient.
/// - base: The base used to compute angular frequency for each dimension in the positional encodings.
/// - scale: The scale used to scale the positions.
/// - offset: The position offset as an array. Can be a scalar or a vector of offsets for each batch element.
/// - freqs: Optional frequencies to use with RoPE.
/// - stream: stream or device to evaluate on
/// - Returns: The input with rotary positional encoding applied.
public static func RoPE(
_ array: MLXArray, dimensions: Int, traditional: Bool, base: Float?, scale: Float,
offset: MLXArray,
freqs: MLXArray? = nil, stream: StreamOrDevice = .default
) -> MLXArray {
var result = mlx_array_new()
let base = mlx_optional_float(value: base ?? 0, has_value: base != nil)
mlx_fast_rope_offset_array(
&result,
array.ctx, Int32(dimensions), traditional, base, scale, offset.ctx,
(freqs ?? .mlxNone).ctx, stream.ctx)
return MLXArray(result)
}

/// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V`
///
/// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150).
Expand Down Expand Up @@ -245,6 +276,19 @@ public func RoPE(
offset: offset, freqs: freqs, stream: stream)
}

/// Optimized implementation of `NN.RoPE` with array offset for batched inference.
///
/// > Note: `MLXNN.RoPE` uses this implementation internally.
public func RoPE(
_ array: MLXArray, dimensions: Int, traditional: Bool, base: Float?, scale: Float,
offset: MLXArray,
freqs: MLXArray? = nil, stream: StreamOrDevice = .default
) -> MLXArray {
return MLXFast.RoPE(
array, dimensions: dimensions, traditional: traditional, base: base, scale: scale,
offset: offset, freqs: freqs, stream: stream)
}

/// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V`
///
/// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150).
Expand Down
15 changes: 15 additions & 0 deletions Source/MLXNN/PositionalEncoding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,21 @@ final public class RoPE: Module, UnaryLayer {
return x.reshaped(shape)
}

/// Evaluate with array offset for batched inference with different positions per sequence.
///
/// - Parameters:
/// - x: input array
/// - offset: position offset array (scalar or vector with length matching batch size)
/// - Returns: the input with rotary positional encoding applied
public func callAsFunction(_ x: MLXArray, offset: MLXArray) -> MLXArray {
let shape = x.shape
var x = x.reshaped(-1, x.dim(-2), x.dim(-1))
x = MLXFast.RoPE(
x, dimensions: dimensions, traditional: traditional, base: base, scale: scale,
offset: offset)
return x.reshaped(shape)
}

/// Evaluate with `offset` of `0`.
public func callAsFunction(_ x: MLXArray) -> MLXArray {
callAsFunction(x, offset: 0)
Expand Down
56 changes: 56 additions & 0 deletions Tests/MLXTests/IntegrationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6996,6 +6996,62 @@ class MLXIntegrationTests: XCTestCase {
accuracy: 2.3360192871093752)
}

func testRoPEArrayOffset() {
MLXRandom.seed(42)
let batch = MLXRandom.uniform(0.0 ..< 1.0, [3, 8, 16])
XCTAssertEqual(batch.shape, [3, 8, 16])

let offsets = MLXArray([50, 20, 0])

// Test MLXFast.RoPE with array offset
let result = MLXFast.RoPE(
batch, dimensions: 8, traditional: false,
base: 10000, scale: 1.0, offset: offsets)
XCTAssertEqual(result.shape, [3, 8, 16])

// Verify against individual scalar offset calls
for i in 0..<3 {
let single = batch[i].expandedDimensions(axis: 0)
let offsetValue = [50, 20, 0][i]
let expected = MLXFast.RoPE(
single, dimensions: 8, traditional: false,
base: 10000, scale: 1.0, offset: offsetValue)
XCTAssert(allClose(result[i], expected[0]).all().item())
}
}

func testRoPEArrayOffsetModule() {
MLXRandom.seed(123)
let batch = MLXRandom.uniform(0.0 ..< 1.0, [3, 8, 16])
let offsets = MLXArray([10, 5, 0])

let rope = RoPE(dimensions: 8)

// Test MLXNN RoPE module with array offset
let result = rope(batch, offset: offsets)
XCTAssertEqual(result.shape, [3, 8, 16])

// Verify shape and dtype preserved
XCTAssertEqual(result.dtype, batch.dtype)
}

func testRoPEScalarArrayOffset() {
MLXRandom.seed(99)
let a = MLXRandom.uniform(0.0 ..< 1.0, [2, 8, 16])

// Scalar array offset should work the same as int offset
let scalarOffset = MLXArray(5)
let resultArray = MLXFast.RoPE(
a, dimensions: 8, traditional: false,
base: 10000, scale: 1.0, offset: scalarOffset)

let resultInt = MLXFast.RoPE(
a, dimensions: 8, traditional: false,
base: 10000, scale: 1.0, offset: 5)

XCTAssert(allClose(resultArray, resultInt).all().item())
}

func testSinusoidalPositionalEncoding() {
MLXRandom.seed(226)
let a = MLXRandom.uniform(0.0 ..< 1.0, [2, 8, 16])
Expand Down