forked from brenns10/eecs440
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathw06.py
50 lines (38 loc) · 1.23 KB
/
w06.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""
Code to plot decision boundary of ANN.
Stephen Brennan, EECS 440 Written 6, 10/06/2015.
"""
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')
def sigmoid(X):
return (1 + np.exp(-X)) ** -1
def output(X, weights, out_weights):
hidden_outs = sigmoid(np.dot(weights, X.T))
return sigmoid(np.dot(hidden_outs.T, out_weights))
def plot_decision_boundary(wmin, wmax):
hidden_weights = np.random.uniform(wmin, wmax, (2, 2))
output_weights = np.random.uniform(wmin, wmax, 2)
x1 = np.arange(-5.0, 5.1, step=0.1)
x2 = np.arange(-5.0, 5.1, step=0.1)
X = np.transpose([np.tile(x1, len(x2)), np.repeat(x2, len(x1))])
y = output(X, hidden_weights, output_weights)
fig, ax = plt.subplots()
pos_idx = y > 0.5
pos = X[pos_idx]
neg = X[~pos_idx]
print(y)
print('%r positives, %r negatives' % (len(pos), len(neg)))
print(pos)
print(neg)
ax.scatter(*pos.T, color='b')
ax.scatter(*neg.T, color='r')
return ax
if __name__ == '__main__':
import sys
if len(sys.argv) < 2:
print('Need weight bounds.')
else:
bound = float(sys.argv[1])
ax = plot_decision_boundary(-bound, bound)
ax.figure.savefig(sys.argv[1] + '.pdf')