-
Notifications
You must be signed in to change notification settings - Fork 0
/
p29_sin_samples.py
68 lines (50 loc) · 2.31 KB
/
p29_sin_samples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import math
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
mid_units = 120
num_samples = 1000 # 样本少再多的神经元也不能得到很好的结果,如果样本质量不好,则需要调小对应的神经元个数避免过拟合学习到噪声点和离群点。数量、质量和分布
def get_tensors(lr=0.001): # 获得张量
x = tf.placeholder(tf.float32, [None], 'x') # [None]表示不确定的样本
z = tf.placeholder(tf.float32, [None, 2], 'z') # 实际输出
t = tf.expand_dims(x, 1) # [None, 1]
w1 = tf.get_variable('w1', [1, mid_units], tf.float32)
t = tf.matmul(t, w1) # [None, mid_units]
b1 = tf.get_variable('b1', [mid_units], tf.float32)
t += b1 # [None, mid_units]
t = tf.nn.relu(t)
w2 = tf.get_variable('w2', [mid_units, 2], tf.float32)
predict = tf.matmul(t, w2) # [None, 2]
# predict = tf.reshape(predict, [-1]) # [None]
loss = tf.reduce_mean(tf.square(predict - z)) # **2, === np.mean(), shape: []
opt = tf.train.AdamOptimizer(lr)
train_op = opt.minimize(loss)
return x, z, predict, train_op
def train(tensors, samples, session, epoches=3000):
x, z, _, train_op = tensors
xs, zs = samples
print('training is started!')
for _ in range(epoches):
session.run(train_op, {x: xs, z: zs})
print('training is finished!!!')
def predict(tensors, xs, session):
x, _, predict, _ = tensors
return session.run(predict, {x: xs})
def main():
# 训练样本
xs = np.arange(0, 2 * math.pi, 2 * math.pi / (num_samples - 1)) # range(4, 10+1, 2)
zs = [np.sin(xs) + np.cos(xs)] # [2, num_samples]
zs = np.transpose(zs) # [num_samples, 2]
plt.plot(xs, zs)
with tf.Session() as session:
tensors = get_tensors() # x, predict, train_op
session.run(tf.global_variables_initializer())
train(tensors, (xs, zs), session)
# 测试样本
xs = np.random.uniform(-math.pi / 2, 2.5 * math.pi, [600])
xs = np.sort(xs)
zs = predict(tensors, xs, session) # [None, 2]
plt.plot(xs, zs)
plt.show()
if __name__ == '__main__':
main()