Skip to content

Commit 548cebe

Browse files
author
dhiren
committed
added more features to the models for improving accuracy.
1 parent d096134 commit 548cebe

File tree

1 file changed

+77
-9
lines changed

1 file changed

+77
-9
lines changed

fbmetrics.py

+77-9
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import matplotlib.pyplot as plt
55
from matplotlib import style
66
from sklearn.linear_model import LinearRegression,SGDRegressor,Ridge,Lasso,LogisticRegression
7-
from adspy_shared_utilities import plot_fruit_knn
7+
from adspy_shared_utilities import plot_fruit_knn,plot_feature_importances,plot_decision_tree
88
import graphviz
9+
from sklearn.tree import DecisionTreeClassifier
10+
from sklearn.metrics import r2_score
911

1012

1113

@@ -34,6 +36,24 @@ def Type(x,type_of):
3436

3537
plotdf = df.drop(df.columns[7:15],axis =1)
3638

39+
def Weekday(x):
40+
if x == 1:
41+
return 'Su'
42+
elif x== 2:
43+
return 'Mo'
44+
elif x == 3:
45+
return 'Tu'
46+
elif x == 4:
47+
return 'We'
48+
elif x == 5:
49+
return 'Th'
50+
elif x ==6:
51+
return 'Fr'
52+
elif x == 7:
53+
return "Sa"
54+
55+
df['Weekday'] = df['Post Weekday'].apply(lambda x: Weekday(x))
56+
3757
fig, ax = plt.subplots()
3858
paid = df[df['Paid']==1]
3959
free = df[df['Paid']==0] #seperated free users
@@ -47,24 +67,66 @@ def Type(x,type_of):
4767
ax.legend(labels=['Paid','Free'])
4868
#plt.show()
4969

50-
x = df[['Page total likes','Post Month', 'Post Weekday',
51-
'Post Hour', 'Paid','Photo','Video','Link','Status','Cat_1','Cat_2']]
52-
y = df['Lifetime Engaged Users']
70+
#print(df.head())
71+
72+
dayDf = pd.get_dummies(df['Weekday'])
73+
df = pd.concat([df,dayDf],axis=1)
74+
hours = list(range(0,18))
75+
#hours
76+
for i in hours:
77+
hours[i] = str(hours[i])
78+
hours[i]='hr_'+ hours[i]
79+
80+
hourDf = pd.get_dummies(df['Post Hour'],prefix='hr_')
81+
df = pd.concat([df,hourDf],axis=1)
82+
monthDf = pd.get_dummies(df['Post Month'],prefix='Mo')
83+
df = pd.concat([df,monthDf],axis=1)
84+
85+
df['Video'] = pd.get_dummies(df['Type'])['Video']
86+
x = df[['Page total likes','Paid','Video','Status','Photo',
87+
'Cat_1','Cat_2','Mo','Tu','Sa',"We",'Th','Fr',
88+
'hr__17','hr__1','hr__2','hr__3','hr__4','hr__5', 'hr__6','hr__7','hr__8',
89+
'hr__9','hr__10','hr__11','hr__12','hr__13','hr__14','hr__15','hr__16','Mo_1',
90+
'Mo_2','Mo_12','Mo_4','Mo_5','Mo_6','Mo_7','Mo_8','Mo_9','Mo_11','Mo_10']]
91+
y = df['like']
92+
93+
x_train,x_test,y_train, y_test = model_selection.train_test_split(x,
94+
y, test_size=0.1,
95+
random_state=42)
96+
97+
reg = LinearRegression(normalize=True)
98+
lasso = Lasso(normalize=True)
99+
reg.fit(x_train,y_train)
100+
lasso.fit(x_train,y_train)
101+
102+
reg.fit(x_test,y_test)
103+
#print(reg.score(x_test,y_test))
53104

54-
x_train,x_test,y_train,y_test = model_selection.train_test_split(x,y, test_size=0.2)
105+
predicted_train = reg.predict(x_train)
106+
predicted_test = reg.predict(x_test)
107+
test_score = r2_score(y_test, predicted_test)
108+
print(test_score)
109+
#x_train,x_test,y_train,y_test = model_selection.train_test_split(x,y,test_size=0.2)
110+
# clf = DecisionTreeClassifier(max_depth = 4, min_samples_leaf = 8,
111+
# random_state = 0).fit(x_train, y_train)
55112

113+
114+
115+
#clf= LinearRegression()
116+
117+
#cv_scores= model_selection.cross_val_score(clf,x,y)
56118
# scaler= preprocessing.MinMaxScaler()
57119
# scaler.fit(x_train)
58120
# x_test= scaler.transform(x_test)
59121
# x_train= scaler.transform(x_train)
60122
#Clf=Lasso(alpha=8.0,max_iter=10000)
61123
#clf= neighbors.KNeighborsRegressor(n_neighbors=5)
124+
#print(cv_scores)
62125

63-
clf=LogisticRegression()
64-
clf.fit(x_train,y_train)
126+
#clf.fit(x_train,y_train)
65127

66-
accuracy= clf.score(x_test,y_test)
67-
print(accuracy*100)
128+
#accuracy= clf.score(x_test,y_test)
129+
#print(accuracy*100)
68130

69131
# x_train.plot()
70132
# y_train.plot()
@@ -76,3 +138,9 @@ def Type(x,type_of):
76138
# sgd = sgd.fit(x,y)
77139

78140
#print(sgd.score(x_test,y_test))
141+
142+
143+
#plt.figure(figsize=(10,4))
144+
#plot_feature_importances(clf,['Page total likes','Post Month', 'Post Weekday',
145+
# 'Post Hour', 'Paid','Photo','Video','Link','Status','Cat_1','Cat_2'])
146+
#plt.show()

0 commit comments

Comments
 (0)