diff --git a/Sources/TensorFlow/Core/Tensor.swift b/Sources/TensorFlow/Core/Tensor.swift index 2c45ca7b5..5bb8ad71e 100644 --- a/Sources/TensorFlow/Core/Tensor.swift +++ b/Sources/TensorFlow/Core/Tensor.swift @@ -13,6 +13,7 @@ // limitations under the License. import CTensorFlow +import Foundation import _Differentiation infix operator .==: ComparisonPrecedence @@ -768,7 +769,40 @@ extension Tensor: Differentiable & EuclideanDifferentiable where Scalar: TensorF case .XLA: return xlaTensor.device case .TF_EAGER: - return Device.defaultTFEager + var kind: Device.Kind = .CPU + var ordinal = 0 + let status = _ExecutionContext.global.status + + // Find out what the underlying libraries think the default is. + if let cString = TFE_TensorHandleDeviceName(handle._cTensorHandle, status) { + checkOk(status) + let tfDeviceName = String(cString: cString) + + // Parse type and ordinal from a string with the expected syntax: + // /job:localhost/replica:0/task:0/device:CPU:0 + let pattern = ".+device:(.+):(\\d+)$" + let regex = try! NSRegularExpression(pattern: pattern) + let nsrange = NSRange(tfDeviceName.startIndex..., in: tfDeviceName) + if let match = regex.firstMatch(in: tfDeviceName, range: nsrange) { + if let kindRange = Range(match.range(at: 1), in: tfDeviceName) { + switch String(tfDeviceName[kindRange]).uppercased() { + case "CPU": + kind = .CPU + case "GPU": + kind = .GPU + case "TPU": + kind = .TPU + default: + kind = .CPU + } + } + if let ordinalRange = Range(match.range(at: 2), in: tfDeviceName) { + ordinal = Int(tfDeviceName[ordinalRange]) ?? 0 + } + } + } + + return Device(kind: kind, ordinal: ordinal, backend: .TF_EAGER) } } }