Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,6 @@ fastlane/test_output

iOSInjectionProject/
.swiftpm

# VS Code
.vscode/
46 changes: 37 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
cmake_minimum_required(VERSION 3.16)
project(MLXSwift LANGUAGES C CXX Swift)

# ----------------------------- Setup -----------------------------
# note: 1:1 mirror of MLX configs

set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

# ----------------------------- Configuration -----------------------------
# note: mirrors a subset of MLX options exactly (1:1 mapping)

option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_METAL "Build metal backend" ON)
# option(MLX_BUILD_CUDA "Build cuda backend" OFF)

# ----------------------------- Lib -----------------------------

include(FetchContent)
# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
if(POLICY CMP0135)
Expand All @@ -26,6 +45,13 @@ FetchContent_MakeAvailable(swift-numerics)

# MLX package
file(GLOB MLX-src ${CMAKE_CURRENT_LIST_DIR}/Source/MLX/*.swift)

# todo: add conditional logic for MLX_BUILD_CUDA (once implemented)
if(NOT MLX_BUILD_METAL)
list(REMOVE_ITEM MLX-src ${CMAKE_CURRENT_LIST_DIR}/Source/MLX/GPU.swift
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of the API in here has moved to memory.h and is no longer GPU specific:

int mlx_clear_cache(void);
int mlx_get_active_memory(size_t* res);
int mlx_get_cache_memory(size_t* res);
int mlx_get_memory_limit(size_t* res);
int mlx_get_peak_memory(size_t* res);
int mlx_reset_peak_memory(void);
int mlx_set_cache_limit(size_t* res, size_t limit);
int mlx_set_memory_limit(size_t* res, size_t limit);
int mlx_set_wired_limit(size_t* res, size_t limit);

we don't need to deal with it here but perhaps this needs some refactoring -- we can forward the "GPU" methods to these.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought down the road it could be nice to have GPU.swift (cross-platform) and GPU+Metal.swift (only Apple Silicon builds) files and later a GPU+CUDA.swift (cross-platform) file ... wasn't quite yet too sure how to split GPU.swift. If you believe it's ok to add into this PR, happy to do so.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be done here or left until later. I was actually thinking a Memory.swift like this:

Those are actually functions and properties currently on GPU but they have moved on the python side, so I think we would want to deprecate current ones and forward to the (moved) API.

I don't know if we would want a Metal.swift as it would collide with the framework. So yes, perhaps GPU+X and keep the GPU specific parts under the GPU type.

Anyway, I don't think this is critical and we could pick it up later.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can take a look and let you know (I am just starting out to learn this code base). Plus we need new CI for sure -- sorry forgot about this for a second.

${CMAKE_CURRENT_LIST_DIR}/Source/MLX/MLXArray+Metal.swift)
endif()

add_library(MLX STATIC ${MLX-src})
target_include_directories(MLX
PUBLIC ${CMAKE_CURRENT_LIST_DIR}/Source/Cmlx/include)
Expand Down Expand Up @@ -65,12 +91,14 @@ add_library(MLXLinalg STATIC ${MLXLinalg-src})
target_link_libraries(MLXLinalg PRIVATE MLX)

# Examples
add_executable(example1
${CMAKE_CURRENT_LIST_DIR}/Source/Examples/Example1.swift)
target_link_libraries(example1 PRIVATE MLX)
target_compile_options(example1 PRIVATE -parse-as-library)

add_executable(tutorial
${CMAKE_CURRENT_LIST_DIR}/Source/Examples/Tutorial.swift)
target_link_libraries(tutorial PRIVATE MLX)
target_compile_options(tutorial PRIVATE -parse-as-library)
if(MLX_BUILD_EXAMPLES)
add_executable(example1
${CMAKE_CURRENT_LIST_DIR}/Source/Examples/Example1.swift)
target_link_libraries(example1 PRIVATE MLX)
target_compile_options(example1 PRIVATE -parse-as-library)

add_executable(tutorial
${CMAKE_CURRENT_LIST_DIR}/Source/Examples/Tutorial.swift)
target_link_libraries(tutorial PRIVATE MLX)
target_compile_options(tutorial PRIVATE -parse-as-library)
endif()
47 changes: 45 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ from MLX Python.

## Installation

The ``MLX`` Swift package can be built and run from Xcode or SwiftPM. A CMake install is also provided.
The ``MLX`` Swift package can be built and run from Xcode or SwiftPM. A CMake installation is also provided, featuring a native Linux build option.

More details are in the [documentation](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/install).

Expand Down Expand Up @@ -95,14 +95,57 @@ brew install ninja
With CMake:

```shell
mkdir build
mkdir -p build
cd build
cmake .. -G Ninja
ninja
./example1
./tutorial
```

<details>
<summary>Expand Native Linux Build Instructions</summary>

#### (1) Install Dependencies

RHEL/Fedora:
```shell
sudo dnf install -y blas-devel lapack-devel openblas-devel clang llvm cmake make ninja
# Then install Swift by following the instructions at https://swift.org
```

Ubuntu/Debian:
```shell
sudo apt update;
sudo apt install -y libblas-dev liblapack-dev libopenblas-dev clang llvm cmake make ninja;
# Then install Swift by following the instructions at https://swift.org
```

Refer to [swift.org](https://www.swift.org/install/linux/) for installation options and instructions specific to your Linux distribution.


#### (2) Build + Run Examples

On Linux, the examples use the CPU backend by default.

Note: GPU+CUDA support is a work in progress for `mlx-swift` on Linux, but is available in the Python-based MLX.

Note: SwiftPM builds are not currently supported for native Linux targets.

```shell
mkdir -p build
cd build
cmake -DMLX_BUILD_METAL=OFF .. -G Ninja
ninja
./example1
./tutorial
```


</details>

</br>

## Contributing

Check out the [contribution guidelines](CONTRIBUTING.md) for more information
Expand Down
24 changes: 12 additions & 12 deletions Source/Examples/Example1.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ import MLX
@main
struct Example1 {
static func main() {
// create from stride (sequence of 2, 4, 6, 8)
let arr = MLXArray(stride(from: Int32(2), through: 8, by: 2), [2, 2])
let osName = ProcessInfo.processInfo.operatingSystemVersionString.lowercased()
let device: Device = osName.contains("linux") ? .cpu : .gpu
Stream.withNewDefaultStream(device: device) {
let arr = MLXArray(stride(from: Int32(2), through: 8, by: 2), [2, 2])

print(arr)
print(arr.dtype)
print(arr.shape)
print(arr.ndim)
print(arr.asType(.int64))
print(arr)
print(arr.dtype)
print(arr.shape)
print(arr.ndim)
print(arr.asType(.int64))

// print a row
print(arr[1])

// print a value
print(arr[0, 1].item(Int32.self))
print(arr[1])
print(arr[0, 1].item(Int32.self))
}
}
}
12 changes: 8 additions & 4 deletions Source/Examples/Tutorial.swift
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,14 @@ struct Tutorial {

assert(df2dx2.item() == Float(2))
}

static func main() {
scalarBasics()
arrayBasics()
automaticDifferentiation()
let osName = ProcessInfo.processInfo.operatingSystemVersionString.lowercased()
let device: Device = osName.contains("linux") ? .cpu : .gpu
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we need to do something inside the default properties along these lines so that the calling process would get the default for the platform? Though I suppose this might change once this build supports CUDA

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, and I like your suggestion to just wait this one out and see how the CUDA integration turns out to be for Linux.


Stream.withNewDefaultStream(device: device) {
scalarBasics()
arrayBasics()
automaticDifferentiation()
}
}
}
3 changes: 2 additions & 1 deletion Source/MLX/IO.swift
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,11 @@ private func new_mlx_io_vtable_dataIO() -> mlx_io_vtable {
default:
break
}

} read: { ptr, data, n in
let state = Unmanaged<IOState>.fromOpaque(ptr!).takeUnretainedValue()

if n + state.offset <= state.data.count {
guard let data = data else { return }
_ = state.data.withUnsafeBytes { buffer in
memcpy(data, buffer.baseAddress!.advanced(by: state.offset), n)
}
Expand All @@ -229,6 +229,7 @@ private func new_mlx_io_vtable_dataIO() -> mlx_io_vtable {
let state = Unmanaged<IOState>.fromOpaque(ptr!).takeUnretainedValue()

if n + offset <= state.data.count {
guard let data = data else { return }
_ = state.data.withUnsafeBytes { buffer in
memcpy(data, buffer.baseAddress!.advanced(by: offset), n)
}
Expand Down
27 changes: 0 additions & 27 deletions Source/MLX/MLXArray+Bytes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import Cmlx
import Foundation
import Metal

// MARK: - Backing / Bytes

Expand Down Expand Up @@ -275,32 +274,6 @@ extension MLXArray {
.data
}

/// Return the contents as a Metal buffer in the native ``dtype``.
///
/// > If you can guarantee the lifetime of the ``MLXArray`` will exceed the MTLBuffer and that
/// the array will not be mutated (e.g. using indexing or other means) it is possible to pass `noCopy: true`
/// to reference the backing bytes.
///
/// ### See Also
/// - <doc:conversion>
/// - ``asArray(_:)``
/// - ``asData(access:)``
public func asMTLBuffer(device: any MTLDevice, noCopy: Bool = false) -> (any MTLBuffer)? {
self.eval()

if noCopy && self.contiguousToDimension() == 0 {
// the backing is contiguous, we can provide a wrapper
// for the contents without a copy (if requested)
let source = UnsafeMutableRawPointer(mutating: mlx_array_data_uint8(self.ctx))!
return device.makeBuffer(bytesNoCopy: source, length: self.nbytes)
} else {
let data = asDataCopy()
return data.data.withUnsafeBytes { ptr in
device.makeBuffer(bytes: ptr.baseAddress!, length: ptr.count)
}
}
}

}

/// Return the strides for contiguous memory
Expand Down
37 changes: 37 additions & 0 deletions Source/MLX/MLXArray+Metal.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright © 2025 Apple Inc.

import Cmlx
import Foundation
import Metal

// MARK: - Metal

extension MLXArray {

/// Return the contents as a Metal buffer in the native ``dtype``.
///
/// > If you can guarantee the lifetime of the ``MLXArray`` will exceed the MTLBuffer and that
/// the array will not be mutated (e.g. using indexing or other means) it is possible to pass `noCopy: true`
/// to reference the backing bytes.
///
/// ### See Also
/// - <doc:conversion>
/// - ``asArray(_:)``
/// - ``asData(access:)``
public func asMTLBuffer(device: any MTLDevice, noCopy: Bool = false) -> (any MTLBuffer)? {
self.eval()

if noCopy && self.contiguousToDimension() == 0 {
// the backing is contiguous, we can provide a wrapper
// for the contents without a copy (if requested)
let source = UnsafeMutableRawPointer(mutating: mlx_array_data_uint8(self.ctx))!
return device.makeBuffer(bytesNoCopy: source, length: self.nbytes)
} else {
let data = asDataCopy()
return data.data.withUnsafeBytes { ptr in
device.makeBuffer(bytes: ptr.baseAddress!, length: ptr.count)
}
}
}

}
1 change: 0 additions & 1 deletion Source/MLX/MLXArray.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import Cmlx
import Foundation
import Metal
import Numerics

public final class MLXArray {
Expand Down
2 changes: 1 addition & 1 deletion Source/MLX/State.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ extension MLXRandom {

/// Initialize the RandomState with a seed based on the current time.
public init() {
let now = mach_approximate_time()
let now = DispatchTime.now().uptimeNanoseconds
state = MLXRandom.key(now)
}

Expand Down
2 changes: 1 addition & 1 deletion Source/MLXNN/Module.swift
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ open class Module {
}
p._updateInternal(newArray)

case (.value(.parameters(let p)), .none):
case (.value(.parameters), .none):
if Self.parameterIsValid(key) {
throw UpdateError.keyNotFound(path: path, modules: modulePath)
} else {
Expand Down