Skip to content

Commit f2bdc4d

Browse files
committed
Merge pull request #264 from CPSSD/n/visualise-model-error#187
Add model error calculation
2 parents 700f4de + d21ccae commit f2bdc4d

File tree

5 files changed

+282
-1
lines changed

5 files changed

+282
-1
lines changed

Diff for: README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,4 @@ This contains a worker to refresh the machine learning model of a particular use
141141

142142
#### `./visualisations`
143143

144-
This folder stores the programs that are intended to parse the data in the database and provide meaningfull statistics on things such as what are the most subscribed feeds and what topics a particular user is most interested in.
144+
This contains the tools that are used to visualise the data in the databases in a more friendly and useful way; to parse the data, and to provide meaningful statistics, such as what are the most subscribed feeds, and what topics is a particular user most interested in.

Diff for: script/test/visualisations

+3
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,8 @@ cd visualisations
66
cd all_topics
77
python testing.py
88
cd ..
9+
cd model_error
10+
python testing.py
11+
cd ..
912

1013
cd ..

Diff for: visualisations/model_error/README.md

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
Model Error Calculator
2+
======================
3+
4+
This module gets the error of a given user's prediction model.
5+
6+
It uses the '2-fold cross-validation' method, as described [here](https://en.wikipedia.org/wiki/Cross-validation_%28statistics%29#2-fold_cross-validation)
7+
8+
Dependencies
9+
------------
10+
11+
- Python 2.7
12+
- scikit-learn
13+
- gearman
14+
15+
How to Test
16+
-----------
17+
18+
Move to this directory, and run `python testing.py`.
19+
20+
Usage
21+
-----
22+
23+
Use `source /home/python/bin/activate` to set the Python environment.
24+
25+
To get the error of a user's model, just then run `python model_error.py <username>`.
26+

Diff for: visualisations/model_error/model_error.py

+228
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
import gearman
2+
import bson
3+
import sys
4+
import pickle
5+
import os
6+
from random import shuffle
7+
from sklearn import linear_model
8+
9+
# get kw_score module, so it doesn't need to do a gearman request for each call
10+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'aggregator'))
11+
import kw_score
12+
13+
gearman_client = gearman.GearmanClient(['localhost:4730'])
14+
15+
def get_username_from_input(argv):
16+
'''
17+
A python script gives its input args as a list;
18+
Convert that list to a single string.
19+
'''
20+
return ' '.join(argv[1:])
21+
22+
def get_input_data(all_data):
23+
return [all_data[i][0] for i in xrange(len(all_data))]
24+
25+
def get_output_data(all_data):
26+
return [all_data[i][1] for i in xrange(len(all_data))]
27+
28+
def db_get(collection, query, projection):
29+
db_req = {
30+
'database': 'feedlark',
31+
'collection': collection,
32+
'query': query,
33+
'projection': projection
34+
}
35+
key = os.getenv('SECRETKEY')
36+
if key is not None:
37+
db_req['key'] = key
38+
str_bson_db_req = str(bson.BSON.encode(db_req))
39+
raw_db_result = gearman_client.submit_job('db-get', str_bson_db_req).result
40+
db_result = bson.BSON.decode(bson.BSON(raw_db_result))
41+
return db_result
42+
43+
def has_enough_classes(training):
44+
training_classes = set()
45+
for t in training:
46+
training_classes.add(t[1])
47+
return len(training_classes) >= 2
48+
49+
def get_model_score(training, validation):
50+
model = linear_model.SGDClassifier(loss='log', n_iter=5)
51+
model.fit(get_input_data(training), get_output_data(training))
52+
curr_score = model.score(get_input_data(validation), get_output_data(validation))
53+
return curr_score
54+
55+
def main():
56+
gearman_client = gearman.GearmanClient(['localhost:4730'])
57+
if len(sys.argv) < 2:
58+
print('Please specify a user to get the error of. See README.md')
59+
return
60+
username = get_username_from_input(sys.argv)
61+
62+
print 'Getting model error of {}'.format(username)
63+
print 'Loading user\'s votes from database'
64+
65+
# get the user's votes on articles
66+
db_result = db_get('vote', {
67+
'username': username
68+
},{
69+
'article_url': 1,
70+
'feed_url': 1,
71+
'positive_opinion': 1,
72+
'vote_datetime': 1
73+
})
74+
if db_result['status'] != 'ok':
75+
print 'Error'
76+
print 'Could not get user data from vote collection'
77+
print db_result['description']
78+
return
79+
articles = db_result['docs']
80+
81+
print len(articles), 'article opinions found in vote db for given user'
82+
83+
# map each article url to 1 or -1, if the user liked or disliked it
84+
article_opinions = {}
85+
vote_datetimes = {}
86+
for article in articles:
87+
# make sure all the required fields are there
88+
req_fields = ['article_url', 'positive_opinion', 'vote_datetime']
89+
if not all([s in article for s in req_fields]):
90+
print 'Error'
91+
print 'Vote is missing some fields: {}'.format(article)
92+
continue
93+
url = article['article_url']
94+
# set the classes for the votes to 1 for positive and -1 for negative
95+
vote = 1 if article['positive_opinion'] else -1
96+
article_opinions[url] = vote
97+
vote_datetimes[url] = article['vote_datetime']
98+
99+
# split the articles into the feeds they belong to, to minimise db lookups
100+
# the dict maps feed urls to a list of article urls fromt that feed
101+
feeds = {}
102+
for article in articles:
103+
if article['feed_url'] in feeds:
104+
feeds[article['feed_url']].append(article['article_url'])
105+
else:
106+
feeds[article['feed_url']] = [article['article_url']]
107+
108+
# get a set of the unique article urls
109+
article_url_set = set(article_opinions.keys())
110+
print len(article_url_set), 'unique articles in set'
111+
112+
if len(article_url_set) < 0:
113+
print 'Error'
114+
print 'Not enough articles in data set'
115+
return
116+
117+
# get the words the user is interested in
118+
db_result = db_get('user', {
119+
'username': username
120+
}, {
121+
'words': 1
122+
})
123+
if db_result['status'] != 'ok':
124+
print 'Error'
125+
print 'Could not load data from user collection'
126+
print db_result['description']
127+
return
128+
if len(db_result['docs']) < 1:
129+
print 'Error'
130+
print 'No such user in user collection'
131+
return
132+
user_data = db_result['docs'][0]
133+
user_words = user_data['words']
134+
135+
# it is required to have at least classes, so create two
136+
# inputs with extreme values to train the model
137+
data_x = [[10.0, 1], [0.0, 10000000]]
138+
data_y = [1, -1]
139+
140+
# get the data from the db for each feed a user voted on an article in
141+
for feed in feeds:
142+
db_result = db_get('feed', {
143+
'url': feed
144+
}, {
145+
'items': 1
146+
})
147+
if db_result['status'] != 'ok':
148+
print 'Error'
149+
print 'Could not get data from feed collection'
150+
print db_result['description']
151+
return
152+
if 'docs' not in db_result or len(db_result['docs']) < 1:
153+
print 'Error'
154+
print 'No feed returned for url', feed
155+
return
156+
157+
items = db_result['docs'][0]['items']
158+
# check the items in that feed for one the user voted on
159+
for item in items:
160+
if item['link'] not in article_url_set:
161+
continue
162+
print 'adding', item['link']
163+
if 'topics' not in item:
164+
print 'Error'
165+
print 'No topics for given item, skipping'
166+
continue
167+
words = item['topics']
168+
topic_crossover = kw_score.score(words, user_words)
169+
if 'pub_date' not in item:
170+
print 'Error'
171+
print 'No pub_date for given item, skipping'
172+
continue
173+
time_diff = vote_datetimes[item['link']] - item['pub_date']
174+
x = [topic_crossover, time_diff.total_seconds()]
175+
y = article_opinions[item['link']]
176+
data_x.append(x)
177+
data_y.append(y)
178+
179+
print 'Articles from feed', feed, 'added to data'
180+
181+
182+
print data_x
183+
print data_y
184+
185+
if len(data_x) < 3:
186+
print 'Error'
187+
print 'Not enough data points'
188+
return
189+
190+
data_points = [(data_x[i], data_y[i]) for i in xrange(len(data_x))]
191+
n = 0
192+
score = 0
193+
194+
# start the 2-fold cross-validation, doing up to 10 folds of the data
195+
repetitions = min(len(data_points), 10)
196+
for k in xrange(repetitions):
197+
print 'Iteration {} out of {} ({}% finished)'.format(k, len(data_points), 100*(float(k)/repetitions))
198+
shuffle(data_points)
199+
training = data_points[:len(data_points)/2]
200+
validation = data_points[len(data_points)/2:]
201+
if has_enough_classes(training):
202+
curr_score = get_model_score(training, validation)
203+
print '- Score 1 this fold: {}'.format(curr_score)
204+
score += curr_score
205+
n += 1
206+
else:
207+
print '- Not enough training classes, skipping'
208+
continue
209+
210+
#swap the training and validation data
211+
training, validation = validation, training
212+
if has_enough_classes(training):
213+
curr_score = get_model_score(training, validation)
214+
print '- Score 2 this fold: {}'.format(curr_score)
215+
score += curr_score
216+
n += 1
217+
else:
218+
print '- Not enough training classes, skipping'
219+
continue
220+
if n == 0:
221+
print 'Error'
222+
print 'Not enough valid data points'
223+
return
224+
print 'Score: {:.6f}, based on {} divisions of the data.'.format(float(score)/n, n)
225+
return
226+
227+
if __name__ == '__main__':
228+
main()

Diff for: visualisations/model_error/testing.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import gearman
2+
import bson
3+
import unittest
4+
import model_error as me
5+
6+
class TestDbTools(unittest.TestCase):
7+
def test_get_username(self):
8+
self.assertEqual(me.get_username_from_input(['', 'jeremy']), 'jeremy')
9+
self.assertEqual(me.get_username_from_input(['', 'jeremy', 'corbyn']), 'jeremy corbyn')
10+
11+
def test_has_enough_classes(self):
12+
data = [[0, 1], [0, 1], [0, 1]]
13+
self.assertFalse(me.has_enough_classes(data))
14+
data.append([0, -1])
15+
self.assertTrue(me.has_enough_classes(data))
16+
17+
def test_get_model_score(self):
18+
training = [[[1, -1], 1], [[-1, 1], -1]]
19+
validation = [[[1, -1], 1]]
20+
21+
self.assertTrue(me.get_model_score(training, validation) > 0.5)
22+
23+
if __name__ == '__main__':
24+
unittest.main()

0 commit comments

Comments
 (0)