Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Source/MLX/MLXArray+Indexing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ func updateSlice(
var strides = [Int32](repeating: 1, count: ndim)

// If it's just a simple slice, just do a slice update and return
if operations.count == 1, case let .slice(slice) = operations[0] {
if operations.count == 1, case .slice(let slice) = operations[0] {
let size = src.dim(0).int32
starts[0] = slice.start(size)
ends[0] = slice.end(size)
Expand Down
5 changes: 3 additions & 2 deletions Source/MLXNN/Module.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1582,8 +1582,9 @@ extension UpdateError: LocalizedError {
"Unable to collect modules from container: \(path.joined(separator: ".")) in \(modules.joined(separator: "."))"
case .mismatchedContainers(let base, let key):
return "Mismatched containers: \(base) \(key)"
case let .mismatchedSize(
path, modules, expectedShape: expectedShape, actualShape: actualShape):
case .mismatchedSize(
let
path, let modules, let expectedShape, let actualShape):
return
"Mismatched parameter \(path.joined(separator: ".")) in \(modules.joined(separator: ".")) shape. Actual \(actualShape), expected \(expectedShape)"
case .keyNotFound(let path, let modules):
Expand Down
8 changes: 5 additions & 3 deletions Tests/MLXTests/ModuleTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ class ModuleTests: XCTestCase {
verify: .all)
) { error in
guard let error = error as? UpdateError,
case let .keyNotFound(path, modules) = error
case .keyNotFound(let path, let modules) = error
else {
XCTFail("Expected to fail with UpdateError.keyNotFound, but got: \(error)")
return
Expand Down Expand Up @@ -586,8 +586,10 @@ class ModuleTests: XCTestCase {
verify: .all)
) { error in
guard let error = error as? UpdateError,
case let .mismatchedSize(
path, modules, expectedShape: expectedShape, actualShape: actualShape) =
case .mismatchedSize(
let
path, let modules, expectedShape: let expectedShape,
actualShape: let actualShape) =
error
else {
XCTFail("Expected to fail with UpdateError.mismatchedSize, but got: \(error)")
Expand Down