Skip to content
This repository has been archived by the owner on Jul 1, 2023. It is now read-only.

Accurate Tensor.device for TFEager backends #1077

Closed
wants to merge 12 commits into from
49 changes: 47 additions & 2 deletions Sources/x10/swift_bindings/Device.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import CTensorFlow
import Foundation
@_implementationOnly import x10_device_wrapper

extension DeviceType {
Expand Down Expand Up @@ -127,6 +129,8 @@ public struct Device {
#endif
}

static var defaultOrdinal: Int { 0 }

/// The default `Device`.
public static var `default`: Device {
switch defaultBackend {
Expand All @@ -146,8 +150,49 @@ public struct Device {

/// The current TF Eager device.
public static var defaultTFEager: Device {
// TODO: Pull this from withDevice() {} mechanism?
return Device(kind: .CPU, ordinal: 0, backend: .TF_EAGER)
// Create a dummy tensor on any TFEager device.
var kind: Kind = .CPU
var ordinal = defaultOrdinal
let device = Device(kind: kind, ordinal: defaultOrdinal, backend: .TF_EAGER)
let tensor = Tensor<Float>(zeros: [1], on: device)
let handle = tensor.handle._cTensorHandle
let status = TF_NewStatus()

// Find out what the underlying libraries think the default is.
if let cString = TFE_TensorHandleDeviceName(handle, status) {
checkOk(status)

// TODO: What's the best way to deallocate this memory?
texasmichelle marked this conversation as resolved.
Show resolved Hide resolved
// defer { DeleteString(str) }
let tfDeviceName = String(cString: cString)

// Parse type and ordinal from a string with the expected syntax:
// /job:localhost/replica:0/task:0/device:CPU:0
let pattern = ".+device:(.+):(\\d+)$"
let regex = try! NSRegularExpression(pattern: pattern)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This string parsing looks expensive. Might want to double check that this doesn't happen too much.

let nsrange = NSRange(tfDeviceName.startIndex..., in: tfDeviceName)
if let match = regex.firstMatch(in: tfDeviceName, range: nsrange) {
if let kindRange = Range(match.range(at: 1), in: tfDeviceName) {
switch String(tfDeviceName[kindRange]).uppercased() {
case "CPU":
kind = .CPU
case "GPU":
kind = .GPU
case "TPU":
kind = .TPU
case "REMOTE_TPU":
texasmichelle marked this conversation as resolved.
Show resolved Hide resolved
kind = .REMOTE_TPU
default:
kind = .CPU
}
}
if let ordinalRange = Range(match.range(at: 2), in: tfDeviceName) {
ordinal = Int(tfDeviceName[ordinalRange]) ?? defaultOrdinal
}
}
}

return Device(kind: kind, ordinal: ordinal, backend: .TF_EAGER)
}

/// An array of all devices.
Expand Down