Skip to content
This repository has been archived by the owner on Jul 1, 2023. It is now read-only.

Accurate Tensor.device for TFEager backends #1077

Closed
wants to merge 12 commits into from
36 changes: 35 additions & 1 deletion Sources/TensorFlow/Core/Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

import CTensorFlow
import Foundation
import _Differentiation

infix operator .==: ComparisonPrecedence
Expand Down Expand Up @@ -768,7 +769,40 @@ extension Tensor: Differentiable & EuclideanDifferentiable where Scalar: TensorF
case .XLA:
return xlaTensor.device
case .TF_EAGER:
return Device.defaultTFEager
var kind: Device.Kind = .CPU
var ordinal = 0
let status = _ExecutionContext.global.status

// Find out what the underlying libraries think the default is.
if let cString = TFE_TensorHandleDeviceName(handle._cTensorHandle, status) {
checkOk(status)
let tfDeviceName = String(cString: cString)

// Parse type and ordinal from a string with the expected syntax:
// /job:localhost/replica:0/task:0/device:CPU:0
let pattern = ".+device:(.+):(\\d+)$"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe break this String -> Device out as a separate function?
Also, I'm concerned that the string parsing will be expensive. This function is called a lot. (whenever there is a scalar constant or anything like that). I think it would be best to see if you can add your own TFE_TensorHandleDevice_Type and TFE_TensorHandleDevice_Id. Some benchmarking results might work instead.

let regex = try! NSRegularExpression(pattern: pattern)
let nsrange = NSRange(tfDeviceName.startIndex..., in: tfDeviceName)
if let match = regex.firstMatch(in: tfDeviceName, range: nsrange) {
if let kindRange = Range(match.range(at: 1), in: tfDeviceName) {
switch String(tfDeviceName[kindRange]).uppercased() {
case "CPU":
kind = .CPU
case "GPU":
kind = .GPU
case "TPU":
kind = .TPU
default:
kind = .CPU
}
}
if let ordinalRange = Range(match.range(at: 2), in: tfDeviceName) {
ordinal = Int(tfDeviceName[ordinalRange]) ?? 0
}
}
}

return Device(kind: kind, ordinal: ordinal, backend: .TF_EAGER)
}
}
}
Expand Down