Skip to content

Commit

Permalink
Add sqrt implementation that takes custom storage
Browse files Browse the repository at this point in the history
  • Loading branch information
alejandro-isaza committed May 3, 2018
1 parent 37c3760 commit bb3e9b4
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions Sources/Surge/Arithmetic.swift
Original file line number Diff line number Diff line change
Expand Up @@ -188,27 +188,43 @@ public func remainder<X: UnsafeMemoryAccessible, Y: UnsafeMemoryAccessible>(_ x:
///
/// - Warning: does not support memory stride (assumes stride is 1).
public func sqrt<C: UnsafeMemoryAccessible>(_ x: C) -> [Float] where C.Element == Float {
var results = [Float](repeating: 0.0, count: numericCast(x.count))
sqrt(x, into: &results)
return results
}

/// Elemen-wise square root with custom output storage.
///
/// - Warning: does not support memory stride (assumes stride is 1).
public func sqrt<MI: UnsafeMemoryAccessible, MO: UnsafeMutableMemoryAccessible>(_ x: MI, into results: inout MO) where MI.Element == Float, MO.Element == Float {
return x.withUnsafeMemory { xm in
precondition(xm.stride == 1, "\(#function) does not support strided memory access")
var results = [Float](repeating: 0.0, count: numericCast(xm.count))
results.withUnsafeMutableBufferPointer { bufferPointer in
vvsqrtf(bufferPointer.baseAddress!, xm.pointer, [numericCast(xm.count)])
results.withUnsafeMutableMemory { rm in
precondition(xm.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
precondition(rm.count >= xm.count, "`results` doesnt have enough capacity to store the results")
vvsqrtf(rm.pointer, xm.pointer, [numericCast(xm.count)])
}
return results
}
}

/// Elemen-wise square root.
///
/// - Warning: does not support memory stride (assumes stride is 1).
public func sqrt<C: UnsafeMemoryAccessible>(_ x: C) -> [Double] where C.Element == Double {
var results = [Double](repeating: 0.0, count: numericCast(x.count))
sqrt(x, into: &results)
return results
}

/// Elemen-wise square root with custom output storage.
///
/// - Warning: does not support memory stride (assumes stride is 1).
public func sqrt<MI: UnsafeMemoryAccessible, MO: UnsafeMutableMemoryAccessible>(_ x: MI, into results: inout MO) where MI.Element == Double, MO.Element == Double {
return x.withUnsafeMemory { xm in
precondition(xm.stride == 1, "\(#function) does not support strided memory access")
var results = [Double](repeating: 0.0, count: numericCast(xm.count))
results.withUnsafeMutableBufferPointer { bufferPointer in
vvsqrt(bufferPointer.baseAddress!, xm.pointer, [numericCast(xm.count)])
results.withUnsafeMutableMemory { rm in
precondition(xm.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
precondition(rm.count >= xm.count, "`results` doesnt have enough capacity to store the results")
vvsqrt(rm.pointer, xm.pointer, [numericCast(xm.count)])
}
return results
}
}

Expand Down

0 comments on commit bb3e9b4

Please sign in to comment.