Skip to content

Commit 77e6304

Browse files
authored
Upgrade to TensorFlow 2 (#40)
Changes our uses of TensorFlow APIs to be compatible with TensorFlow 2, mostly automatically through the tf_upgrade_v2 command. This change has been validated to work with all capsules in the Capsule Zoo as well as our internally developed capsules at Aotu.ai. This aims to make OpenVisionCapsules TensorFlow 2 compatible without any features disabled, but does not attempt to upgrade our use of old v1 APIs. While this upgrade requires no changes to the API of vcap or vcap-utils, it is still a breaking change for any capsule that has to use TensorFlow directly. Fixes #39
1 parent 5506c2c commit 77e6304

File tree

9 files changed

+14
-14
lines changed

9 files changed

+14
-14
lines changed

vcap/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"scipy==1.4.1",
2525
"scikit-learn==0.22.2",
2626
"numpy>=1.16,<2",
27-
"tensorflow-gpu==1.15.4",
27+
"tensorflow~=2.5.0",
2828
],
2929
extras_require={
3030
"tests": test_packages,

vcap/vcap/device_mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def get_all_devices() -> List[str]:
2525
#
2626
# TODO: Use tf.config.list_physical_devices in TF 2.1
2727

28-
with tf.Session():
28+
with tf.compat.v1.Session():
2929
all_devices = device_lib.list_local_devices()
3030

3131
# Get the device names and remove duplicates, just in case...

vcap_utils/vcap_utils/backends/crowd_density.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class CrowdDensityCounter(BaseTFBackend):
1414

1515
def __init__(self, model_bytes,
1616
device: str=None,
17-
session_config: tf.ConfigProto=None):
17+
session_config: tf.compat.v1.ConfigProto=None):
1818
"""
1919
:param model_bytes: Model file data, likely a loaded *.pb file
2020
:param device: The device to run the model on

vcap_utils/vcap_utils/backends/depth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class DepthPredictor(BaseTFBackend):
1919

2020
def __init__(self, model_bytes,
2121
device: str=None,
22-
session_config: tf.ConfigProto=None):
22+
session_config: tf.compat.v1.ConfigProto=None):
2323
"""
2424
:param model_bytes: Model file data, likely a loaded *.pb file
2525
:param device: The device to run the model on

vcap_utils/vcap_utils/backends/load_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
def parse_tf_model_bytes(model_bytes,
77
device: str = None,
8-
session_config: tf.ConfigProto = None):
8+
session_config: tf.compat.v1.ConfigProto = None):
99
"""
1010
1111
:param model_bytes: The bytes of the model to load
@@ -18,7 +18,7 @@ def parse_tf_model_bytes(model_bytes,
1818
detection_graph = tf.Graph()
1919
with detection_graph.as_default():
2020
# Load a (frozen) Tensorflow model from memory
21-
graph_def = tf.GraphDef()
21+
graph_def = tf.compat.v1.GraphDef()
2222
graph_def.ParseFromString(model_bytes)
2323

2424
with tf.device(device):
@@ -29,16 +29,16 @@ def parse_tf_model_bytes(model_bytes,
2929
name='')
3030

3131
if session_config is None:
32-
session_config = tf.ConfigProto()
32+
session_config = tf.compat.v1.ConfigProto()
3333

3434
if device is not None:
3535
# allow_soft_placement lets us remap GPU only ops to GPU, and doesn't
3636
# crash for non-gpu only ops (it will place those on CPU, instead)
3737
session_config.allow_soft_placement = True
3838

3939
# Create a session for later use
40-
persistent_sess = tf.Session(graph=detection_graph,
41-
config=session_config)
40+
persistent_sess = tf.compat.v1.Session(graph=detection_graph,
41+
config=session_config)
4242

4343
return detection_graph, persistent_sess
4444

vcap_utils/vcap_utils/backends/openface_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class OpenFaceEncoder(BaseEncoderBackend, BaseTFBackend):
2222

2323
def __init__(self, model_bytes, model_name,
2424
device: str = None,
25-
session_config: tf.ConfigProto = None):
25+
session_config: tf.compat.v1.ConfigProto = None):
2626
"""
2727
:param model_bytes: Model file bytes, a loaded *.pb file
2828
:param model_name: The name of the model in order to load correct

vcap_utils/vcap_utils/backends/segmentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Segmenter(BaseTFBackend):
1616

1717
def __init__(self, model_bytes, metadata_bytes,
1818
device: str = None,
19-
session_config: tf.ConfigProto = None):
19+
session_config: tf.compat.v1.ConfigProto = None):
2020
"""
2121
:param model_bytes: Model file data, likely a loaded *.pb file
2222
:param metadata_bytes: The dataset metadata file data, likely named

vcap_utils/vcap_utils/backends/tf_image_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class TFImageClassifier(BaseTFBackend):
1616
def __init__(self, model_bytes, metadata_bytes, model_name,
1717
device: str = None,
18-
session_config: tf.ConfigProto = None):
18+
session_config: tf.compat.v1.ConfigProto = None):
1919
"""
2020
:param model_bytes: Loaded model data, likely from a *.pb file
2121
:param metadata_bytes: Loaded dataset metadata, likely from a file
@@ -41,7 +41,7 @@ def __init__(self, model_bytes, metadata_bytes, model_name,
4141
# Create the input node to the graph, with preprocessing built-in
4242
with self.graph.as_default():
4343
# Create a new input node for images of various sizes
44-
self.input_node = tf.placeholder(
44+
self.input_node = tf.compat.v1.placeholder(
4545
dtype=tf.float32,
4646
shape=[None, self.config.img_size, self.config.img_size, 3])
4747

vcap_utils/vcap_utils/backends/tf_object_detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class TFObjectDetector(BaseTFBackend):
1818
def __init__(self, model_bytes, metadata_bytes,
1919
confidence_thresh=0.05,
2020
device: str = None,
21-
session_config: tf.ConfigProto = None):
21+
session_config: tf.compat.v1.ConfigProto = None):
2222
"""
2323
:param model_bytes: Model file data, likely a loaded *.pb file
2424
:param metadata_bytes: The dataset metadata file data, likely named

0 commit comments

Comments
 (0)