-
Notifications
You must be signed in to change notification settings - Fork 3
/
predict_stock_price.py
72 lines (62 loc) · 2.17 KB
/
predict_stock_price.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
df = pd.read_csv('AAPL.csv')
dates = list(range(0,int(len(df))))
prices = df['Close']
#Impute missing values (NaN)
prices[np.isnan(prices)] = np.median(prices[~np.isnan(prices)])
#Plot Original Data
plt.plot(df['Close'], label='Close Price history')
plt.title('Linear Regression | Time vs. Price (Original Data)')
plt.legend()
plt.xlabel('Date Integer')
plt.show()
#Convert to numpy array and reshape them
dates = np.asanyarray(dates)
prices = np.asanyarray(prices)
dates = np.reshape(dates,(len(dates),1))
prices = np.reshape(prices, (len(prices), 1))
#Load Pickle File to get the previous saved model accuracy
try:
pickle_in = open("prediction.pickle", "rb")
reg = pickle.load(pickle_in)
xtrain, xtest, ytrain, ytest = train_test_split(dates, prices, test_size=0.2)
best = reg.score(ytrain, ytest)
except:
pass
#Get the highest accuracy model
best = 0
for _ in range(100):
xtrain, xtest, ytrain, ytest = train_test_split(dates, prices, test_size=0.2)
reg = LinearRegression().fit(xtrain, ytrain)
acc = reg.score(xtest, ytest)
if acc > best:
best = acc
#Save model to pickle format
with open('prediction.pickle','wb') as f:
pickle.dump(reg, f)
print(acc)
#Load linear regression model
pickle_in = open("prediction.pickle", "rb")
reg = pickle.load(pickle_in)
#Get the average accuracy of the model
mean = 0
for i in range(10):
#Random Split Data
msk = np.random.rand(len(df)) < 0.8
xtest = dates[~msk]
ytest = prices[~msk]
mean += reg.score(xtest,ytest)
print("Average Accuracy:", mean/10)
#Plot Predicted VS Actual Data
plt.plot(xtest, ytest, color='green',linewidth=1, label= 'Actual Price') #plotting the initial datapoints
plt.plot(xtest, reg.predict(xtest), color='blue', linewidth=3, label = 'Predicted Price') #plotting the line made by linear regression
plt.title('Linear Regression | Time vs. Price ')
plt.legend()
plt.xlabel('Date Integer')
plt.show()