From 835fa6ba70c62963e70b56b9d41fa7facfdb7854 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 15 Sep 2017 19:22:44 +0000 Subject: [PATCH] LSTM on books data --- hatt_train.ipynb | 9 +- lstm_word2vec_small.ipynb | 331 ++++++++++++++++++++++++++++++-------- 2 files changed, 267 insertions(+), 73 deletions(-) diff --git a/hatt_train.ipynb b/hatt_train.ipynb index 0232aa0..70c9e3e 100644 --- a/hatt_train.ipynb +++ b/hatt_train.ipynb @@ -2,7 +2,10 @@ "cells": [ { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "deletable": true, + "editable": true + }, "source": [ "The code in this notebook is based on [Richard Liao's implementation of hierarchical attention networks](https://github.com/richliao/textClassifier/blob/master/textClassifierHATT.py) and a related [Google group discussion](https://groups.google.com/forum/#!topic/keras-users/IWK9opMFavQ). The notebook also includes code from [Keras documentation](https://keras.io/) and [blog](https://blog.keras.io/using-pre-trained-word-embeddings-in-a-keras-model.html) as well as this [word2vec tutorial](http://adventuresinmachinelearning.com/gensim-word2vec-tutorial/)." ] @@ -171,11 +174,11 @@ "texts = []\n", "\n", "for idx in range(train_data.shape[0]):\n", - " text = train_data['text'][idx]\n", + " text = train_data['text'].iloc[idx]\n", " texts.append(text)\n", " sentences = nltk.tokenize.sent_tokenize(text)\n", " reviews.append(sentences)\n", - " labels.append(train_data['rating'][idx])" + " labels.append(train_data['rating'].iloc[idx])" ] }, { diff --git a/lstm_word2vec_small.ipynb b/lstm_word2vec_small.ipynb index 4cc87a7..bdbec87 100644 --- a/lstm_word2vec_small.ipynb +++ b/lstm_word2vec_small.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": { "collapsed": false, "deletable": true, @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": { "collapsed": false, "deletable": true, @@ -65,6 +65,7 @@ }, "outputs": [], "source": [ + "\"\"\"\n", "from azureml import Workspace\n", "ws = Workspace(\n", " workspace_id='817780d9ee0d4a878e25f8c9deb3b866',\n", @@ -73,7 +74,31 @@ ")\n", "ds = ws.datasets['Book Reviews from Amazon']\n", "all_data = ds.to_dataframe()\n", - "all_data.rename(columns={0: 'rating', 1: 'text'}, inplace=True)" + "all_data.rename(columns={0: 'rating', 1: 'text'}, inplace=True)\n", + "all_data.loc[:, 'rating'] = all_data['rating'] - 1 # reindex ratings to start from 0\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "from azureml import Workspace\n", + "ws = Workspace(\n", + " workspace_id='817780d9ee0d4a878e25f8c9deb3b866',\n", + " authorization_token='6df8a52943bd49eba6e57446bc73f5fc',\n", + " endpoint='https://studioapi.azureml.net'\n", + ")\n", + "ds = ws.datasets['dfe_happysad_utf.csv']\n", + "all_data = ds.to_dataframe()\n", + "all_data.rename(columns={'features': 'text', 'label': 'rating'}, inplace=True)\n", + "all_data.replace({'rating': {'sadness': 0, 'happiness': 1}}, inplace=True)" ] }, { @@ -88,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 39, "metadata": { "collapsed": false, "deletable": true, @@ -123,7 +148,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 42, "metadata": { "collapsed": true, "deletable": true, @@ -138,7 +163,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 43, "metadata": { "collapsed": true, "deletable": true, @@ -162,7 +187,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 44, "metadata": { "collapsed": false, "deletable": true, @@ -191,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 49, "metadata": { "collapsed": false, "deletable": true, @@ -200,13 +225,12 @@ "outputs": [], "source": [ "labels = to_categorical(np.asarray(train_data[LABEL_COL]))\n", - "labels = labels[:,1:] # rating 0 does not exist, so remove \n", "labels = labels.astype('float32')" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 50, "metadata": { "collapsed": false, "deletable": true, @@ -224,12 +248,12 @@ "editable": true }, "source": [ - "Train word2vec on all the documents in order to initialize the word embedding. Ignore rare words (min_count=6). Use skip-gram as the training algorithm (sg=1)." + "Train word2vec on the training documents in order to initialize the word embedding. Ignore rare words (min_count=6). Use skip-gram as the training algorithm (sg=1)." ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 53, "metadata": { "collapsed": false, "deletable": true, @@ -261,7 +285,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 54, "metadata": { "collapsed": false, "deletable": true, @@ -273,53 +297,45 @@ "name": "stderr", "output_type": "stream", "text": [ - "2017-09-14 16:47:12,691 : INFO : collecting all words and their counts\n", - "2017-09-14 16:47:12,692 : INFO : PROGRESS: at sentence #0, processed 0 words, keeping 0 word types\n", - "2017-09-14 16:47:12,740 : INFO : PROGRESS: at sentence #10000, processed 179872 words, keeping 17751 word types\n", - "2017-09-14 16:47:12,788 : INFO : PROGRESS: at sentence #20000, processed 362578 words, keeping 26298 word types\n", - "2017-09-14 16:47:12,834 : INFO : PROGRESS: at sentence #30000, processed 544135 words, keeping 32158 word types\n", - "2017-09-14 16:47:12,882 : INFO : PROGRESS: at sentence #40000, processed 727386 words, keeping 37250 word types\n", - "2017-09-14 16:47:12,931 : INFO : PROGRESS: at sentence #50000, processed 911012 words, keeping 41503 word types\n", - "2017-09-14 16:47:12,982 : INFO : PROGRESS: at sentence #60000, processed 1093093 words, keeping 45119 word types\n", - "2017-09-14 16:47:13,022 : INFO : collected 47857 word types from a corpus of 1245083 raw words and 68283 sentences\n", - "2017-09-14 16:47:13,022 : INFO : Loading a fresh vocabulary\n", - "2017-09-14 16:47:13,072 : INFO : min_count=6 retains 12670 unique words (26% of original 47857, drops 35187)\n", - "2017-09-14 16:47:13,073 : INFO : min_count=6 leaves 1180885 word corpus (94% of original 1245083, drops 64198)\n", - "2017-09-14 16:47:13,106 : INFO : deleting the raw counts dictionary of 47857 items\n", - "2017-09-14 16:47:13,109 : INFO : sample=0.001 downsamples 44 most-common words\n", - "2017-09-14 16:47:13,109 : INFO : downsampling leaves estimated 886078 word corpus (75.0% of prior 1180885)\n", - "2017-09-14 16:47:13,110 : INFO : estimated required memory for 12670 words and 200 dimensions: 26607000 bytes\n", - "2017-09-14 16:47:13,144 : INFO : resetting layer weights\n", - "2017-09-14 16:47:13,341 : INFO : training model with 24 workers on 12670 vocabulary and 200 features, using sg=1 hs=0 sample=0.001 negative=5 window=5\n", - "2017-09-14 16:47:14,357 : INFO : PROGRESS: at 21.53% examples, 950506 words/s, in_qsize 47, out_qsize 0\n", - "2017-09-14 16:47:15,365 : INFO : PROGRESS: at 47.88% examples, 1054121 words/s, in_qsize 47, out_qsize 0\n", - "2017-09-14 16:47:16,367 : INFO : PROGRESS: at 72.71% examples, 1069305 words/s, in_qsize 48, out_qsize 0\n", - "2017-09-14 16:47:17,329 : INFO : worker thread finished; awaiting finish of 23 more threads\n", - "2017-09-14 16:47:17,335 : INFO : worker thread finished; awaiting finish of 22 more threads\n", - "2017-09-14 16:47:17,335 : INFO : worker thread finished; awaiting finish of 21 more threads\n", - "2017-09-14 16:47:17,336 : INFO : worker thread finished; awaiting finish of 20 more threads\n", - "2017-09-14 16:47:17,347 : INFO : worker thread finished; awaiting finish of 19 more threads\n", - "2017-09-14 16:47:17,353 : INFO : worker thread finished; awaiting finish of 18 more threads\n", - "2017-09-14 16:47:17,355 : INFO : worker thread finished; awaiting finish of 17 more threads\n", - "2017-09-14 16:47:17,371 : INFO : PROGRESS: at 97.55% examples, 1076295 words/s, in_qsize 16, out_qsize 1\n", - "2017-09-14 16:47:17,372 : INFO : worker thread finished; awaiting finish of 16 more threads\n", - "2017-09-14 16:47:17,374 : INFO : worker thread finished; awaiting finish of 15 more threads\n", - "2017-09-14 16:47:17,378 : INFO : worker thread finished; awaiting finish of 14 more threads\n", - "2017-09-14 16:47:17,383 : INFO : worker thread finished; awaiting finish of 13 more threads\n", - "2017-09-14 16:47:17,391 : INFO : worker thread finished; awaiting finish of 12 more threads\n", - "2017-09-14 16:47:17,398 : INFO : worker thread finished; awaiting finish of 11 more threads\n", - "2017-09-14 16:47:17,403 : INFO : worker thread finished; awaiting finish of 10 more threads\n", - "2017-09-14 16:47:17,408 : INFO : worker thread finished; awaiting finish of 9 more threads\n", - "2017-09-14 16:47:17,411 : INFO : worker thread finished; awaiting finish of 8 more threads\n", - "2017-09-14 16:47:17,417 : INFO : worker thread finished; awaiting finish of 7 more threads\n", - "2017-09-14 16:47:17,426 : INFO : worker thread finished; awaiting finish of 6 more threads\n", - "2017-09-14 16:47:17,427 : INFO : worker thread finished; awaiting finish of 5 more threads\n", - "2017-09-14 16:47:17,428 : INFO : worker thread finished; awaiting finish of 4 more threads\n", - "2017-09-14 16:47:17,429 : INFO : worker thread finished; awaiting finish of 3 more threads\n", - "2017-09-14 16:47:17,434 : INFO : worker thread finished; awaiting finish of 2 more threads\n", - "2017-09-14 16:47:17,440 : INFO : worker thread finished; awaiting finish of 1 more threads\n", - "2017-09-14 16:47:17,445 : INFO : worker thread finished; awaiting finish of 0 more threads\n", - "2017-09-14 16:47:17,446 : INFO : training on 6225415 raw words (4431895 effective words) took 4.1s, 1083352 effective words/s\n" + "2017-09-15 11:21:16,427 : INFO : collecting all words and their counts\n", + "2017-09-15 11:21:16,428 : INFO : PROGRESS: at sentence #0, processed 0 words, keeping 0 word types\n", + "2017-09-15 11:21:16,451 : INFO : PROGRESS: at sentence #10000, processed 75804 words, keeping 13063 word types\n", + "2017-09-15 11:21:16,459 : INFO : collected 15977 word types from a corpus of 100883 raw words and 13257 sentences\n", + "2017-09-15 11:21:16,460 : INFO : Loading a fresh vocabulary\n", + "2017-09-15 11:21:16,470 : INFO : min_count=6 retains 1613 unique words (10% of original 15977, drops 14364)\n", + "2017-09-15 11:21:16,471 : INFO : min_count=6 leaves 80706 word corpus (79% of original 100883, drops 20177)\n", + "2017-09-15 11:21:16,476 : INFO : deleting the raw counts dictionary of 15977 items\n", + "2017-09-15 11:21:16,478 : INFO : sample=0.001 downsamples 65 most-common words\n", + "2017-09-15 11:21:16,479 : INFO : downsampling leaves estimated 59121 word corpus (73.3% of prior 80706)\n", + "2017-09-15 11:21:16,479 : INFO : estimated required memory for 1613 words and 200 dimensions: 3387300 bytes\n", + "2017-09-15 11:21:16,484 : INFO : resetting layer weights\n", + "2017-09-15 11:21:16,520 : INFO : training model with 24 workers on 1613 vocabulary and 200 features, using sg=1 hs=0 sample=0.001 negative=5 window=5\n", + "2017-09-15 11:21:16,831 : INFO : worker thread finished; awaiting finish of 23 more threads\n", + "2017-09-15 11:21:16,833 : INFO : worker thread finished; awaiting finish of 22 more threads\n", + "2017-09-15 11:21:16,844 : INFO : worker thread finished; awaiting finish of 21 more threads\n", + "2017-09-15 11:21:16,846 : INFO : worker thread finished; awaiting finish of 20 more threads\n", + "2017-09-15 11:21:16,848 : INFO : worker thread finished; awaiting finish of 19 more threads\n", + "2017-09-15 11:21:16,854 : INFO : worker thread finished; awaiting finish of 18 more threads\n", + "2017-09-15 11:21:16,858 : INFO : worker thread finished; awaiting finish of 17 more threads\n", + "2017-09-15 11:21:16,861 : INFO : worker thread finished; awaiting finish of 16 more threads\n", + "2017-09-15 11:21:16,865 : INFO : worker thread finished; awaiting finish of 15 more threads\n", + "2017-09-15 11:21:16,880 : INFO : worker thread finished; awaiting finish of 14 more threads\n", + "2017-09-15 11:21:16,882 : INFO : worker thread finished; awaiting finish of 13 more threads\n", + "2017-09-15 11:21:16,889 : INFO : worker thread finished; awaiting finish of 12 more threads\n", + "2017-09-15 11:21:16,891 : INFO : worker thread finished; awaiting finish of 11 more threads\n", + "2017-09-15 11:21:16,895 : INFO : worker thread finished; awaiting finish of 10 more threads\n", + "2017-09-15 11:21:16,897 : INFO : worker thread finished; awaiting finish of 9 more threads\n", + "2017-09-15 11:21:16,898 : INFO : worker thread finished; awaiting finish of 8 more threads\n", + "2017-09-15 11:21:16,904 : INFO : worker thread finished; awaiting finish of 7 more threads\n", + "2017-09-15 11:21:16,907 : INFO : worker thread finished; awaiting finish of 6 more threads\n", + "2017-09-15 11:21:16,908 : INFO : worker thread finished; awaiting finish of 5 more threads\n", + "2017-09-15 11:21:16,909 : INFO : worker thread finished; awaiting finish of 4 more threads\n", + "2017-09-15 11:21:16,910 : INFO : worker thread finished; awaiting finish of 3 more threads\n", + "2017-09-15 11:21:16,914 : INFO : worker thread finished; awaiting finish of 2 more threads\n", + "2017-09-15 11:21:16,920 : INFO : worker thread finished; awaiting finish of 1 more threads\n", + "2017-09-15 11:21:16,921 : INFO : worker thread finished; awaiting finish of 0 more threads\n", + "2017-09-15 11:21:16,921 : INFO : training on 504415 raw words (295266 effective words) took 0.4s, 758339 effective words/s\n", + "2017-09-15 11:21:16,922 : WARNING : under 10 jobs per worker: consider setting a smaller `batch_words' for smoother alpha decay\n" ] } ], @@ -343,7 +359,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 55, "metadata": { "collapsed": false, "deletable": true, @@ -354,7 +370,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Total 12670 word vectors.\n" + "Total 1613 word vectors.\n" ] } ], @@ -389,7 +405,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 56, "metadata": { "collapsed": true, "deletable": true, @@ -405,7 +421,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 57, "metadata": { "collapsed": false, "deletable": true, @@ -413,7 +429,7 @@ }, "outputs": [], "source": [ - "def lstm_create_train(reg_param):\n", + "def lstm_create_train(reg_param, ref_str):\n", " l2_reg = regularizers.l2(reg_param)\n", "\n", " # model init\n", @@ -438,7 +454,7 @@ " metrics=['acc'])\n", "\n", " history = History()\n", - " csv_logger = CSVLogger('./lstm_model_wvec_{}_log'.format(reg_param),\n", + " csv_logger = CSVLogger('./lstm_model_wvec_{0}_{1}.log'.format(reg_param, ref_str),\n", " separator=',',\n", " append=True)\n", "\n", @@ -455,10 +471,10 @@ " print(\"\\n\")\n", " \n", " # save model\n", - " model.save('./lstm_wvec_{}_model.h5'.format(reg_param))\n", - " np.savetxt('./lstm_wvec_{}_time.txt'.format(reg_param), \n", + " model.save('./lstm_wvec_{0}_{1}_model.h5'.format(reg_param, ref_str))\n", + " np.savetxt('./lstm_wvec_{0}_{1}_time.txt'.format(reg_param, ref_str), \n", " [reg_param, (t2-t1) / 3600])\n", - " with open('./lstm_wvec_{}_history.txt'.format(reg_param), \"w\") as res_file:\n", + " with open('./lstm_wvec_{0}_{1}_history.txt'.format(reg_param, ref_str), \"w\") as res_file:\n", " res_file.write(str(history.history))" ] }, @@ -537,7 +553,7 @@ } ], "source": [ - "preds = model.predict_classes(test_seq) + 1 # add 1 since we removed the 0 class" + "preds = model.predict_classes(test_seq) " ] }, { @@ -565,6 +581,181 @@ "\n", "accuracy_score(test_data[LABEL_COL], preds)" ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training model with regularization parameter = 1e-10\n", + "Epoch 1/10\n", + "33s - loss: 0.6922 - acc: 0.5447\n", + "Epoch 2/10\n", + "32s - loss: 0.6916 - acc: 0.5557\n", + "Epoch 3/10\n", + "32s - loss: 0.6910 - acc: 0.5652\n", + "Epoch 4/10\n", + "32s - loss: 0.6905 - acc: 0.5893\n", + "Epoch 5/10\n", + "32s - loss: 0.6899 - acc: 0.5828\n", + "Epoch 6/10\n", + "32s - loss: 0.6893 - acc: 0.5827\n", + "Epoch 7/10\n", + "32s - loss: 0.6887 - acc: 0.5943\n", + "Epoch 8/10\n", + "32s - loss: 0.6882 - acc: 0.5901\n", + "Epoch 9/10\n", + "32s - loss: 0.6876 - acc: 0.5933\n", + "Epoch 10/10\n", + "32s - loss: 0.6870 - acc: 0.5905\n", + "\n", + "\n", + "Training model with regularization parameter = 1e-07\n", + "Epoch 1/10\n", + "32s - loss: 0.6944 - acc: 0.4816\n", + "Epoch 2/10\n", + "32s - loss: 0.6936 - acc: 0.4899\n", + "Epoch 3/10\n", + "32s - loss: 0.6929 - acc: 0.5145\n", + "Epoch 4/10\n", + "32s - loss: 0.6924 - acc: 0.5299\n", + "Epoch 5/10\n", + "32s - loss: 0.6917 - acc: 0.5463\n", + "Epoch 6/10\n", + "32s - loss: 0.6911 - acc: 0.5524\n", + "Epoch 7/10\n", + "32s - loss: 0.6905 - acc: 0.5656\n", + "Epoch 8/10\n", + "32s - loss: 0.6899 - acc: 0.5616\n", + "Epoch 9/10\n", + "32s - loss: 0.6893 - acc: 0.5647\n", + "Epoch 10/10\n", + "32s - loss: 0.6888 - acc: 0.5716\n", + "\n", + "\n", + "Training model with regularization parameter = 0.0001\n", + "Epoch 1/10\n", + "33s - loss: 0.9469 - acc: 0.4776\n", + "Epoch 2/10\n", + "32s - loss: 0.9460 - acc: 0.4563\n", + "Epoch 3/10\n", + "32s - loss: 0.9453 - acc: 0.4888\n", + "Epoch 4/10\n", + "32s - loss: 0.9447 - acc: 0.5072\n", + "Epoch 5/10\n", + "32s - loss: 0.9441 - acc: 0.5355\n", + "Epoch 6/10\n", + "32s - loss: 0.9435 - acc: 0.5539\n", + "Epoch 7/10\n", + "32s - loss: 0.9429 - acc: 0.5832\n", + "Epoch 8/10\n", + "32s - loss: 0.9423 - acc: 0.5837\n", + "Epoch 9/10\n", + "32s - loss: 0.9418 - acc: 0.5871\n", + "Epoch 10/10\n", + "32s - loss: 0.9412 - acc: 0.5881\n", + "\n", + "\n", + "Training model with regularization parameter = 0.1\n", + "Epoch 1/10\n", + "33s - loss: 218.8508 - acc: 0.4709\n", + "Epoch 2/10\n", + "32s - loss: 162.2584 - acc: 0.4561\n", + "Epoch 3/10\n", + "32s - loss: 120.3472 - acc: 0.4695\n", + "Epoch 4/10\n", + "32s - loss: 89.3082 - acc: 0.4863\n", + "Epoch 5/10\n", + "32s - loss: 66.3210 - acc: 0.4891\n", + "Epoch 6/10\n", + "32s - loss: 49.2968 - acc: 0.5004\n", + "Epoch 7/10\n", + "32s - loss: 36.6888 - acc: 0.5081\n", + "Epoch 8/10\n", + "32s - loss: 27.3514 - acc: 0.5031\n", + "Epoch 9/10\n", + "32s - loss: 20.4361 - acc: 0.5056\n", + "Epoch 10/10\n", + "32s - loss: 15.3147 - acc: 0.5019\n", + "\n", + "\n", + "Training model with regularization parameter = 100.0\n", + "Epoch 1/10\n", + "32s - loss: 252060.6371 - acc: 0.4981\n", + "Epoch 2/10\n", + "32s - loss: 252059.9525 - acc: 0.5048\n", + "Epoch 3/10\n", + "32s - loss: 252059.1579 - acc: 0.5069\n", + "Epoch 4/10\n", + "32s - loss: 252058.5729 - acc: 0.5048\n", + "Epoch 5/10\n", + "32s - loss: 252057.9460 - acc: 0.5047\n", + "Epoch 6/10\n", + "32s - loss: 252057.1785 - acc: 0.5205\n", + "Epoch 7/10\n", + "32s - loss: 252056.3502 - acc: 0.5075\n", + "Epoch 8/10\n", + "32s - loss: 252055.6600 - acc: 0.5020\n", + "Epoch 9/10\n", + "32s - loss: 252054.9252 - acc: 0.5036\n", + "Epoch 10/10\n", + "32s - loss: 252054.3204 - acc: 0.5081\n", + "\n", + "\n" + ] + } + ], + "source": [ + "for rp in [1e-10, 1e-7, 1e-4, 1e-1, 1e2]:\n", + " lstm_create_train(rp, 'tweets')" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1e-10, 0.59364081062194274)\n", + "(1e-07, 0.57092941998602376)\n", + "(0.0001, 0.57477288609364086)\n", + "(0.1, 0.50454227812718377)\n", + "(100.0, 0.56533892382948991)\n" + ] + } + ], + "source": [ + "from sklearn.metrics import accuracy_score\n", + "\n", + "for rp in [1e-10, 1e-7, 1e-4, 1e-1, 1e2]:\n", + " model = load_model('./lstm_wvec_{0}_{1}_model.h5'.format(rp, 'tweets'))\n", + " preds = model.predict_classes(test_seq, verbose=0)\n", + " print((rp, accuracy_score(test_data[LABEL_COL], preds)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [] } ], "metadata": {