diff --git a/tensorflow_probability/python/layers/util.py b/tensorflow_probability/python/layers/util.py index 6f3010920b..c8b607f3c1 100644 --- a/tensorflow_probability/python/layers/util.py +++ b/tensorflow_probability/python/layers/util.py @@ -21,7 +21,6 @@ import types # Dependency imports import numpy as np -import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf from tensorflow_probability.python import util as tfp_util @@ -43,8 +42,8 @@ def default_loc_scale_fn( is_singular=False, - loc_initializer=tf1.initializers.random_normal(stddev=0.1), - untransformed_scale_initializer=tf1.initializers.random_normal( + loc_initializer=tf_keras.initializers.RandomNormal(stddev=0.1), + untransformed_scale_initializer=tf_keras.initializers.RandomNormal( mean=-3., stddev=0.1), loc_regularizer=None, untransformed_scale_regularizer=None, @@ -124,8 +123,8 @@ def _fn(dtype, shape, name, trainable, add_variable_fn): def default_mean_field_normal_fn( is_singular=False, - loc_initializer=tf1.initializers.random_normal(stddev=0.1), - untransformed_scale_initializer=tf1.initializers.random_normal( + loc_initializer=tf_keras.initializers.RandomNormal(stddev=0.1), + untransformed_scale_initializer=tf_keras.initializers.RandomNormal( mean=-3., stddev=0.1), loc_regularizer=None, untransformed_scale_regularizer=None,