Skip to content

Commit

Permalink
Upgrade tensorflow to 2.x (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
xuhdev authored Aug 13, 2021
1 parent 1ddbeb3 commit 8368dd9
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
language: python
python:
- 3.6
- 3.8
services:
- docker
install:
Expand Down
6 changes: 3 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
# limitations under the License.
#

FROM quay.io/codait/max-base:v1.4.0
FROM quay.io/codait/max-base:v1.5.1

COPY requirements.txt .

RUN pip install -r requirements.txt

COPY . .

EXPOSE 5000

CMD python app.py
2 changes: 1 addition & 1 deletion api/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from core.model import ModelWrapper
from maxfw.core import MAX_API, PredictAPI
from flask_restplus import fields
from flask_restx import fields
from werkzeug.datastructures import FileStorage
from config import DEFAULT_MODEL, MODELS

Expand Down
12 changes: 10 additions & 2 deletions core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from config import DEFAULT_MODEL_PATH, MODELS, MODEL_META_DATA as model_meta
from keras.models import load_model
import numpy as np
from sklearn.externals import joblib
import joblib

logging.basicConfig()
logger = logging.getLogger()
Expand All @@ -34,12 +34,19 @@ def load_array(input_data):
class SingleModelWrapper(object):

def __init__(self, model, path):
# The code was originally written for TF1 and doesn't work with eager mode.
tf.compat.v1.disable_eager_execution()
self.session = tf.compat.v1.Session()

self.model_name = model

# load model
model_path = '{}/{}_model'.format(path, model)
logger.info(model_path)
self.graph = tf.get_default_graph()
self.graph = tf.compat.v1.get_default_graph()
# See https://github.com/tensorflow/tensorflow/issues/28287#issuecomment-495005162
# We have to do this because we load 3 models in the process.
tf.compat.v1.keras.backend.set_session(self.session)
self.model = load_model(model_path)

# load scaler
Expand Down Expand Up @@ -96,6 +103,7 @@ def _rescale_preds(self, preds):
def predict(self, x):
reshaped_x = self._reshape_data(x)
with self.graph.as_default():
tf.compat.v1.keras.backend.set_session(self.session)
preds = self.model.predict(reshaped_x)
rescaled_preds = self._rescale_preds(preds)
return rescaled_preds
Expand Down
9 changes: 4 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
numpy==1.16.2
tensorflow==1.15.2
keras==2.2.4
scikit-learn==0.22.1
h5py==2.9.0
tensorflow==2.6.0
keras==2.6.0
scikit-learn==0.22.2
# numpy and h5py not specified here because tensorflow has specifies a major.minor version range for them

0 comments on commit 8368dd9

Please sign in to comment.