Skip to content

Commit

Permalink
Eliminate use of tf1 in layers/util.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 579249332
  • Loading branch information
jburnim authored and tensorflower-gardener committed Nov 3, 2023
1 parent b079f83 commit 65d9d72
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tensorflow_probability/python/layers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import types
# Dependency imports
import numpy as np
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf

from tensorflow_probability.python import util as tfp_util
Expand All @@ -43,8 +42,8 @@

def default_loc_scale_fn(
is_singular=False,
loc_initializer=tf1.initializers.random_normal(stddev=0.1),
untransformed_scale_initializer=tf1.initializers.random_normal(
loc_initializer=tf_keras.initializers.RandomNormal(stddev=0.1),
untransformed_scale_initializer=tf_keras.initializers.RandomNormal(
mean=-3., stddev=0.1),
loc_regularizer=None,
untransformed_scale_regularizer=None,
Expand Down Expand Up @@ -124,8 +123,8 @@ def _fn(dtype, shape, name, trainable, add_variable_fn):

def default_mean_field_normal_fn(
is_singular=False,
loc_initializer=tf1.initializers.random_normal(stddev=0.1),
untransformed_scale_initializer=tf1.initializers.random_normal(
loc_initializer=tf_keras.initializers.RandomNormal(stddev=0.1),
untransformed_scale_initializer=tf_keras.initializers.RandomNormal(
mean=-3., stddev=0.1),
loc_regularizer=None,
untransformed_scale_regularizer=None,
Expand Down

0 comments on commit 65d9d72

Please sign in to comment.