diff --git a/testing/dependency_install_lib.sh b/testing/dependency_install_lib.sh index 9db1691815..801d7a3361 100644 --- a/testing/dependency_install_lib.sh +++ b/testing/dependency_install_lib.sh @@ -69,7 +69,9 @@ install_tensorflow() { PIP_FLAGS=${2-} # NB: tf-nightly pulls in other deps, like numpy, absl, and six, transitively. TF_VERSION_STR=$(find_good_tf_nightly_version_str $TF_NIGHTLY_PACKAGE) - python -m pip install $PIP_FLAGS $TF_NIGHTLY_PACKAGE==$TF_VERSION_STR + python -m pip install $PIP_FLAGS \ + $TF_NIGHTLY_PACKAGE==$TF_VERSION_STR \ + tf-keras-nightly } install_jax() {