-
Notifications
You must be signed in to change notification settings - Fork 3
/
sample.py
36 lines (22 loc) · 1.1 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
import tensorflow as tf
def eps_greedy_sample(logits, eps):
'''
rand_num = tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32)
sample = tf.cond(rand_num < eps,
lambda: tf.random.uniform([tf.shape(logits)[0], 1], minval=0, maxval=tf.shape(logits)[1], dtype=tf.int32),
lambda: tf.multinomial(logits, 1, output_dtype=tf.int32))
'''
rand_num = tf.random_uniform([tf.shape(logits)[0], 1], minval=0, maxval=1, dtype=tf.float32)
mask = tf.cast(rand_num < eps, tf.int32)
rand_sample = tf.random_uniform([tf.shape(logits)[0], 1], minval=0, maxval=tf.shape(logits)[1], dtype=tf.int32)
greedy_sample = tf.multinomial(logits, 1, output_dtype=tf.int32)
return mask * rand_sample + (1 - mask) * greedy_sample
def gaussian_sample(mean, logvar):
raise NotImplementedError
if __name__ == '__main__':
logits = tf.log(tf.constant([[1.], [1.]]))
eps = tf.placeholder(tf.float32, ())
with tf.Session() as sess:
sample = sess.run(eps_greedy_sample(logits, eps), feed_dict={eps: 0.})
print(sample)