diff --git a/hatt_archive.ipynb b/hatt_archive.ipynb new file mode 100644 index 0000000..a11e8ac --- /dev/null +++ b/hatt_archive.ipynb @@ -0,0 +1,1170 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "The code in this notebook is based on [Richard Liao's implementation of hierarchical attention networks](https://github.com/richliao/textClassifier/blob/master/textClassifierHATT.py) and a related [Google group discussion](https://groups.google.com/forum/#!topic/keras-users/IWK9opMFavQ). The notebook also includes code from [Keras documentation](https://keras.io/) and [blog](https://blog.keras.io/using-pre-trained-word-embeddings-in-a-keras-model.html) as well as this [word2vec tutorial](http://adventuresinmachinelearning.com/gensim-word2vec-tutorial/)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "To enable Theano to run on a single GPU: \n", + "\n", + "* check the following dependencies: \n", + "\n", + " `conda install pygpu`\n", + " \n", + "\n", + "* Replace $HOME/.theanorc with this:\n", + "```\n", + "[global]\n", + "floatX = float32\n", + "device = gpu0\n", + "[lib]\n", + "gpuarray.preallocate=1\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true, + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING (theano.sandbox.cuda): The cuda backend is deprecated and will be removed in the next release (v0.10). Please switch to the gpuarray backend. You can get more information about how to switch at this URL:\n", + " https://github.com/Theano/Theano/wiki/Converting-to-the-new-gpu-back-end%28gpuarray%29\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "gpu0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using gpu device 0: Tesla K80 (CNMeM is disabled, cuDNN 5110)\n" + ] + } + ], + "source": [ + "import os \n", + "os.environ['THEANO_FLAGS'] = 'floatX=float32,device=gpu0'\n", + "os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda-8.0/bin'\n", + "import theano\n", + "print(theano.config.device) " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using Theano backend.\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from collections import defaultdict\n", + "import os \n", + "os.environ['KERAS_BACKEND'] = 'theano'\n", + "import subprocess\n", + "import time\n", + "\n", + "from keras.preprocessing.text import Tokenizer, text_to_word_sequence\n", + "from keras.preprocessing.sequence import pad_sequences\n", + "from keras.utils.np_utils import to_categorical\n", + "from keras.optimizers import SGD\n", + "\n", + "from keras.layers import Embedding\n", + "from keras.layers import Dense, Input, Flatten\n", + "from keras.layers import Conv1D, MaxPooling1D, Embedding, Merge, Dropout, LSTM, GRU, Bidirectional, TimeDistributed\n", + "from keras.models import Model, load_model\n", + "\n", + "from keras import backend as K\n", + "from keras.engine.topology import Layer, InputSpec\n", + "from keras import initializers, regularizers, optimizers\n", + "from keras.callbacks import History, CSVLogger" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "Download the Amazon reviews data for food from the Internet archive \n", + "[J. McAuley and J. Leskovec. Hidden factors and hidden topics: understanding rating dimensions with review text. RecSys, 2013]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2017-09-22 09:08:24-- https://archive.org/download/amazon-reviews-1995-2013/Gourmet_Foods.txt.gz\n", + "Resolving archive.org (archive.org)... 207.241.224.2\n", + "Connecting to archive.org (archive.org)|207.241.224.2|:443... connected.\n", + "HTTP request sent, awaiting response... 302 Moved Temporarily\n", + "Location: https://ia601306.us.archive.org/24/items/amazon-reviews-1995-2013/Gourmet_Foods.txt.gz [following]\n", + "--2017-09-22 09:08:24-- https://ia601306.us.archive.org/24/items/amazon-reviews-1995-2013/Gourmet_Foods.txt.gz\n", + "Resolving ia601306.us.archive.org (ia601306.us.archive.org)... 207.241.227.176\n", + "Connecting to ia601306.us.archive.org (ia601306.us.archive.org)|207.241.227.176|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 31388180 (30M) [application/octet-stream]\n", + "Saving to: ‘Gourmet_Foods.txt.gz’\n", + "\n", + "Gourmet_Foods.txt.g 100%[===================>] 29.93M 1.26MB/s in 13s \n", + "\n", + "2017-09-22 09:08:38 (2.26 MB/s) - ‘Gourmet_Foods.txt.gz’ saved [31388180/31388180]\n", + "\n" + ] + } + ], + "source": [ + "!wget \"https://archive.org/download/amazon-reviews-1995-2013/Gourmet_Foods.txt.gz\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "!gunzip -f Gourmet_Foods.txt.gz" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "with open(\"Gourmet_Foods.txt\", \"r\") as fp:\n", + " lst = fp.readlines()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "Extract scores and review texts from file " + ] + }, + { + "cell_type": "code", + "execution_count": 193, + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "text_lst = lst[9:len(lst):11]\n", + "score_lst = lst[6:len(lst):11]\n", + "score_lst2 = [sc[14:17] for sc in score_lst]\n", + "text_lst2 = [txt[13:] for txt in text_lst]" + ] + }, + { + "cell_type": "code", + "execution_count": 218, + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "all_data = pd.DataFrame(data={'text': text_lst2, 'rating': score_lst2})\n", + "all_data.loc[:, 'rating'] = all_data['rating'].astype(float)\n", + "all_data.loc[:, 'rating'] = all_data['rating'].astype(int)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "Remove medium ratings and convert to binary classification (high vs. low rating). " + ] + }, + { + "cell_type": "code", + "execution_count": 220, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "all_data = all_data[all_data['rating'].isin([1, 5])]" + ] + }, + { + "cell_type": "code", + "execution_count": 221, + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "new_data = all_data.replace({'rating': {1: '0', 5: '1'}})\n", + "new_data.loc[:, 'rating'] = new_data['rating'].astype(int)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "Extract a balanced subsample and split into training and test sets." + ] + }, + { + "cell_type": "code", + "execution_count": 224, + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "sample_data = pd.concat([new_data[new_data.rating == 0].sample(10000), new_data[new_data.rating == 1].sample(10000)])\n", + "shuffled = sample_data.iloc[np.random.permutation(20000), :]\n", + "train_data = shuffled.iloc[:10000, :]\n", + "test_data = shuffled.iloc[10000:, :]" + ] + }, + { + "cell_type": "code", + "execution_count": 225, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0 5019\n", + "1 4981\n", + "Name: rating, dtype: int64" + ] + }, + "execution_count": 225, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_data.rating.value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": 226, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "1 5019\n", + "0 4981\n", + "Name: rating, dtype: int64" + ] + }, + "execution_count": 226, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_data.rating.value_counts()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "Set the dimensions of the input and the embedding. Because of the hierarchical nature of the network, the input has to be a 3-dimensional tensor of fixed size (sample_size x n_sentences x n_words). \n", + "\n", + "MAX_SENT_LEN : the number of words in each sentence. \n", + "\n", + "MAX_SENTS : the number of sentences in each document.\n", + "\n", + "Longer documents and sentences will be truncated, shorter ones will be padded with zeros. These numbers should not be much larger than the average sentence and document lengths in the data. \n", + "\n", + "MAX_NB_WORDS : the size of the word encoding (number of most frequent words to keep in the vocabulary)\n", + "\n", + "EMBEDDING_DIM : the dimensionality of the word embedding" + ] + }, + { + "cell_type": "code", + "execution_count": 334, + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "MAX_SENT_LENGTH = 50\n", + "MAX_SENTS = 15\n", + "MAX_NB_WORDS = 6000\n", + "EMBEDDING_DIM = 100" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "Fit a Keras tokenizer to the most frequent words using the entire training data set as the corpus.\n", + "Create the training data in the 3d format required. " + ] + }, + { + "cell_type": "code", + "execution_count": 335, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[nltk_data] Downloading package punkt to /home/anargyri/nltk_data...\n", + "[nltk_data] Package punkt is already up-to-date!\n" + ] + } + ], + "source": [ + "import nltk \n", + "\n", + "nltk.download('punkt')\n", + "\n", + "reviews = []\n", + "labels = []\n", + "texts = []\n", + "\n", + "for idx in range(train_data.shape[0]):\n", + " text = train_data['text'].iloc[idx]\n", + " texts.append(text)\n", + " sentences = nltk.tokenize.sent_tokenize(text)\n", + " reviews.append(sentences)\n", + " labels.append(train_data['rating'].iloc[idx])" + ] + }, + { + "cell_type": "code", + "execution_count": 336, + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "tokenizer = Tokenizer(num_words=MAX_NB_WORDS)\n", + "tokenizer.fit_on_texts(texts)" + ] + }, + { + "cell_type": "code", + "execution_count": 337, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "data = np.zeros((len(texts), MAX_SENTS, MAX_SENT_LENGTH), dtype='int32')\n", + "doc_lst = []\n", + "\n", + "# keep the MAX_NB_WORDS most frequent words and replace the rest with 'UNK'\n", + "# truncate to the first MAX_SENTS sentences per doc and MAX_SENT_LENGTH words per sentence\n", + "\n", + "for i, sentences in enumerate(reviews):\n", + " for j, sent in enumerate(sentences):\n", + " if j < MAX_SENTS:\n", + " wordTokens = text_to_word_sequence(sent)\n", + " k = 0\n", + " words_in_sent = []\n", + " for _, word in enumerate(wordTokens):\n", + " if k < MAX_SENT_LENGTH: \n", + " if (word in tokenizer.word_index) and (tokenizer.word_index[word] < MAX_NB_WORDS):\n", + " data[i, j, k] = tokenizer.word_index[word]\n", + " words_in_sent.append(word)\n", + " else:\n", + " data[i, j, k] = MAX_NB_WORDS\n", + " words_in_sent.append('UNK')\n", + " k = k + 1\n", + " doc_lst.append(words_in_sent)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "Convert the ratings to one-hot categorical labels." + ] + }, + { + "cell_type": "code", + "execution_count": 338, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total 22846 unique tokens.\n", + "Shape of data tensor: (10000, 15, 50)\n", + "Shape of label tensor: (10000, 2)\n" + ] + } + ], + "source": [ + "word_index = tokenizer.word_index\n", + "print('Total %s unique tokens.' % len(word_index))\n", + "\n", + "y_train = to_categorical(np.asarray(labels))\n", + "x_train = data\n", + "\n", + "print('Shape of data tensor:', x_train.shape)\n", + "print('Shape of label tensor:', y_train.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 339, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "n_classes = y_train.shape[1]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "Train word2vec on the training documents in order to initialize the word embedding. Ignore rare words (min_count=3). Use skip-gram as the training algorithm (sg=1)." + ] + }, + { + "cell_type": "code", + "execution_count": 340, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true, + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2017-09-22 18:38:39,323 : INFO : collecting all words and their counts\n", + "2017-09-22 18:38:39,324 : INFO : PROGRESS: at sentence #0, processed 0 words, keeping 0 word types\n", + "2017-09-22 18:38:39,355 : INFO : PROGRESS: at sentence #10000, processed 146812 words, keeping 5370 word types\n", + "2017-09-22 18:38:39,388 : INFO : PROGRESS: at sentence #20000, processed 294232 words, keeping 5871 word types\n", + "2017-09-22 18:38:39,423 : INFO : PROGRESS: at sentence #30000, processed 441636 words, keeping 5976 word types\n", + "2017-09-22 18:38:39,454 : INFO : PROGRESS: at sentence #40000, processed 591218 words, keeping 5991 word types\n", + "2017-09-22 18:38:39,475 : INFO : collected 5997 word types from a corpus of 694642 raw words and 46940 sentences\n", + "2017-09-22 18:38:39,476 : INFO : Loading a fresh vocabulary\n", + "2017-09-22 18:38:39,490 : INFO : min_count=3 retains 5984 unique words (99% of original 5997, drops 13)\n", + "2017-09-22 18:38:39,491 : INFO : min_count=3 leaves 694622 word corpus (99% of original 694642, drops 20)\n", + "2017-09-22 18:38:39,506 : INFO : deleting the raw counts dictionary of 5997 items\n", + "2017-09-22 18:38:39,507 : INFO : sample=0.001 downsamples 54 most-common words\n", + "2017-09-22 18:38:39,507 : INFO : downsampling leaves estimated 498881 word corpus (71.8% of prior 694622)\n", + "2017-09-22 18:38:39,508 : INFO : estimated required memory for 5984 words and 100 dimensions: 7779200 bytes\n", + "2017-09-22 18:38:39,522 : INFO : resetting layer weights\n", + "2017-09-22 18:38:39,613 : INFO : training model with 24 workers on 5984 vocabulary and 100 features, using sg=1 hs=0 sample=0.001 negative=5 window=5\n", + "2017-09-22 18:38:40,633 : INFO : PROGRESS: at 44.61% examples, 1106428 words/s, in_qsize 46, out_qsize 1\n", + "2017-09-22 18:38:41,568 : INFO : worker thread finished; awaiting finish of 23 more threads\n", + "2017-09-22 18:38:41,575 : INFO : worker thread finished; awaiting finish of 22 more threads\n", + "2017-09-22 18:38:41,578 : INFO : worker thread finished; awaiting finish of 21 more threads\n", + "2017-09-22 18:38:41,581 : INFO : worker thread finished; awaiting finish of 20 more threads\n", + "2017-09-22 18:38:41,582 : INFO : worker thread finished; awaiting finish of 19 more threads\n", + "2017-09-22 18:38:41,583 : INFO : worker thread finished; awaiting finish of 18 more threads\n", + "2017-09-22 18:38:41,584 : INFO : worker thread finished; awaiting finish of 17 more threads\n", + "2017-09-22 18:38:41,589 : INFO : worker thread finished; awaiting finish of 16 more threads\n", + "2017-09-22 18:38:41,590 : INFO : worker thread finished; awaiting finish of 15 more threads\n", + "2017-09-22 18:38:41,591 : INFO : worker thread finished; awaiting finish of 14 more threads\n", + "2017-09-22 18:38:41,596 : INFO : worker thread finished; awaiting finish of 13 more threads\n", + "2017-09-22 18:38:41,602 : INFO : worker thread finished; awaiting finish of 12 more threads\n", + "2017-09-22 18:38:41,610 : INFO : worker thread finished; awaiting finish of 11 more threads\n", + "2017-09-22 18:38:41,615 : INFO : worker thread finished; awaiting finish of 10 more threads\n", + "2017-09-22 18:38:41,622 : INFO : worker thread finished; awaiting finish of 9 more threads\n", + "2017-09-22 18:38:41,635 : INFO : PROGRESS: at 97.71% examples, 1213962 words/s, in_qsize 8, out_qsize 1\n", + "2017-09-22 18:38:41,636 : INFO : worker thread finished; awaiting finish of 8 more threads\n", + "2017-09-22 18:38:41,638 : INFO : worker thread finished; awaiting finish of 7 more threads\n", + "2017-09-22 18:38:41,639 : INFO : worker thread finished; awaiting finish of 6 more threads\n", + "2017-09-22 18:38:41,640 : INFO : worker thread finished; awaiting finish of 5 more threads\n", + "2017-09-22 18:38:41,641 : INFO : worker thread finished; awaiting finish of 4 more threads\n", + "2017-09-22 18:38:41,643 : INFO : worker thread finished; awaiting finish of 3 more threads\n", + "2017-09-22 18:38:41,645 : INFO : worker thread finished; awaiting finish of 2 more threads\n", + "2017-09-22 18:38:41,646 : INFO : worker thread finished; awaiting finish of 1 more threads\n", + "2017-09-22 18:38:41,649 : INFO : worker thread finished; awaiting finish of 0 more threads\n", + "2017-09-22 18:38:41,650 : INFO : training on 3473210 raw words (2494104 effective words) took 2.0s, 1233235 effective words/s\n" + ] + } + ], + "source": [ + "# train word2vec on the sentences to initialize the word embedding \n", + "import gensim, logging\n", + "\n", + "logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)\n", + "# use skip-gram\n", + "word2vec_model = gensim.models.Word2Vec(doc_lst, min_count=3, size=EMBEDDING_DIM, sg=1, workers=os.cpu_count())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "Create the initial embedding matrix from the output of word2vec." + ] + }, + { + "cell_type": "code", + "execution_count": 341, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total 5984 word vectors.\n" + ] + } + ], + "source": [ + "embeddings_index = {}\n", + "\n", + "for word in word2vec_model.wv.vocab:\n", + " coefs = np.asarray(word2vec_model.wv[word], dtype='float32')\n", + " embeddings_index[word] = coefs\n", + "\n", + "print('Total %s word vectors.' % len(embeddings_index))" + ] + }, + { + "cell_type": "code", + "execution_count": 342, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "# Initial embedding\n", + "embedding_matrix = np.zeros((MAX_NB_WORDS + 1, EMBEDDING_DIM))\n", + "\n", + "for word, i in word_index.items():\n", + " embedding_vector = embeddings_index.get(word)\n", + " if embedding_vector is not None and i < MAX_NB_WORDS:\n", + " embedding_matrix[i] = embedding_vector\n", + " elif i == MAX_NB_WORDS:\n", + " # index MAX_NB_WORDS in data corresponds to 'UNK'\n", + " embedding_matrix[i] = embeddings_index['UNK']" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "source": [ + "Define the network.\n", + "The mask_zero option determines whether masking is performed, i.e. whether the layers ignore the padded zeros in shorter documents." + ] + }, + { + "cell_type": "code", + "execution_count": 352, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true, + "scrolled": true + }, + "outputs": [], + "source": [ + "# building Hierachical Attention network\n", + "\n", + "REG_PARAM = 1e-10\n", + "l2_reg = regularizers.l2(REG_PARAM)\n", + "\n", + "embedding_layer = Embedding(MAX_NB_WORDS + 1,\n", + " EMBEDDING_DIM,\n", + " input_length=MAX_SENT_LENGTH,\n", + " trainable=True,\n", + " mask_zero=True,\n", + " embeddings_regularizer=l2_reg,\n", + " weights=[embedding_matrix])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "Define a custom layer implementing the attention mechanism." + ] + }, + { + "cell_type": "code", + "execution_count": 353, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "CONTEXT_DIM = 100\n", + "\n", + "class AttLayer(Layer):\n", + " def __init__(self, regularizer=None, **kwargs):\n", + " self.regularizer = regularizer\n", + " self.supports_masking = True\n", + " super(AttLayer, self).__init__(**kwargs)\n", + "\n", + " def build(self, input_shape):\n", + " assert len(input_shape) == 3 \n", + " self.W = self.add_weight(name='W', shape=(input_shape[-1], CONTEXT_DIM), initializer='normal', trainable=True, \n", + " regularizer=self.regularizer)\n", + " self.b = self.add_weight(name='b', shape=(CONTEXT_DIM,), initializer='normal', trainable=True, \n", + " regularizer=self.regularizer)\n", + " self.u = self.add_weight(name='u', shape=(CONTEXT_DIM,), initializer='normal', trainable=True, \n", + " regularizer=self.regularizer) \n", + " super(AttLayer, self).build(input_shape) # be sure you call this somewhere!\n", + "\n", + " def call(self, x, mask=None):\n", + " eij = K.dot(K.tanh(K.dot(x, self.W) + self.b), self.u)\n", + " ai = K.exp(eij)\n", + " alphas = ai / K.sum(ai, axis=1).dimshuffle(0, 'x')\n", + " if mask is not None:\n", + " # use only the inputs specified by the mask\n", + " alphas *= mask\n", + " weighted_input = x * alphas.dimshuffle(0, 1, 'x')\n", + " return weighted_input.sum(axis=1)\n", + "\n", + " def compute_output_shape(self, input_shape):\n", + " return (input_shape[0], input_shape[-1])\n", + " \n", + " def get_config(self):\n", + " config = {}\n", + " base_config = super(AttLayer, self).get_config()\n", + " return dict(list(base_config.items()) + list(config.items()))\n", + "\n", + " def compute_mask(self, inputs, mask):\n", + " return None" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "GRU_UNITS is the dimensionality of each GRU output (the number of GRU units). GRU_IMPL = 2 selects a matricized RNN implementation which is more appropriate for training on a GPU. \n", + "\n", + "There are two levels of models in the definition. The sentence model `sentEncoder` is shared across all sentences in the input document. " + ] + }, + { + "cell_type": "code", + "execution_count": 354, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "GPU_IMPL = 2 \n", + "GRU_UNITS = 100 \n", + "\n", + "sentence_input = Input(shape=(MAX_SENT_LENGTH,), dtype='int32')\n", + "embedded_sequences = embedding_layer(sentence_input)\n", + "l_lstm = Bidirectional(GRU(GRU_UNITS, return_sequences=True, kernel_regularizer=l2_reg, \n", + " implementation=GPU_IMPL))(embedded_sequences)\n", + "l_att = AttLayer(regularizer=l2_reg)(l_lstm) \n", + "sentEncoder = Model(sentence_input, l_att)\n", + "\n", + "review_input = Input(shape=(MAX_SENTS, MAX_SENT_LENGTH), dtype='int32')\n", + "review_encoder = TimeDistributed(sentEncoder)(review_input)\n", + "l_lstm_sent = Bidirectional(GRU(GRU_UNITS, return_sequences=True, kernel_regularizer=l2_reg, \n", + " implementation=GPU_IMPL))(review_encoder)\n", + "l_att_sent = AttLayer(regularizer=l2_reg)(l_lstm_sent) \n", + "preds = Dense(n_classes, activation='softmax', kernel_regularizer=l2_reg)(l_att_sent)\n", + "model = Model(review_input, preds)" + ] + }, + { + "cell_type": "code", + "execution_count": 355, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "model.compile(loss='categorical_crossentropy',\n", + " optimizer=optimizers.SGD(lr=0.8, nesterov=True),\n", + " metrics=['acc'])" + ] + }, + { + "cell_type": "code", + "execution_count": 356, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true, + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "input_32 (InputLayer) (None, 15, 50) 0 \n", + "_________________________________________________________________\n", + "time_distributed_16 (TimeDis (None, 15, 200) 740900 \n", + "_________________________________________________________________\n", + "bidirectional_32 (Bidirectio (None, 15, 200) 180600 \n", + "_________________________________________________________________\n", + "att_layer_32 (AttLayer) (None, 200) 20200 \n", + "_________________________________________________________________\n", + "dense_16 (Dense) (None, 2) 402 \n", + "=================================================================\n", + "Total params: 942,102\n", + "Trainable params: 942,102\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 357, + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "fname = 'han_food2'\n", + "history = History()\n", + "csv_logger = CSVLogger('./{0}_{1}.log'.format(fname, REG_PARAM), separator=',', append=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "Order training data by the number of sentences in document (as suggested in the [Yang et al.] paper)." + ] + }, + { + "cell_type": "code", + "execution_count": 358, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "doc_lengths = [len(r) for r in reviews]\n", + "ind = np.argsort(doc_lengths)" + ] + }, + { + "cell_type": "code", + "execution_count": 359, + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "BATCH_SIZE = 50\n", + "NUM_EPOCHS = 20" + ] + }, + { + "cell_type": "code", + "execution_count": 360, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/20\n", + "58s - loss: 0.6869 - acc: 0.5516\n", + "Epoch 2/20\n", + "57s - loss: 0.6629 - acc: 0.6292\n", + "Epoch 3/20\n", + "57s - loss: 0.4835 - acc: 0.7710\n", + "Epoch 4/20\n", + "57s - loss: 0.3469 - acc: 0.8512\n", + "Epoch 5/20\n", + "57s - loss: 0.2877 - acc: 0.8860\n", + "Epoch 6/20\n", + "57s - loss: 0.2561 - acc: 0.8969\n", + "Epoch 7/20\n", + "57s - loss: 0.2222 - acc: 0.9151\n", + "Epoch 8/20\n", + "57s - loss: 0.1908 - acc: 0.9288\n", + "Epoch 9/20\n", + "57s - loss: 0.1659 - acc: 0.9393\n", + "Epoch 10/20\n", + "57s - loss: 0.1517 - acc: 0.9473\n", + "Epoch 11/20\n", + "57s - loss: 0.1335 - acc: 0.9523\n", + "Epoch 12/20\n", + "57s - loss: 0.1370 - acc: 0.9499\n", + "Epoch 13/20\n", + "57s - loss: 0.1163 - acc: 0.9603\n", + "Epoch 14/20\n", + "57s - loss: 0.1138 - acc: 0.9615\n", + "Epoch 15/20\n", + "57s - loss: 0.0921 - acc: 0.9711\n", + "Epoch 16/20\n", + "57s - loss: 0.0962 - acc: 0.9688\n", + "Epoch 17/20\n", + "57s - loss: 0.0745 - acc: 0.9759\n", + "Epoch 18/20\n", + "57s - loss: 0.0713 - acc: 0.9779\n", + "Epoch 19/20\n", + "57s - loss: 0.0570 - acc: 0.9822\n", + "Epoch 20/20\n", + "57s - loss: 0.0491 - acc: 0.9861\n" + ] + } + ], + "source": [ + "t1 = time.time()\n", + "\n", + "model.fit(x_train[ind,:,:], y_train[ind,:], epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, shuffle=False, \n", + " callbacks=[history, csv_logger], verbose=2)\n", + "\n", + "t2 = time.time()" + ] + }, + { + "cell_type": "code", + "execution_count": 361, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "test_reviews = []\n", + "test_labels = []\n", + "test_texts = []\n", + "\n", + "for idx in range(test_data.shape[0]):\n", + " text = test_data['text'].iloc[idx]\n", + " test_texts.append(text)\n", + " sentences = nltk.tokenize.sent_tokenize(text)\n", + " test_reviews.append(sentences)\n", + " test_labels.append(test_data['rating'].iloc[idx])" + ] + }, + { + "cell_type": "code", + "execution_count": 362, + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "data2 = np.zeros((len(test_texts), MAX_SENTS, MAX_SENT_LENGTH), dtype='int32')\n", + "\n", + "for i, sentences in enumerate(test_reviews):\n", + " for j, sent in enumerate(sentences):\n", + " if j < MAX_SENTS:\n", + " wordTokens = text_to_word_sequence(sent)\n", + " k = 0\n", + " words_in_sent = []\n", + " for _, word in enumerate(wordTokens):\n", + " if k < MAX_SENT_LENGTH: \n", + " if (word in tokenizer.word_index) and (tokenizer.word_index[word] < MAX_NB_WORDS):\n", + " data2[i, j, k] = tokenizer.word_index[word]\n", + " words_in_sent.append(word)\n", + " else:\n", + " data2[i, j, k] = MAX_NB_WORDS\n", + " words_in_sent.append('UNK')\n", + " k = k + 1" + ] + }, + { + "cell_type": "code", + "execution_count": 363, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "y_test = to_categorical(np.asarray(test_labels))\n", + "x_test = data2" + ] + }, + { + "cell_type": "code", + "execution_count": 364, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "from sklearn.metrics import accuracy_score, roc_auc_score, roc_curve" + ] + }, + { + "cell_type": "code", + "execution_count": 365, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy = 0.8887 \t AUC = 0.9590876492256548\n" + ] + } + ], + "source": [ + "preds = model.predict(x_test)\n", + "print(\"Accuracy = {0} \\t AUC = {1}\".format(accuracy_score(test_labels, preds.argmax(axis=1)),\n", + " roc_auc_score(test_labels, preds[:, 1])))" + ] + }, + { + "cell_type": "code", + "execution_count": 366, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Populating the interactive namespace from numpy and matplotlib\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: pylab import has clobbered these variables: ['text']\n", + "`%matplotlib` prevents importing * from pylab and numpy\n" + ] + } + ], + "source": [ + "%pylab inline" + ] + }, + { + "cell_type": "code", + "execution_count": 367, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 367, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiAAAAF5CAYAAACm4JG+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzt3XuUXFWZ9/Hvk4QkBJJACCSgaEAhBBmRNCiIoA5CZBTf\nYcRLuIjoQiO49I0jynhjxIXoiGZQyYBXwoxGUWdeL4MEAW/cJQGBecOdEEEJ4ZIAuXYn+/1jV790\n2nTS3ek6uy7fz1q1dtWuc6qeOqu769f77HNOpJSQJEmq0rDSBUiSpPZjAJEkSZUzgEiSpMoZQCRJ\nUuUMIJIkqXIGEEmSVDkDiCRJqpwBRJIkVc4AIkmSKmcAkSRJlWuIABIRR0TEzyLi0YjYGBFv6cc6\nr4uIhRGxNiLujYhTq6hVkiRtu4YIIMAOwO3AGcBWL04TEVOAXwDXAAcCFwLfioij61eiJEkaKtFo\nF6OLiI3A36eUfraFZb4IHJtSenmPvvnA+JTS31VQpiRJ2gaNMgIyUIcCV/fqWwAcVqAWSZI0QM0a\nQCYDy3r1LQPGRcSoAvVIkqQBGFG6gKpExC7ADGAJsLZsNZIkNZXRwBRgQUrpyaF4wWYNII8Bk3r1\nTQKeSSmt62OdGcD36lqVJEmt7STg+0PxQs0aQG4Eju3Vd0ytvy9LAP7jP/6DadOm1aks9TZ79mzm\nzJlTuoyWt2wZ3HMPbNgA3/nObN71rjls3AiPPgqjRsEjj8Djj8Pw4bBx46a39evhD394/rUiYCjn\npo8cCcOGPX/r6oK1a+GFL8yPIza9DavtGO5uV6yAffaBESPguefgRS96frme6w8blp/fe+/8mfbY\nY9P3HTYMOjtht93yNtlhBxgzJm+T7baDHXcc/Gf057x6bvNqLV68mJNPPhlq36VDoSECSETsALwU\niFrX3hFxIPBUSulPEXE+sEdKqftcHxcDZ9aOhvkOcBRwArClI2DWAkybNo3p06fX42NoM8aPH+/2\n7qc1a2D16vwluXYtPPwwrFoFS5bkL8yuLnjgAdh+e7jrrvz41lvhz3/u/UrjOfvsTbf52LGwbh1M\nngwve1n+Mh4+/Pl20iQYPx5e/er8uPcXfM82JXjpSzft63mLgJ13hgkT8hd8O/DnvHpu82KGbApD\nQwQQ4GDg1+RzgCTgy7X+ecB7yJNO9+xeOKW0JCLeBMwBPgQ8Arw3pdT7yBipMp2d8Mc/wpNPwtNP\n59GH9etzWFi+PI8E/M//5P+8N27MAeK++3L/+vVbf/1hw/J6AFOm5LByxBE5qLz3vXD00TBuHLzr\nXfCDH+QgMWIEjB5d148tSYPSEAEkpfRbtnBETkrptM30/Q7oqGddUldX3nVx663w2GM5KDzwQP5S\nX7cO7rwz/7d/4415N0dPI0bk57q68kjAuHHwylfmUY19983PDx+eRz722Se3e++dA8mIEXlXwa67\n5l0D48c/v0tia7bbLi8vSY2sIQKIVLX163NguO02uOaaPJpw6615vsGoUXmkYcmSv15v1Ki8C2L9\nepg6Nc+5WL4cDjwQjjoKjjwSDj8cdtoph4f+hgZJajcGENXVzJkzK3/P9evz7o/rroOf/hTuvz+P\nKHR15V0gEXl3SU877ZRHISZMyGFi3Li8zrBhedLjfvvBIYfkvkZXYpu3O7d59dzmza/hTsVeLxEx\nHVi4cOFCJy41oc5OuPrqPGIxenR+fNtt8MQTORQsXpxDw113/fW6u+0G06bB3/xN3j2xfj284AV5\nHsVLXgLTpzdHsJCkUhYtWkRHRwdAR0pp0VC8pn921VBWrYK774bvfheWLs1Hg/zqV5su032URVdX\nvt/RAQccALvsAjNm5FGLV7wiH+Z5yCH5qBFJUmMxgKiIFSvyaMVvf5sncD79dJ6TsXTp88vssQcc\neigcd1x+/NGPwqteledhSJKamwFEdffoo7BoEfzmNzBvXj5MtbdXvhLe+tY832L33eH1r8+7TiRJ\nrckAoiF35515fsZ558G99/71869/PRxzTJ57sf/+eT5GxF8vJ0lqXQYQbZNnnoEf/Qhuvz3P3bju\nujxvo9uMGbDXXvDOd8JBB+WJopIkGUA0IPfdB9//fp4k2tWVd690O+igfCjrvvvCF7+YT/u9ww7l\napUkNS4DiLbquefg1FPhqqvy/W7Tp8NrXgOzZuVriIwcWa5GSVJzMYCoT11d8NWvwj/+4/N9b34z\nfPrT+fBW521IkgbLAKK/8q1vwTnnbHqV1Usugfe9r1xNkqTWYgDRJq69Fk4/Pd//4Afz0SpvfrOj\nHZKkoWUAaXNXXglnn/38dVIgn5Z86dJ8Pg5JkurBANKGVq+GuXPh3/4NHnww9+28M7z97fDSl+Yz\nju68c9kaJUmtzQDSRlauhM9+FubMeb7vwgvhtNNg7NhydUmS2o8BpMV1dcEvfgHHH/983y67wNe+\nBieckK8OK0lS1QwgLWrDhnwxt8cff77viCPgk5/ME0udVCpJKmlY6QI0tFavhp/8JO9S6Q4fc+fm\n/t/9Lp8a3fAhSSrNEZAW8uSTeRLpihX58ac/DeeeW7YmSZI2xwDSIpYuhRe/ON/fZ598cbgxY8rW\nJElSX9wF0+RSgje96fnwccEFcO+9hg9JUmMzgDSpu++G978fhg2DK66Aww+HRx7Z9LotkiQ1KnfB\nNJmNG+G443LoAJg6Nc/1OOmksnVJkjQQBpAm8oUvwD/9U76/225w/fV50qkkSc3GANIEUsoTSx94\nID++4AJ3tUiSmpsBpAnMmJHDx8EHw29/6wRTSVLzM4A0sJtvhssug1/9CiZMgFtu8SRikqTWYABp\nUBdfDB/4QL5/2GHwmc8YPiRJrcMA0oBWrnw+fDzzjFeqlSS1Hs8D0mDOOw922infv/BCw4ckqTU5\nAtJAzj8fPvWpfP8nP4F/+Iey9UiSVC8GkAaxbBl84hP5vrtdJEmtzl0wDeBrX4PJk/P9yy83fEiS\nWp8jIIXNmwcf+lC+f8UVcOyxZeuRJKkKBpCCbr4Z3v3ufP+pp2DnnYuWI0lSZdwFU8iTT8Khh+b7\nv/mN4UOS1F4MIAUsXQoTJ+b7v/wlvPa1ZeuRJKlqBpACZs3K7cUXwxvfWLYWSZJKMIBUbNasPOox\nbhy8//2lq5EkqQwDSIWefBIuuSTfv/vusrVIklSSAaRCc+fm9pZbYPfdy9YiSVJJBpAKfe5zuT3k\nkLJ1SJJUmgGkIj/6EXR2wkc+UroSSZLKM4BUYMMGOOusfP+CC8rWIklSIzCAVOAzn4GHH86nWY8o\nXY0kSeUZQOrsmWfg85/P93/2s7K1SJLUKAwgddTZCePH5/v/+Z8wwivvSJIEGEDqZu1aeMEL8v3j\nj883SZKUGUDq5KyzYPlyOPts+MlPSlcjSVJjcadAnXz967k9//yydUiS1IgaZgQkIs6MiIciYk1E\n3BQRWzxdV0ScFBG3R8SqiPhzRHw7IiZUVe+WLFmS21NOKVqGJEkNqyECSES8A/gycA5wEPBHYEFE\nTOxj+cOBecA3gf2BE4BXAt+opOCt+OhHc/uJT5StQ5KkRtUQAQSYDVySUrospXQ3MAtYDbynj+UP\nBR5KKV2UUno4pXQDcAk5hBT1q1/lOR/HHw/77Ve6GkmSGlPxABIR2wEdwDXdfSmlBFwNHNbHajcC\ne0bEsbXXmAS8Dfjv+la7dR/7WG6/972ydUiS1MiKBxBgIjAcWNarfxkweXMr1EY8TgZ+GBHrgb8A\nTwMfrGOd/bJ8ORxwAGy/felKJElqXI0QQAYsIvYHLgT+GZgOzAD2Iu+GKebKK+HRR+Hkk0tWIUlS\n42uEw3CfADYAk3r1TwIe62Ods4HrU0pfqT2+KyLOAH4fEZ9MKfUeTfn/Zs+ezfju05PWzJw5k5kz\nZw6q+J4uvji3733vNr+UJElFzJ8/n/nz52/St3LlyiF/n+IBJKXUGRELgaOAnwFERNQef7WP1cYA\n63v1bQQSsMXLvc2ZM4fp06dvU82bc8UV8NOf5kNvJ2722B1Jkhrf5v4pX7RoER0dHUP6Po2yC+Yr\nwOkR8a6I2A+4mBwyLgWIiPMjYl6P5X8OvDUiZkXEXrXDci8Ebk4p9TVqUlfvfGduP/WpEu8uSVJz\nKT4CApBSurx2zo9zybtebgdmpJSW1xaZDOzZY/l5EbEjcCZwAbCCfBTN2ZUWXvPLX8Kzz8KHPwz7\n7luiAkmSmktDBBCAlNJcYG4fz522mb6LgIvqXVd/dJ92/YILytYhSVKzaJRdME3r6afz/I9jjoER\nDRPnJElqbAaQbdQ96nHWWWXrkCSpmRhAttEPfgCTJsFRR5WuRJKk5mEA2QbLl8ODD8IJJ0Bs8eBf\nSZLUkwFkG9x2W25POaVsHZIkNRsDyDb49a9zO3Vq2TokSWo2BpBBSgm+8AXYc0/YaafS1UiS1FwM\nIIP0vvfldvjwsnVIktSMDCCDcNtt8K1v5fv331+2FkmSmpEBZBBmz87tnXc6AiJJ0mAYQAbhrrvg\nVa+CAw4oXYkkSc3JADJA69bBk0/C0UeXrkSSpOZlABmgBQty29FRtg5JkpqZAWSALrwwt699bdk6\nJElqZgaQAbr2WnjZy2DnnUtXIklS8zKADMCf/pTbU08tW4ckSc3OADIAX/96bt/2trJ1SJLU7Awg\nA9DZCWPHwpQppSuRJKm5GUAG4NZbYZ99SlchSVLzM4D0U0rw+9/DqFGlK5EkqfkZQPrppz/N7Qkn\nlK1DkqRWYADpp3/919yeeWbZOiRJagUGkH664w742791F4wkSUPBANIPjz0GTz8NRx1VuhJJklqD\nAaQf7rort697XdEyJElqGQaQfrjsstxOnVq2DkmSWoUBpB+uvBKOPhp22aV0JZIktQYDyFZ0dcHy\n5fCSl5SuRJKk1mEA2YqVK3N74IFl65AkqZUYQLbi4Ydzu9deZeuQJKmVGEC24s47c2sAkSRp6BhA\ntuKb38ytV8CVJGnoGEC2oKsLrr8eXv5yGDmydDWSJLUOA8gWXHJJbr3+iyRJQ8sAsgXPPpvb008v\nW4ckSa3GALIFy5bB5MkQUboSSZJaiwFkC665Boa5hSRJGnJ+vfbhqafyRehOOaV0JZIktR4DSB9W\nrICU4A1vKF2JJEmtxwDSh6eeyu3225etQ5KkVmQA6cNtt+V2993L1iFJUisygPTh7rtz++IXl61D\nkqRWZADpwy23wC67wPDhpSuRJKn1GEA2Y+VKuO46OOKI0pVIktSaDCCbcd11uT3jjLJ1SJLUqgwg\nm/HQQ7l99avL1iFJUqsygGzGVVfldsyYsnVIktSqDCCbsWED7Luv14CRJKleDCC9pARXXAGHHFK6\nEkmSWpcBpJfFi3Pr4beSJNWPAaSX887btJUkSUOvYQJIRJwZEQ9FxJqIuCkitrgTJCJGRsR5EbEk\nItZGxIMR8e5traP7GjAvfOG2vpIkSerLiNIFAETEO4AvA+8DbgFmAwsiYt+U0hN9rPYjYFfgNOAB\nYHeGIFCtXQsnnritryJJkrakIQIIOXBcklK6DCAiZgFvAt4D/EvvhSPijcARwN4ppRW17qVDUcj6\n9TBy5FC8kiRJ6kvxXTARsR3QAVzT3ZdSSsDVwGF9rHYccCvw8Yh4JCLuiYgvRcTobaklJbjhBhhW\nfKtIktTaGmEEZCIwHFjWq38ZMLWPdfYmj4CsBf6+9hr/BkwA3jvYQu66K7cvetFgX0GSJPVHIwSQ\nwRgGbAROTCk9BxARHwF+FBFnpJTW9bXi7NmzGT9+/CZ9M2fOZObMmfzud/nxrFn1KluSpMY2f/58\n5s+fv0nfypUrh/x9GiGAPAFsACb16p8EPNbHOn8BHu0OHzWLgQBeSJ6Uullz5sxh+vTpm3/Rv+R2\n1137UbUkSS2o+5/ynhYtWkRHR8eQvk/x2Q4ppU5gIXBUd19ERO3xDX2sdj2wR0T0vFrLVPKoyCOD\nreXBB+Hgg50DIklSvTXKV+1XgNMj4l0RsR9wMTAGuBQgIs6PiHk9lv8+8CTw3YiYFhFHko+W+faW\ndr9szapVMHnyYNeWJEn91RABJKV0OfBR4FzgNuDlwIyU0vLaIpOBPXssvwo4GtgJ+APw78BPgQ9v\nSx0LFngFXEmSqtAIc0AASCnNBeb28dxpm+m7F5gxVO+/fDmsW+cRMJIkVaEhRkAawY035vZtbytb\nhyRJ7cAAUrN2bW6nTStbhyRJ7cAAUtN9DpDR23QuVUmS1B8GkJqFC3M7omFmxUiS1LoMIDWrVsGp\np0JE6UokSWp9BhBgwwa480548YtLVyJJUnswgACLF+d2333L1iFJUrswgAB33JHbQw4pW4ckSe3C\nAAKsX5/bKVOKliFJUtswgAA//3luR44sW4ckSe3CAAI89JAnIJMkqUoGkJpddildgSRJ7cMAAnR2\nwkEHla5CkqT2YQAhT0J1/ockSdUZsgASEf8QEXcM1etVac0aA4gkSVUaUACJiPdHxI8j4vsR8apa\n399GxG3AvwPX16PIenr2WfjTn2DXXUtXIklS++h3AImIs4GvAS8G3gJcGxGfAL4H/BB4YUrpA3Wp\nso5WrsytR8FIklSdgVz79TTg9JTSvIg4Avgt8GrgpSmlVXWprgKPPprb7bcvW4ckSe1kILtgXgRc\nC5BS+j3QCZzTzOEDYMWK3O62W9k6JElqJwMJIKOAtT0erweeGtpyqre29okmTixbhyRJ7WQgu2AA\nPhcRq2v3RwKfioiVPRdIKX1kSCqryHPP5Xb06LJ1SJLUTgYSQH4HTO3x+AZg717LpG2uqGJLl8JO\nO8GOO5auRJKk9tHvAJJSel0d6yjm4Yfz7peI0pVIktQ+BrQLJiLGAoeSd7/cklJaXpeqKnTNNbD/\n/qWrkCSpvQzkPCCvAO4BFgA/B+6PiBn1KqwKXV1w//0wZUrpSiRJai8DOQrmi8CD5HN/dADXAF+v\nR1FVeeyx3B55ZNk6JElqNwPZBdMBHJNSWgQQEe8BnoqIcSmlZ+pSXZ0tW5bbnXYqW4ckSe1mICMg\nE4BHuh+klFYAq4BdhrqoqnR15XbSpLJ1SJLUbgZ6HpD9I2Jyj8cBTKtNTgUgpdQ0V8Tt7MztiIFu\nBUmStE0G+tV7DTl09PQL8vk/otYOH4K6KtE9ArLddmXrkCSp3QwkgOxVtyoKcQREkqQyBvLVeypw\nQUpp9VaXbBLdV8IdO3bLy0mSpKE1kEmo5wAtdcLye++FCRPyTZIkVWcgAaTlTlb++OOwd++r2UiS\npLobSACBJrzY3JbceKMTUCVJKmGg0y/vjYgthpCUUtPs0NhxR09CJklSCQMNIOcAK+tRSAmrV8Or\nXlW6CkmS2s9AA8gPUkqP16WSAh55BHbbrXQVkiS1n4HMAWmp+R8Azz0HEyeWrkKSpPbTtkfBPP10\nPhPqyJGlK5Ekqf30exdMSmmgR8w0tHvvze2LXlS2DkmS2lFLhYqB+K//yq2TUCVJql7bBpARIyDC\n07BLklRC2waQzk7PgipJUiltG0C6urwKriRJpbRtAOns9DTskiSV0rYB5MYbYfjw0lVIktSe2jaA\nrFgBe+1VugpJktpT2waQ556D6dNLVyFJUntq2wCyahWMHl26CkmS2lNbBpC1a+HZZ2HDhtKVSJLU\nnhomgETEmRHxUESsiYibIuKQfq53eER0RsSi/r7XypW59TwgkiSV0RABJCLeAXwZOAc4CPgjsCAi\ntnit2ogYD8wDrh7I+61Zk9sJEwZRrCRJ2mYNEUCA2cAlKaXLUkp3A7OA1cB7trLexcD3gJsG8mYP\nPZTbXXYZcJ2SJGkIFA8gEbEd0AFc092XUkrkUY3DtrDeacBewGcH+p431eLKPvsMdE1JkjQUGuFk\n5BOB4cCyXv3LgKmbWyEi9gE+D7wmpbQxIgb0hk88AaNGwY47DqJaSZK0zYqPgAxURAwj73Y5J6X0\nQHf3QF7jwQc9BFeSpJIaYQTkCWADMKlX/yTgsc0sPxY4GHhFRFxU6xsGRESsB45JKf2mrzebPXs2\nf/jDeMaMgbe8JffNnDmTmTNnbtOHkCSpFcyfP5/58+dv0rey+/DRIRR5ukVZEXETcHNK6cO1xwEs\nBb6aUvpSr2UDmNbrJc4EXg+8FViSUlqzmfeYDixcuHAhp5wynWnT4Mc/rsOHkSSpxSxatIiOjg6A\njpRSv097sSWNMAIC8BXg0ohYCNxCPipmDHApQEScD+yRUjq1NkH1//ZcOSIeB9amlBb35806Oz0H\niCRJJTVEAEkpXV4758e55F0vtwMzUkrLa4tMBvYcqve7//48CVWSJJXREAEEIKU0F5jbx3OnbWXd\nz9LPw3EffxxSghEN88klSWo/TXcUzLa6777cHn982TokSWpnbRdA1q/P7e67l61DkqR21nYBpLMz\nt84BkSSpnLYLIPffn1sDiCRJ5bRdAFm3LrcGEEmSymm7ALJ0Key8c+kqJElqb20XQLbfHqZu9hJ3\nkiSpKm0XQFKCMWNKVyFJUntrywASA7p2riRJGmptF0DAACJJUmltF0AcAZEkqTwDiCRJqpwBRJIk\nVc4AIkmSKtd2AQQMIJIkldZ2AcQREEmSyjOASJKkyhlAJElS5QwgkiSpcm0XQMAAIklSaW0XQBwB\nkSSpPAOIJEmqnAFEkiRVzgAiSZIq13YBBAwgkiSV1nYBxBEQSZLKM4BIkqTKGUAkSVLl2i6A3HGH\nAUSSpNLaLoB0dcGwtvvUkiQ1lrb7Ku7qgte9rnQVkiS1t7YLIAA77li6AkmS2ltbBpDJk0tXIElS\ne2vLADJiROkKJElqbwYQSZJUOQOIJEmqXFsGkOHDS1cgSVJ7a8sAMm5c6QokSWpvbRlAJkwoXYEk\nSe2tLQOIc0AkSSrLACJJkirXlgFku+1KVyBJUntrywDiCIgkSWW1XQAZN86r4UqSVFrbfRWPHl26\nAkmS1HYBxJOQSZJUngFEkiRVzgAiSZIq13YBxAmokiSV13Zfx46ASJJUXtsFEK8DI0lSeQ0TQCLi\nzIh4KCLWRMRNEXHIFpY9PiKuiojHI2JlRNwQEcf05322337oapYkSYPTEAEkIt4BfBk4BzgI+COw\nICIm9rHKkcBVwLHAdODXwM8j4sCtvZenYZckqbyGCCDAbOCSlNJlKaW7gVnAauA9m1s4pTQ7pXRB\nSmlhSumBlNIngfuA47b2RgYQSZLKKx5AImI7oAO4prsvpZSAq4HD+vkaAYwFntrasmPHDq5OSZI0\ndIoHEGAiMBxY1qt/GTC5n69xFrADcPnWFnQSqiRJ5TX9dWEj4kTg08BbUkpPbG15rwUjSVJ5jRBA\nngA2AJN69U8CHtvSihHxTuAbwAkppV/3582uvHI2DzwwfpO+mTNnMnPmzH4XLElSq5o/fz7z58/f\npG/lypVD/j6Rp1uUFRE3ATenlD5cexzAUuCrKaUv9bHOTOBbwDtSSr/ox3tMBxZ+4QsL+fjHpw9d\n8ZIktbhFixbR0dEB0JFSWjQUr9kIIyAAXwEujYiFwC3ko2LGAJcCRMT5wB4ppVNrj0+sPfch4A8R\n0T16sial9MyW3sgzoUqSVF5DBJCU0uW1c36cS971cjswI6W0vLbIZGDPHqucTp64elHt1m0efRy6\n281rwUiSVF5DBBCAlNJcYG4fz53W6/HrB/s+BhBJkspru69jA4gkSeW13dexAUSSpPLa7uvYSaiS\nJJXXdgHEERBJkspru6/jyf09ubskSaqbtgsgu+1WugJJktR2AWTUqNIVSJKktgsgkiSpPAOIJEmq\nnAFEkiRVzgAiSZIqZwCRJEmVM4BIkqTKGUAkSVLlDCCSJKlyBhBJklQ5A4gkSaqcAUSSJFXOACJJ\nkipnAJEkSZUzgEiSpMoZQCRJUuUMIJIkqXIGEEmSVDkDiCRJqpwBRJIkVc4AIkmSKmcAkSRJlTOA\nSJKkyhlAJElS5QwgkiSpcgYQSZJUOQOIJEmqnAFEkiRVzgAiSZIqZwCRJEmVM4BIkqTKGUAkSVLl\nDCCSJKlyBhBJklQ5A4gkSaqcAUSSJFXOACJJkipnAJEkSZUzgEiSpMoZQCRJUuUMIJIkqXIGEEmS\nVDkDiCRJqpwBRJIkVc4AIkmSKtcwASQizoyIhyJiTUTcFBGHbGX510XEwohYGxH3RsSpVdWq/ps/\nf37pEtqO27x6bvPquc2bX0MEkIh4B/Bl4BzgIOCPwIKImNjH8lOAXwDXAAcCFwLfioijq6hX/ecf\nieq5zavnNq+e27z5NUQAAWYDl6SULksp3Q3MAlYD7+lj+Q8AD6aUPpZSuieldBHw49rrSJKkBlc8\ngETEdkAHeTQDgJRSAq4GDutjtUNrz/e0YAvLS5KkBlI8gAATgeHAsl79y4DJfawzuY/lx0XEqKEt\nT5IkDbURpQuo0GiAxYsXl66jraxcuZJFixaVLqOtuM2r5zavntu8Wj2+O0cP1Ws2QgB5AtgATOrV\nPwl4rI91Hutj+WdSSuv6WGcKwMknnzy4KjVoHR0dpUtoO27z6rnNq+c2L2IKcMNQvFDxAJJS6oyI\nhcBRwM8AIiJqj7/ax2o3Asf26jum1t+XBcBJwBJg7TaULElSuxlNDh8LhuoFI8/3LCsi3g5cSj76\n5Rby0SwnAPullJZHxPnAHimlU2vLTwHuBOYC3yGHlX8F/i6l1HtyqiRJajDFR0AAUkqX1875cS55\nV8rtwIyU0vLaIpOBPXssvyQi3gTMAT4EPAK81/AhSVJzaIgREEmS1F4a4TBcSZLUZgwgkiSpci0T\nQLyYXfUGss0j4viIuCoiHo+IlRFxQ0QcU2W9rWCgP+c91js8IjojwhMnDNAg/raMjIjzImJJ7e/L\ngxHx7orKbQmD2OYnRcTtEbEqIv4cEd+OiAlV1dvsIuKIiPhZRDwaERsj4i39WGebv0NbIoB4Mbvq\nDXSbA0cCV5EPn54O/Br4eUQcWEG5LWEQ27x7vfHAPP768gXaikFu8x8BrwdOA/YFZgL31LnUljGI\nv+eHk3++vwnsTz6C8pXANyopuDXsQD744wxgqxNDh+w7NKXU9DfgJuDCHo+DfGTMx/pY/ovAHb36\n5gNXlP7vDYcBAAAFCklEQVQszXIb6Dbv4zXuAj5V+rM0y22w27z2s/1Z8h/0RaU/RzPdBvG35Y3A\nU8BOpWtv1tsgtvk/Avf16vsgsLT0Z2nGG7AReMtWlhmS79CmHwHxYnbVG+Q27/0aAYwl/7HWVgx2\nm0fEacBe5ACiARjkNj8OuBX4eEQ8EhH3RMSXImLITl/dyga5zW8E9oyIY2uvMQl4G/Df9a22rQ3J\nd2jTBxC8mF0Jg9nmvZ1FHva7fAjramUD3uYRsQ/weeCklNLG+pbXkgbzc743cATwMuDvgQ+Tdwlc\nVKcaW82At3lK6QbgZOCHEbEe+AvwNHkURPUxJN+hrRBA1GQi4kTg08DbUkpPlK6nFUXEMOB7wDkp\npQe6uwuW1C6GkYewT0wp3ZpSuhL4CHCq/9zUR0TsT56D8M/k+WUzyKN+lxQsS/3QEGdC3UZVXcxO\nzxvMNgcgIt5Jnhx2Qkrp1/UpryUNdJuPBQ4GXhER3f99DyPv/VoPHJNS+k2dam0Vg/k5/wvwaErp\nuR59i8nh74XAA5tdS90Gs83PBq5PKX2l9viuiDgD+H1EfDKl1Ps/dW27IfkObfoRkJRSJ9B9MTtg\nk4vZ9XXFvht7Ll+ztYvZqWaQ25yImAl8G3hn7T9D9dMgtvkzwAHAK8iz1A8ELgburt2/uc4lN71B\n/pxfD+wREWN69E0lj4o8UqdSW8Ygt/kYoKtX30by0RyO+tXH0HyHlp5xO0Szdt8OrAbeBexHHnp7\nEti19vz5wLwey08BniXP5J1KPvRoPfCG0p+lWW6D2OYn1rbxLHJS7r6NK/1ZmuU20G2+mfU9CqbO\n25w8r+lh4IfANPLh5/cAF5f+LM1yG8Q2PxVYV/vbshdwOPmipjeU/izNcqv93B5I/odlI/C/a4/3\n7GObD8l3aPEPPoQb8AxgCbCGnMIO7vHcd4Frey1/JDlprwHuA04p/Rma7TaQbU4+78eGzdy+U/pz\nNNNtoD/nvdY1gFSwzcnn/lgAPFcLI/8CjCr9OZrpNohtfib5CunPkUea5gG7l/4czXIDXlsLHpv9\n+1yv71AvRidJkirX9HNAJElS8zGASJKkyhlAJElS5QwgkiSpcgYQSZJUOQOIJEmqnAFEkiRVzgAi\nSZIqZwCRJEmVM4BIqpuI+G5EbIyIDbW2+/7eEXFpj8frIuK+iPh0RAyrrfvaXus+HhH/HREHlP5c\nkradAURSvf0SmNzjtjv5Oh+px3MvBb5Evl7NR3usm8jXVplMvtrmKOAXETGiotol1YkBRFK9rUsp\nLU8pPd7jtrHXc39KKX0DuBr4X73W7173dmAOsCf5KqmSmpgBRFIjWQuM7NUXABExHjip1re+yqIk\nDT2HMSXV23ER8WyPx1eklN7Re6GIeAMwA7iwZzfwp4gIYIda3/9JKd1bt2olVcIAIqnergVmURvJ\nAFb1eK47nGxXe/57wGd7PJ+A1wBrgEOBTwAfqHfBkurPACKp3lallB7q47nucNIJ/LnH3JCelqSU\nngHui4hJwOXAa+tTqqSqOAdEUkmrUkoPpZQe6SN89HYRcEBE9J6oKqnJGEAkNbLo+SCltAb4JnBu\nmXIkDRUDiKRGljbT93Vgv4g4oepiJA2dSGlzv9+SJEn14wiIJEmqnAFEkiRVzgAiSZIqZwCRJEmV\nM4BIkqTKGUAkSVLlDCCSJKlyBhBJklQ5A4gkSaqcAUSSJFXOACJJkipnAJEkSZX7f45Ym+BGspdl\nAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fpr, tpr, _ = roc_curve(test_labels, preds[:, 1])\n", + "plot(fpr, tpr)\n", + "xlabel('FPR')\n", + "ylabel('TPR')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}