Skip to content

Commit 2fd1c69

Browse files
committed
Expose mlx_fast_rope_offset_array
1 parent e413553 commit 2fd1c69

File tree

3 files changed

+115
-0
lines changed

3 files changed

+115
-0
lines changed

Source/MLX/MLXFast.swift

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,37 @@ public enum MLXFast {
3737
return MLXArray(result)
3838
}
3939

40+
/// Optimized implementation of `NN.RoPE` with array offset for batched inference.
41+
///
42+
/// This overload accepts an array offset, allowing different position offsets for each
43+
/// sequence in a batch. The offset can be a scalar array or a vector with length
44+
/// matching the batch size.
45+
///
46+
/// - Parameters:
47+
/// - array: input array
48+
/// - dimensions: The feature dimensions to be rotated. If the input feature is larger
49+
/// than dims then the rest is left unchanged.
50+
/// - traditional: If `true` choose the traditional implementation which is slightly less efficient.
51+
/// - base: The base used to compute angular frequency for each dimension in the positional encodings.
52+
/// - scale: The scale used to scale the positions.
53+
/// - offset: The position offset as an array. Can be a scalar or a vector of offsets for each batch element.
54+
/// - freqs: Optional frequencies to use with RoPE.
55+
/// - stream: stream or device to evaluate on
56+
/// - Returns: The input with rotary positional encoding applied.
57+
public static func RoPE(
58+
_ array: MLXArray, dimensions: Int, traditional: Bool, base: Float?, scale: Float,
59+
offset: MLXArray,
60+
freqs: MLXArray? = nil, stream: StreamOrDevice = .default
61+
) -> MLXArray {
62+
var result = mlx_array_new()
63+
let base = mlx_optional_float(value: base ?? 0, has_value: base != nil)
64+
mlx_fast_rope_offset_array(
65+
&result,
66+
array.ctx, Int32(dimensions), traditional, base, scale, offset.ctx,
67+
(freqs ?? .mlxNone).ctx, stream.ctx)
68+
return MLXArray(result)
69+
}
70+
4071
/// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V`
4172
///
4273
/// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150).
@@ -245,6 +276,19 @@ public func RoPE(
245276
offset: offset, freqs: freqs, stream: stream)
246277
}
247278

279+
/// Optimized implementation of `NN.RoPE` with array offset for batched inference.
280+
///
281+
/// > Note: `MLXNN.RoPE` uses this implementation internally.
282+
public func RoPE(
283+
_ array: MLXArray, dimensions: Int, traditional: Bool, base: Float?, scale: Float,
284+
offset: MLXArray,
285+
freqs: MLXArray? = nil, stream: StreamOrDevice = .default
286+
) -> MLXArray {
287+
return MLXFast.RoPE(
288+
array, dimensions: dimensions, traditional: traditional, base: base, scale: scale,
289+
offset: offset, freqs: freqs, stream: stream)
290+
}
291+
248292
/// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V`
249293
///
250294
/// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150).

Source/MLXNN/PositionalEncoding.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,21 @@ final public class RoPE: Module, UnaryLayer {
4545
return x.reshaped(shape)
4646
}
4747

48+
/// Evaluate with array offset for batched inference with different positions per sequence.
49+
///
50+
/// - Parameters:
51+
/// - x: input array
52+
/// - offset: position offset array (scalar or vector with length matching batch size)
53+
/// - Returns: the input with rotary positional encoding applied
54+
public func callAsFunction(_ x: MLXArray, offset: MLXArray) -> MLXArray {
55+
let shape = x.shape
56+
var x = x.reshaped(-1, x.dim(-2), x.dim(-1))
57+
x = MLXFast.RoPE(
58+
x, dimensions: dimensions, traditional: traditional, base: base, scale: scale,
59+
offset: offset)
60+
return x.reshaped(shape)
61+
}
62+
4863
/// Evaluate with `offset` of `0`.
4964
public func callAsFunction(_ x: MLXArray) -> MLXArray {
5065
callAsFunction(x, offset: 0)

Tests/MLXTests/IntegrationTests.swift

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6996,6 +6996,62 @@ class MLXIntegrationTests: XCTestCase {
69966996
accuracy: 2.3360192871093752)
69976997
}
69986998

6999+
func testRoPEArrayOffset() {
7000+
MLXRandom.seed(42)
7001+
let batch = MLXRandom.uniform(0.0 ..< 1.0, [3, 8, 16])
7002+
XCTAssertEqual(batch.shape, [3, 8, 16])
7003+
7004+
let offsets = MLXArray([50, 20, 0])
7005+
7006+
// Test MLXFast.RoPE with array offset
7007+
let result = MLXFast.RoPE(
7008+
batch, dimensions: 8, traditional: false,
7009+
base: 10000, scale: 1.0, offset: offsets)
7010+
XCTAssertEqual(result.shape, [3, 8, 16])
7011+
7012+
// Verify against individual scalar offset calls
7013+
for i in 0..<3 {
7014+
let single = batch[i].expandedDimensions(axis: 0)
7015+
let offsetValue = [50, 20, 0][i]
7016+
let expected = MLXFast.RoPE(
7017+
single, dimensions: 8, traditional: false,
7018+
base: 10000, scale: 1.0, offset: offsetValue)
7019+
XCTAssert(allClose(result[i], expected[0]).all().item())
7020+
}
7021+
}
7022+
7023+
func testRoPEArrayOffsetModule() {
7024+
MLXRandom.seed(123)
7025+
let batch = MLXRandom.uniform(0.0 ..< 1.0, [3, 8, 16])
7026+
let offsets = MLXArray([10, 5, 0])
7027+
7028+
let rope = RoPE(dimensions: 8)
7029+
7030+
// Test MLXNN RoPE module with array offset
7031+
let result = rope(batch, offset: offsets)
7032+
XCTAssertEqual(result.shape, [3, 8, 16])
7033+
7034+
// Verify shape and dtype preserved
7035+
XCTAssertEqual(result.dtype, batch.dtype)
7036+
}
7037+
7038+
func testRoPEScalarArrayOffset() {
7039+
MLXRandom.seed(99)
7040+
let a = MLXRandom.uniform(0.0 ..< 1.0, [2, 8, 16])
7041+
7042+
// Scalar array offset should work the same as int offset
7043+
let scalarOffset = MLXArray(5)
7044+
let resultArray = MLXFast.RoPE(
7045+
a, dimensions: 8, traditional: false,
7046+
base: 10000, scale: 1.0, offset: scalarOffset)
7047+
7048+
let resultInt = MLXFast.RoPE(
7049+
a, dimensions: 8, traditional: false,
7050+
base: 10000, scale: 1.0, offset: 5)
7051+
7052+
XCTAssert(allClose(resultArray, resultInt).all().item())
7053+
}
7054+
69997055
func testSinusoidalPositionalEncoding() {
70007056
MLXRandom.seed(226)
70017057
let a = MLXRandom.uniform(0.0 ..< 1.0, [2, 8, 16])

0 commit comments

Comments
 (0)