Skip to content

Commit c502787

Browse files
committedJul 15, 2024
update
0 parents  commit c502787

38 files changed

+5370
-0
lines changed
 

‎README.md

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# E5-V: Universal Embeddings with Multimodal Large Language Models
2+
3+
## Example
4+
``` python
5+
import torch
6+
import torch.nn.functional as F
7+
import requests
8+
from PIL import Image
9+
from transformers import AutoTokenizer
10+
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
11+
12+
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n'
13+
14+
processor = LlavaNextProcessor.from_pretrained('royokong/e5-v')
15+
model = LlavaNextForConditionalGeneration.from_pretrained('royokong/e5-v', torch_dtype=torch.float16).cuda()
16+
17+
img_prompt = llama3_template.format('<image>\nSummary above image in one word: ')
18+
text_prompt = llama3_template.format('<sent>\nSummary above sentence in one word: ')
19+
20+
urls = ['https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg',
21+
'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/179px-Felis_catus-cat_on_snow.jpg']
22+
images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
23+
24+
texts = ['A dog sitting in the grass.',
25+
'A cat standing in the snow.']
26+
27+
text_inputs = processor([text_prompt.replace('<sent>', text) for text in texts], return_tensors="pt", padding=True).to('cuda')
28+
img_inputs = processor([img_prompt]*len(images), images, return_tensors="pt", padding=True).to('cuda')
29+
30+
with torch.no_grad():
31+
text_embs = model(**text_inputs, output_hidden_states=True, return_dict=True).hidden_states[-1][:, -1, :]
32+
img_embs = model(**img_inputs, output_hidden_states=True, return_dict=True).hidden_states[-1][:, -1, :]
33+
34+
text_embs = F.normalize(text_embs, dim=-1)
35+
img_embs = F.normalize(img_embs, dim=-1)
36+
37+
print(text_embs @ img_embs.t())
38+
```
39+
40+
41+
## Evaulate
42+
To evaluate the original results in the paper, please run following
43+
```sh
44+
# eval on coco, flickr30k, fashioniq and cirr
45+
accelerate launch --num_machines=1 --num_processes 8 --machine_rank 0 retrieval.py --use_e5v
46+
47+
# eval on i2i-coco, i2i-flickr30k
48+
accelerate launch --num_machines=1 --num_processes 8 --machine_rank 0 retrieval.py --use_e5v --ocr_replace_text
49+
50+
# eval on sts tasks
51+
cd SentEval/data/downstream/
52+
bash download_dataset.sh
53+
cd -
54+
accelerate launch --num_machines=1 --num_processes 8 --machine_rank 0 eval_sts.py --model_name_or_path royokong/e5-v
55+
```
56+
57+
## Training
58+
1. Install Dependencies
59+
60+
``` sh
61+
pip install -r requirements.txt
62+
```
63+
64+
2. Download Data
65+
66+
``` sh
67+
cd ./data
68+
bash download_nli.sh
69+
cd -
70+
```
71+
72+
3. Transfer llava-llama-3-8b model to huggingface format
73+
74+
``` sh
75+
mkdir -p models
76+
cd models
77+
for i in 1 2 3 4; do
78+
wget https://huggingface.co/lmms-lab/llama3-llava-next-8b/resolve/main/model-0000$i-of-00004.safetensors
79+
done
80+
cd -
81+
python load_llama3_hf.py
82+
rm models/*.safetensors
83+
```
84+
85+
4. Train
86+
``` sh
87+
bash run.sh
88+
```
89+
90+
5. Test
91+
Use `--lora_path` flag to test the results.
92+
``` sh
93+
accelerate launch --num_machines=1 --num_processes 8 --machine_rank 0 retrieval.py \
94+
--llava_llama3 --lora_path e5v-8b --batch_size 1
95+
```
96+
97+
98+
## Acknowledgement
99+
Our Code is based on SimCSE and alpaca-lora

‎SentEval/LICENSE

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
BSD License
2+
3+
For SentEval software
4+
5+
Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
6+
7+
Redistribution and use in source and binary forms, with or without modification,
8+
are permitted provided that the following conditions are met:
9+
10+
* Redistributions of source code must retain the above copyright notice, this
11+
list of conditions and the following disclaimer.
12+
13+
* Redistributions in binary form must reproduce the above copyright notice,
14+
this list of conditions and the following disclaimer in the documentation
15+
and/or other materials provided with the distribution.
16+
17+
* Neither the name Facebook nor the names of its contributors may be used to
18+
endorse or promote products derived from this software without specific
19+
prior written permission.
20+
21+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

‎SentEval/README.md

+249
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
Our modification to SentEval:
2+
3+
1. Add the `all` setting to all STS tasks.
4+
2. Change STS-B and SICK-R to not use an additional regressor.
5+
6+
# SentEval: evaluation toolkit for sentence embeddings
7+
8+
SentEval is a library for evaluating the quality of sentence embeddings. We assess their generalization power by using them as features on a broad and diverse set of "transfer" tasks. **SentEval currently includes 17 downstream tasks**. We also include a suite of **10 probing tasks** which evaluate what linguistic properties are encoded in sentence embeddings. Our goal is to ease the study and the development of general-purpose fixed-size sentence representations.
9+
10+
11+
**(04/22) SentEval new tasks: Added probing tasks for evaluating what linguistic properties are encoded in sentence embeddings**
12+
13+
**(10/04) SentEval example scripts for three sentence encoders: [SkipThought-LN](https://github.com/ryankiros/layer-norm#skip-thoughts)/[GenSen](https://github.com/Maluuba/gensen)/[Google-USE](https://tfhub.dev/google/universal-sentence-encoder/1)**
14+
15+
## Dependencies
16+
17+
This code is written in python. The dependencies are:
18+
19+
* Python 2/3 with [NumPy](http://www.numpy.org/)/[SciPy](http://www.scipy.org/)
20+
* [Pytorch](http://pytorch.org/)>=0.4
21+
* [scikit-learn](http://scikit-learn.org/stable/index.html)>=0.18.0
22+
23+
## Transfer tasks
24+
25+
### Downstream tasks
26+
SentEval allows you to evaluate your sentence embeddings as features for the following *downstream* tasks:
27+
28+
| Task | Type | #train | #test | needs_train | set_classifier |
29+
|---------- |------------------------------ |-----------:|----------:|:-----------:|:----------:|
30+
| [MR](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | movie review | 11k | 11k | 1 | 1 |
31+
| [CR](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | product review | 4k | 4k | 1 | 1 |
32+
| [SUBJ](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | subjectivity status | 10k | 10k | 1 | 1 |
33+
| [MPQA](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | opinion-polarity | 11k | 11k | 1 | 1 |
34+
| [SST](https://nlp.stanford.edu/sentiment/index.html) | binary sentiment analysis | 67k | 1.8k | 1 | 1 |
35+
| **[SST](https://nlp.stanford.edu/sentiment/index.html)** | **fine-grained sentiment analysis** | 8.5k | 2.2k | 1 | 1 |
36+
| [TREC](http://cogcomp.cs.illinois.edu/Data/QA/QC/) | question-type classification | 6k | 0.5k | 1 | 1 |
37+
| [SICK-E](http://clic.cimec.unitn.it/composes/sick.html) | natural language inference | 4.5k | 4.9k | 1 | 1 |
38+
| [SNLI](https://nlp.stanford.edu/projects/snli/) | natural language inference | 550k | 9.8k | 1 | 1 |
39+
| [MRPC](https://aclweb.org/aclwiki/Paraphrase_Identification_(State_of_the_art)) | paraphrase detection | 4.1k | 1.7k | 1 | 1 |
40+
| [STS 2012](https://www.cs.york.ac.uk/semeval-2012/task6/) | semantic textual similarity | N/A | 3.1k | 0 | 0 |
41+
| [STS 2013](http://ixa2.si.ehu.es/sts/) | semantic textual similarity | N/A | 1.5k | 0 | 0 |
42+
| [STS 2014](http://alt.qcri.org/semeval2014/task10/) | semantic textual similarity | N/A | 3.7k | 0 | 0 |
43+
| [STS 2015](http://alt.qcri.org/semeval2015/task2/) | semantic textual similarity | N/A | 8.5k | 0 | 0 |
44+
| [STS 2016](http://alt.qcri.org/semeval2016/task1/) | semantic textual similarity | N/A | 9.2k | 0 | 0 |
45+
| [STS B](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark#Results) | semantic textual similarity | 5.7k | 1.4k | 1 | 0 |
46+
| [SICK-R](http://clic.cimec.unitn.it/composes/sick.html) | semantic textual similarity | 4.5k | 4.9k | 1 | 0 |
47+
| [COCO](http://mscoco.org/) | image-caption retrieval | 567k | 5*1k | 1 | 0 |
48+
49+
where **needs_train** means a model with parameters is learned on top of the sentence embeddings, and **set_classifier** means you can define the parameters of the classifier in the case of a classification task (see below).
50+
51+
Note: COCO comes with ResNet-101 2048d image embeddings. [More details on the tasks.](https://arxiv.org/pdf/1705.02364.pdf)
52+
53+
### Probing tasks
54+
SentEval also includes a series of [*probing* tasks](https://github.com/facebookresearch/SentEval/tree/master/data/probing) to evaluate what linguistic properties are encoded in your sentence embeddings:
55+
56+
| Task | Type | #train | #test | needs_train | set_classifier |
57+
|---------- |------------------------------ |-----------:|----------:|:-----------:|:----------:|
58+
| [SentLen](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Length prediction | 100k | 10k | 1 | 1 |
59+
| [WC](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Word Content analysis | 100k | 10k | 1 | 1 |
60+
| [TreeDepth](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Tree depth prediction | 100k | 10k | 1 | 1 |
61+
| [TopConst](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Top Constituents prediction | 100k | 10k | 1 | 1 |
62+
| [BShift](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Word order analysis | 100k | 10k | 1 | 1 |
63+
| [Tense](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Verb tense prediction | 100k | 10k | 1 | 1 |
64+
| [SubjNum](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Subject number prediction | 100k | 10k | 1 | 1 |
65+
| [ObjNum](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Object number prediction | 100k | 10k | 1 | 1 |
66+
| [SOMO](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Semantic odd man out | 100k | 10k | 1 | 1 |
67+
| [CoordInv](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Coordination Inversion | 100k | 10k | 1 | 1 |
68+
69+
## Download datasets
70+
To get all the transfer tasks datasets, run (in data/downstream/):
71+
```bash
72+
./get_transfer_data.bash
73+
```
74+
This will automatically download and preprocess the downstream datasets, and store them in data/downstream (warning: for MacOS users, you may have to use p7zip instead of unzip). The probing tasks are already in data/probing by default.
75+
76+
## How to use SentEval: examples
77+
78+
### examples/bow.py
79+
80+
In examples/bow.py, we evaluate the quality of the average of word embeddings.
81+
82+
To download state-of-the-art fastText embeddings:
83+
84+
```bash
85+
curl -Lo glove.840B.300d.zip http://nlp.stanford.edu/data/glove.840B.300d.zip
86+
curl -Lo crawl-300d-2M.vec.zip https://dl.fbaipublicfiles.com/fasttext/vectors-english/crawl-300d-2M.vec.zip
87+
```
88+
89+
To reproduce the results for bag-of-vectors, run (in examples/):
90+
```bash
91+
python bow.py
92+
```
93+
94+
As required by SentEval, this script implements two functions: **prepare** (optional) and **batcher** (required) that turn text sentences into sentence embeddings. Then SentEval takes care of the evaluation on the transfer tasks using the embeddings as features.
95+
96+
### examples/infersent.py
97+
98+
To get the **[InferSent](https://www.github.com/facebookresearch/InferSent)** model and reproduce our results, download our best models and run infersent.py (in examples/):
99+
```bash
100+
curl -Lo examples/infersent1.pkl https://dl.fbaipublicfiles.com/senteval/infersent/infersent1.pkl
101+
curl -Lo examples/infersent2.pkl https://dl.fbaipublicfiles.com/senteval/infersent/infersent2.pkl
102+
```
103+
104+
### examples/skipthought.py - examples/gensen.py - examples/googleuse.py
105+
106+
We also provide example scripts for three other encoders:
107+
108+
* [SkipThought with Layer-Normalization](https://github.com/ryankiros/layer-norm#skip-thoughts) in Theano
109+
* [GenSen encoder](https://github.com/Maluuba/gensen) in Pytorch
110+
* [Google encoder](https://tfhub.dev/google/universal-sentence-encoder/1) in TensorFlow
111+
112+
Note that for SkipThought and GenSen, following the steps of the associated githubs is necessary.
113+
The Google encoder script should work as-is.
114+
115+
## How to use SentEval
116+
117+
To evaluate your sentence embeddings, SentEval requires that you implement two functions:
118+
119+
1. **prepare** (sees the whole dataset of each task and can thus construct the word vocabulary, the dictionary of word vectors etc)
120+
2. **batcher** (transforms a batch of text sentences into sentence embeddings)
121+
122+
123+
### 1.) prepare(params, samples) (optional)
124+
125+
*batcher* only sees one batch at a time while the *samples* argument of *prepare* contains all the sentences of a task.
126+
127+
```
128+
prepare(params, samples)
129+
```
130+
* *params*: senteval parameters.
131+
* *samples*: list of all sentences from the tranfer task.
132+
* *output*: No output. Arguments stored in "params" can further be used by *batcher*.
133+
134+
*Example*: in bow.py, prepare is is used to build the vocabulary of words and construct the "params.word_vect* dictionary of word vectors.
135+
136+
137+
### 2.) batcher(params, batch)
138+
```
139+
batcher(params, batch)
140+
```
141+
* *params*: senteval parameters.
142+
* *batch*: numpy array of text sentences (of size params.batch_size)
143+
* *output*: numpy array of sentence embeddings (of size params.batch_size)
144+
145+
*Example*: in bow.py, batcher is used to compute the mean of the word vectors for each sentence in the batch using params.word_vec. Use your own encoder in that function to encode sentences.
146+
147+
### 3.) evaluation on transfer tasks
148+
149+
After having implemented the batch and prepare function for your own sentence encoder,
150+
151+
1) to perform the actual evaluation, first import senteval and set its parameters:
152+
```python
153+
import senteval
154+
params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10}
155+
```
156+
157+
2) (optional) set the parameters of the classifier (when applicable):
158+
```python
159+
params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
160+
'tenacity': 5, 'epoch_size': 4}
161+
```
162+
You can choose **nhid=0** (Logistic Regression) or **nhid>0** (MLP) and define the parameters for training.
163+
164+
3) Create an instance of the class SE:
165+
```python
166+
se = senteval.engine.SE(params, batcher, prepare)
167+
```
168+
169+
4) define the set of transfer tasks and run the evaluation:
170+
```python
171+
transfer_tasks = ['MR', 'SICKEntailment', 'STS14', 'STSBenchmark']
172+
results = se.eval(transfer_tasks)
173+
```
174+
The current list of available tasks is:
175+
```python
176+
['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 'SNLI',
177+
'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 'ImageCaptionRetrieval',
178+
'STS12', 'STS13', 'STS14', 'STS15', 'STS16',
179+
'Length', 'WordContent', 'Depth', 'TopConstituents','BigramShift', 'Tense',
180+
'SubjNumber', 'ObjNumber', 'OddManOut', 'CoordinationInversion']
181+
```
182+
183+
## SentEval parameters
184+
Global parameters of SentEval:
185+
```bash
186+
# senteval parameters
187+
task_path # path to SentEval datasets (required)
188+
seed # seed
189+
usepytorch # use cuda-pytorch (else scikit-learn) where possible
190+
kfold # k-fold validation for MR/CR/SUB/MPQA.
191+
```
192+
193+
Parameters of the classifier:
194+
```bash
195+
nhid: # number of hidden units (0: Logistic Regression, >0: MLP); Default nonlinearity: Tanh
196+
optim: # optimizer ("sgd,lr=0.1", "adam", "rmsprop" ..)
197+
tenacity: # how many times dev acc does not increase before training stops
198+
epoch_size: # each epoch corresponds to epoch_size pass on the train set
199+
max_epoch: # max number of epoches
200+
dropout: # dropout for MLP
201+
```
202+
203+
Note that to get a proxy of the results while **dramatically reducing computation time**,
204+
we suggest the **prototyping config**:
205+
```python
206+
params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
207+
params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
208+
'tenacity': 3, 'epoch_size': 2}
209+
```
210+
which will results in a 5 times speedup for classification tasks.
211+
212+
To produce results that are **comparable to the literature**, use the **default config**:
213+
```python
214+
params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10}
215+
params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
216+
'tenacity': 5, 'epoch_size': 4}
217+
```
218+
which takes longer but will produce better and comparable results.
219+
220+
For probing tasks, we used an MLP with a Sigmoid nonlinearity and and tuned the nhid (in [50, 100, 200]) and dropout (in [0.0, 0.1, 0.2]) on the dev set.
221+
222+
## References
223+
224+
Please considering citing [[1]](https://arxiv.org/abs/1803.05449) if using this code for evaluating sentence embedding methods.
225+
226+
### SentEval: An Evaluation Toolkit for Universal Sentence Representations
227+
228+
[1] A. Conneau, D. Kiela, [*SentEval: An Evaluation Toolkit for Universal Sentence Representations*](https://arxiv.org/abs/1803.05449)
229+
230+
```
231+
@article{conneau2018senteval,
232+
title={SentEval: An Evaluation Toolkit for Universal Sentence Representations},
233+
author={Conneau, Alexis and Kiela, Douwe},
234+
journal={arXiv preprint arXiv:1803.05449},
235+
year={2018}
236+
}
237+
```
238+
239+
Contact: [aconneau@fb.com](mailto:aconneau@fb.com), [dkiela@fb.com](mailto:dkiela@fb.com)
240+
241+
### Related work
242+
* [J. R Kiros, Y. Zhu, R. Salakhutdinov, R. S. Zemel, A. Torralba, R. Urtasun, S. Fidler - SkipThought Vectors, NIPS 2015](https://arxiv.org/abs/1506.06726)
243+
* [S. Arora, Y. Liang, T. Ma - A Simple but Tough-to-Beat Baseline for Sentence Embeddings, ICLR 2017](https://openreview.net/pdf?id=SyK00v5xx)
244+
* [Y. Adi, E. Kermany, Y. Belinkov, O. Lavi, Y. Goldberg - Fine-grained analysis of sentence embeddings using auxiliary prediction tasks, ICLR 2017](https://arxiv.org/abs/1608.04207)
245+
* [A. Conneau, D. Kiela, L. Barrault, H. Schwenk, A. Bordes - Supervised Learning of Universal Sentence Representations from Natural Language Inference Data, EMNLP 2017](https://arxiv.org/abs/1705.02364)
246+
* [S. Subramanian, A. Trischler, Y. Bengio, C. J Pal - Learning General Purpose Distributed Sentence Representations via Large Scale Multi-task Learning, ICLR 2018](https://arxiv.org/abs/1804.00079)
247+
* [A. Nie, E. D. Bennett, N. D. Goodman - DisSent: Sentence Representation Learning from Explicit Discourse Relations, 2018](https://arxiv.org/abs/1710.04334)
248+
* [D. Cer, Y. Yang, S. Kong, N. Hua, N. Limtiaco, R. St. John, N. Constant, M. Guajardo-Cespedes, S. Yuan, C. Tar, Y. Sung, B. Strope, R. Kurzweil - Universal Sentence Encoder, 2018](https://arxiv.org/abs/1803.11175)
249+
* [A. Conneau, G. Kruszewski, G. Lample, L. Barrault, M. Baroni - What you can cram into a single vector: Probing sentence embeddings for linguistic properties, ACL 2018](https://arxiv.org/abs/1805.01070)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
wget --no-check-certificate https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/senteval.tar
2+
tar xvf senteval.tar

‎SentEval/examples/bow.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
from __future__ import absolute_import, division, unicode_literals
9+
10+
import sys
11+
import io
12+
import numpy as np
13+
import logging
14+
15+
16+
# Set PATHs
17+
PATH_TO_SENTEVAL = '../'
18+
PATH_TO_DATA = '../data'
19+
# PATH_TO_VEC = 'glove/glove.840B.300d.txt'
20+
PATH_TO_VEC = 'fasttext/crawl-300d-2M.vec'
21+
22+
# import SentEval
23+
sys.path.insert(0, PATH_TO_SENTEVAL)
24+
import senteval
25+
26+
27+
# Create dictionary
28+
def create_dictionary(sentences, threshold=0):
29+
words = {}
30+
for s in sentences:
31+
for word in s:
32+
words[word] = words.get(word, 0) + 1
33+
34+
if threshold > 0:
35+
newwords = {}
36+
for word in words:
37+
if words[word] >= threshold:
38+
newwords[word] = words[word]
39+
words = newwords
40+
words['<s>'] = 1e9 + 4
41+
words['</s>'] = 1e9 + 3
42+
words['<p>'] = 1e9 + 2
43+
44+
sorted_words = sorted(words.items(), key=lambda x: -x[1]) # inverse sort
45+
id2word = []
46+
word2id = {}
47+
for i, (w, _) in enumerate(sorted_words):
48+
id2word.append(w)
49+
word2id[w] = i
50+
51+
return id2word, word2id
52+
53+
# Get word vectors from vocabulary (glove, word2vec, fasttext ..)
54+
def get_wordvec(path_to_vec, word2id):
55+
word_vec = {}
56+
57+
with io.open(path_to_vec, 'r', encoding='utf-8') as f:
58+
# if word2vec or fasttext file : skip first line "next(f)"
59+
for line in f:
60+
word, vec = line.split(' ', 1)
61+
if word in word2id:
62+
word_vec[word] = np.fromstring(vec, sep=' ')
63+
64+
logging.info('Found {0} words with word vectors, out of \
65+
{1} words'.format(len(word_vec), len(word2id)))
66+
return word_vec
67+
68+
69+
# SentEval prepare and batcher
70+
def prepare(params, samples):
71+
_, params.word2id = create_dictionary(samples)
72+
params.word_vec = get_wordvec(PATH_TO_VEC, params.word2id)
73+
params.wvec_dim = 300
74+
return
75+
76+
def batcher(params, batch):
77+
batch = [sent if sent != [] else ['.'] for sent in batch]
78+
embeddings = []
79+
80+
for sent in batch:
81+
sentvec = []
82+
for word in sent:
83+
if word in params.word_vec:
84+
sentvec.append(params.word_vec[word])
85+
if not sentvec:
86+
vec = np.zeros(params.wvec_dim)
87+
sentvec.append(vec)
88+
sentvec = np.mean(sentvec, 0)
89+
embeddings.append(sentvec)
90+
91+
embeddings = np.vstack(embeddings)
92+
return embeddings
93+
94+
95+
# Set params for SentEval
96+
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
97+
params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
98+
'tenacity': 3, 'epoch_size': 2}
99+
100+
# Set up logger
101+
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
102+
103+
if __name__ == "__main__":
104+
se = senteval.engine.SE(params_senteval, batcher, prepare)
105+
#transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
106+
#'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
107+
#'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
108+
#'Length', 'WordContent', 'Depth', 'TopConstituents',
109+
#'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
110+
#'OddManOut', 'CoordinationInversion']
111+
transfer_tasks = ['STSBenchmark']
112+
results = se.eval(transfer_tasks)
113+
print(results)

‎SentEval/examples/bow_word_piece.py

+158
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
from __future__ import absolute_import, division, unicode_literals
9+
10+
import sys
11+
import io
12+
import numpy as np
13+
import logging
14+
15+
from transformers import BertTokenizer
16+
17+
# Set PATHs
18+
PATH_TO_SENTEVAL = '../'
19+
PATH_TO_DATA = '../data'
20+
# PATH_TO_VEC = 'glove/glove.840B.300d.txt'
21+
PATH_TO_VEC = 'fasttext/crawl-300d-2M.vec'
22+
23+
# import SentEval
24+
sys.path.insert(0, PATH_TO_SENTEVAL)
25+
import senteval
26+
27+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
28+
a_remove_set = {".", "a", "the", "in", ",", "is", "to", "of", "and", "'", "on", "man", "-", "s", "with", "for", "\"", "at", "##s", "woman", "are", "it", "two", "that", "you", "dog", "said", "playing", "i", "an", "as", "was", "from", ":", "by", "white"}
29+
remove_set = {'?', '*', '#', '´', '’', '=', '…', '|', '~', '/', '‚', '¿', '–', '»', '-', '€', '‘', '"', '(', '•', '`', '$', ':', '[', '”', '%', '£', '<', '[UNK]', ';', '“', '@', '_', '{', '^', ',', '.', '!', '™', '&', ']', '>', '\\', "'", ')', '+', '—'}
30+
31+
# Create dictionary
32+
def create_dictionary(sentences, threshold=0):
33+
words = {}
34+
for s in sentences:
35+
for word in s:
36+
words[word] = words.get(word, 0) + 1
37+
#for word in tokenizer.convert_ids_to_tokens(tokenizer.encode(' '.join(s), add_special_tokens=False)):
38+
#if '##' in word and word in remove_set: continue
39+
#words[word] = words.get(word, 0) + 1
40+
41+
if threshold > 0:
42+
newwords = {}
43+
for word in words:
44+
if words[word] >= threshold:
45+
newwords[word] = words[word]
46+
words = newwords
47+
words['<s>'] = 1e9 + 4
48+
words['</s>'] = 1e9 + 3
49+
words['<p>'] = 1e9 + 2
50+
51+
sorted_words = sorted(words.items(), key=lambda x: -x[1]) # inverse sort
52+
id2word = []
53+
word2id = {}
54+
for i, (w, _) in enumerate(sorted_words):
55+
id2word.append(w)
56+
word2id[w] = i
57+
58+
return id2word, word2id
59+
60+
# Get word vectors from vocabulary (glove, word2vec, fasttext ..)
61+
def get_wordvec(path_to_vec, word2id):
62+
word_vec = {}
63+
64+
with io.open(path_to_vec, 'r', encoding='utf-8') as f:
65+
# if word2vec or fasttext file : skip first line "next(f)"
66+
for line in f:
67+
word, vec = line.split(' ', 1)
68+
if word in word2id:
69+
word_vec[word] = np.fromstring(vec, sep=' ')
70+
71+
logging.info('Found {0} words with word vectors, out of \
72+
{1} words'.format(len(word_vec), len(word2id)))
73+
return word_vec
74+
75+
def get_bert_wordvec(path_to_vec, word2id):
76+
word_vec = {}
77+
from transformers import BertModel
78+
bert = BertModel.from_pretrained('bert-base-uncased')
79+
vocab = tokenizer.get_vocab()
80+
bert_word_vec = bert.embeddings.word_embeddings.weight.detach().numpy()
81+
82+
for word in word2id:
83+
if word in ['<s>', '</s>', '<p>']:
84+
word_vec[word] = np.zeros(768)
85+
else:
86+
word_vec[word] = bert_word_vec[vocab[word]]
87+
88+
logging.info('Found {0} words with word vectors, out of \
89+
{1} words'.format(len(word_vec), len(word2id)))
90+
return word_vec
91+
92+
# SentEval prepare and batcher
93+
def prepare(params, samples):
94+
_, params.word2id = create_dictionary(samples)
95+
params.word_vec = get_wordvec(PATH_TO_VEC, params.word2id)
96+
params.wvec_dim = 300
97+
#params.word_vec = get_bert_wordvec(PATH_TO_VEC, params.word2id)
98+
#params.wvec_dim = 768
99+
return
100+
101+
def batcher(params, batch):
102+
batch = [sent if sent != [] else ['.'] for sent in batch]
103+
embeddings = []
104+
105+
for sent in batch:
106+
sentvec = []
107+
# for word in tokenizer.convert_ids_to_tokens(tokenizer.encode(' '.join(sent), add_special_tokens=False)):
108+
for word in sent:
109+
if word in params.word_vec:# and word not in a_remove_set and word not in remove_set:
110+
sentvec.append(params.word_vec[word])
111+
if not sentvec:
112+
vec = np.zeros(params.wvec_dim)
113+
sentvec.append(vec)
114+
sentvec = np.mean(sentvec, 0)
115+
embeddings.append(sentvec)
116+
117+
embeddings = np.vstack(embeddings)
118+
return embeddings
119+
120+
121+
# Set params for SentEval
122+
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
123+
params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
124+
'tenacity': 3, 'epoch_size': 2}
125+
126+
# Set up logger
127+
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
128+
129+
if __name__ == "__main__":
130+
se = senteval.engine.SE(params_senteval, batcher, prepare)
131+
#transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
132+
#'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
133+
#'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
134+
#'Length', 'WordContent', 'Depth', 'TopConstituents',
135+
#'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
136+
#'OddManOut', 'CoordinationInversion']
137+
transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']
138+
results = se.eval(transfer_tasks)
139+
print(results)
140+
task_names = []
141+
scores = []
142+
for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']:
143+
task_names.append(task)
144+
if task in results:
145+
if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']:
146+
scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100))
147+
else:
148+
scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100))
149+
else:
150+
scores.append("0.00")
151+
task_names.append("Avg.")
152+
scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores)))
153+
154+
from prettytable import PrettyTable
155+
tb = PrettyTable()
156+
tb.field_names = task_names
157+
tb.add_row(scores)
158+
print(tb)

‎SentEval/examples/gensen.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
"""
9+
Clone GenSen repo here: https://github.com/Maluuba/gensen.git
10+
And follow instructions for loading the model used in batcher
11+
"""
12+
13+
from __future__ import absolute_import, division, unicode_literals
14+
15+
import sys
16+
import logging
17+
# import GenSen package
18+
from gensen import GenSen, GenSenSingle
19+
20+
# Set PATHs
21+
PATH_TO_SENTEVAL = '../'
22+
PATH_TO_DATA = '../data'
23+
24+
# import SentEval
25+
sys.path.insert(0, PATH_TO_SENTEVAL)
26+
import senteval
27+
28+
# SentEval prepare and batcher
29+
def prepare(params, samples):
30+
return
31+
32+
def batcher(params, batch):
33+
batch = [' '.join(sent) if sent != [] else '.' for sent in batch]
34+
_, reps_h_t = gensen.get_representation(
35+
sentences, pool='last', return_numpy=True, tokenize=True
36+
)
37+
embeddings = reps_h_t
38+
return embeddings
39+
40+
# Load GenSen model
41+
gensen_1 = GenSenSingle(
42+
model_folder='../data/models',
43+
filename_prefix='nli_large_bothskip',
44+
pretrained_emb='../data/embedding/glove.840B.300d.h5'
45+
)
46+
gensen_2 = GenSenSingle(
47+
model_folder='../data/models',
48+
filename_prefix='nli_large_bothskip_parse',
49+
pretrained_emb='../data/embedding/glove.840B.300d.h5'
50+
)
51+
gensen_encoder = GenSen(gensen_1, gensen_2)
52+
reps_h, reps_h_t = gensen.get_representation(
53+
sentences, pool='last', return_numpy=True, tokenize=True
54+
)
55+
56+
# Set params for SentEval
57+
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
58+
params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
59+
'tenacity': 3, 'epoch_size': 2}
60+
params_senteval['gensen'] = gensen_encoder
61+
62+
# Set up logger
63+
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
64+
65+
if __name__ == "__main__":
66+
se = senteval.engine.SE(params_senteval, batcher, prepare)
67+
transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
68+
'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
69+
'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
70+
'Length', 'WordContent', 'Depth', 'TopConstituents',
71+
'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
72+
'OddManOut', 'CoordinationInversion']
73+
results = se.eval(transfer_tasks)
74+
print(results)

‎SentEval/examples/googleuse.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
from __future__ import absolute_import, division
9+
10+
import os
11+
import sys
12+
import logging
13+
import tensorflow as tf
14+
import tensorflow_hub as hub
15+
tf.logging.set_verbosity(0)
16+
17+
# Set PATHs
18+
PATH_TO_SENTEVAL = '../'
19+
PATH_TO_DATA = '../data'
20+
21+
# import SentEval
22+
sys.path.insert(0, PATH_TO_SENTEVAL)
23+
import senteval
24+
25+
# tensorflow session
26+
session = tf.Session()
27+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
28+
29+
# SentEval prepare and batcher
30+
def prepare(params, samples):
31+
return
32+
33+
def batcher(params, batch):
34+
batch = [' '.join(sent) if sent != [] else '.' for sent in batch]
35+
embeddings = params['google_use'](batch)
36+
return embeddings
37+
38+
def make_embed_fn(module):
39+
with tf.Graph().as_default():
40+
sentences = tf.placeholder(tf.string)
41+
embed = hub.Module(module)
42+
embeddings = embed(sentences)
43+
session = tf.train.MonitoredSession()
44+
return lambda x: session.run(embeddings, {sentences: x})
45+
46+
# Start TF session and load Google Universal Sentence Encoder
47+
encoder = make_embed_fn("https://tfhub.dev/google/universal-sentence-encoder-large/2")
48+
49+
# Set params for SentEval
50+
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
51+
params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
52+
'tenacity': 3, 'epoch_size': 2}
53+
params_senteval['google_use'] = encoder
54+
55+
# Set up logger
56+
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
57+
58+
if __name__ == "__main__":
59+
se = senteval.engine.SE(params_senteval, batcher, prepare)
60+
transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
61+
'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
62+
'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
63+
'Length', 'WordContent', 'Depth', 'TopConstituents',
64+
'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
65+
'OddManOut', 'CoordinationInversion']
66+
results = se.eval(transfer_tasks)
67+
print(results)

‎SentEval/examples/infersent.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
"""
9+
InferSent models. See https://github.com/facebookresearch/InferSent.
10+
"""
11+
12+
from __future__ import absolute_import, division, unicode_literals
13+
14+
import sys
15+
import os
16+
import torch
17+
import logging
18+
19+
# get models.py from InferSent repo
20+
from models import InferSent
21+
22+
# Set PATHs
23+
PATH_SENTEVAL = '../'
24+
PATH_TO_DATA = '../data'
25+
PATH_TO_W2V = 'PATH/TO/glove.840B.300d.txt' # or crawl-300d-2M.vec for V2
26+
MODEL_PATH = 'infersent1.pkl'
27+
V = 1 # version of InferSent
28+
29+
assert os.path.isfile(MODEL_PATH) and os.path.isfile(PATH_TO_W2V), \
30+
'Set MODEL and GloVe PATHs'
31+
32+
# import senteval
33+
sys.path.insert(0, PATH_SENTEVAL)
34+
import senteval
35+
36+
37+
def prepare(params, samples):
38+
params.infersent.build_vocab([' '.join(s) for s in samples], tokenize=False)
39+
40+
41+
def batcher(params, batch):
42+
sentences = [' '.join(s) for s in batch]
43+
embeddings = params.infersent.encode(sentences, bsize=params.batch_size, tokenize=False)
44+
return embeddings
45+
46+
47+
"""
48+
Evaluation of trained model on Transfer Tasks (SentEval)
49+
"""
50+
51+
# define senteval params
52+
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
53+
params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
54+
'tenacity': 3, 'epoch_size': 2}
55+
# Set up logger
56+
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
57+
58+
if __name__ == "__main__":
59+
# Load InferSent model
60+
params_model = {'bsize': 64, 'word_emb_dim': 300, 'enc_lstm_dim': 2048,
61+
'pool_type': 'max', 'dpout_model': 0.0, 'version': V}
62+
model = InferSent(params_model)
63+
model.load_state_dict(torch.load(MODEL_PATH))
64+
model.set_w2v_path(PATH_TO_W2V)
65+
66+
params_senteval['infersent'] = model.cuda()
67+
68+
se = senteval.engine.SE(params_senteval, batcher, prepare)
69+
transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
70+
'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
71+
'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
72+
'Length', 'WordContent', 'Depth', 'TopConstituents',
73+
'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
74+
'OddManOut', 'CoordinationInversion']
75+
results = se.eval(transfer_tasks)
76+
print(results)

‎SentEval/examples/models.py

+265
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
"""
9+
This file contains the definition of encoders used in https://arxiv.org/pdf/1705.02364.pdf
10+
"""
11+
12+
import numpy as np
13+
import time
14+
15+
import torch
16+
import torch.nn as nn
17+
18+
19+
class InferSent(nn.Module):
20+
21+
def __init__(self, config):
22+
super(InferSent, self).__init__()
23+
self.bsize = config['bsize']
24+
self.word_emb_dim = config['word_emb_dim']
25+
self.enc_lstm_dim = config['enc_lstm_dim']
26+
self.pool_type = config['pool_type']
27+
self.dpout_model = config['dpout_model']
28+
self.version = 1 if 'version' not in config else config['version']
29+
30+
self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1,
31+
bidirectional=True, dropout=self.dpout_model)
32+
33+
assert self.version in [1, 2]
34+
if self.version == 1:
35+
self.bos = '<s>'
36+
self.eos = '</s>'
37+
self.max_pad = True
38+
self.moses_tok = False
39+
elif self.version == 2:
40+
self.bos = '<p>'
41+
self.eos = '</p>'
42+
self.max_pad = False
43+
self.moses_tok = True
44+
45+
def is_cuda(self):
46+
# either all weights are on cpu or they are on gpu
47+
return self.enc_lstm.bias_hh_l0.data.is_cuda
48+
49+
def forward(self, sent_tuple):
50+
# sent_len: [max_len, ..., min_len] (bsize)
51+
# sent: (seqlen x bsize x worddim)
52+
sent, sent_len = sent_tuple
53+
54+
# Sort by length (keep idx)
55+
sent_len_sorted, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len)
56+
sent_len_sorted = sent_len_sorted.copy()
57+
idx_unsort = np.argsort(idx_sort)
58+
59+
idx_sort = torch.from_numpy(idx_sort).cuda() if self.is_cuda() \
60+
else torch.from_numpy(idx_sort)
61+
sent = sent.index_select(1, idx_sort)
62+
63+
# Handling padding in Recurrent Networks
64+
sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len_sorted)
65+
sent_output = self.enc_lstm(sent_packed)[0] # seqlen x batch x 2*nhid
66+
sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0]
67+
68+
# Un-sort by length
69+
idx_unsort = torch.from_numpy(idx_unsort).cuda() if self.is_cuda() \
70+
else torch.from_numpy(idx_unsort)
71+
sent_output = sent_output.index_select(1, idx_unsort)
72+
73+
# Pooling
74+
if self.pool_type == "mean":
75+
sent_len = torch.FloatTensor(sent_len.copy()).unsqueeze(1).cuda()
76+
emb = torch.sum(sent_output, 0).squeeze(0)
77+
emb = emb / sent_len.expand_as(emb)
78+
elif self.pool_type == "max":
79+
if not self.max_pad:
80+
sent_output[sent_output == 0] = -1e9
81+
emb = torch.max(sent_output, 0)[0]
82+
if emb.ndimension() == 3:
83+
emb = emb.squeeze(0)
84+
assert emb.ndimension() == 2
85+
86+
return emb
87+
88+
def set_w2v_path(self, w2v_path):
89+
self.w2v_path = w2v_path
90+
91+
def get_word_dict(self, sentences, tokenize=True):
92+
# create vocab of words
93+
word_dict = {}
94+
sentences = [s.split() if not tokenize else self.tokenize(s) for s in sentences]
95+
for sent in sentences:
96+
for word in sent:
97+
if word not in word_dict:
98+
word_dict[word] = ''
99+
word_dict[self.bos] = ''
100+
word_dict[self.eos] = ''
101+
return word_dict
102+
103+
def get_w2v(self, word_dict):
104+
assert hasattr(self, 'w2v_path'), 'w2v path not set'
105+
# create word_vec with w2v vectors
106+
word_vec = {}
107+
with open(self.w2v_path, encoding='utf-8') as f:
108+
for line in f:
109+
word, vec = line.split(' ', 1)
110+
if word in word_dict:
111+
word_vec[word] = np.fromstring(vec, sep=' ')
112+
print('Found %s(/%s) words with w2v vectors' % (len(word_vec), len(word_dict)))
113+
return word_vec
114+
115+
def get_w2v_k(self, K):
116+
assert hasattr(self, 'w2v_path'), 'w2v path not set'
117+
# create word_vec with k first w2v vectors
118+
k = 0
119+
word_vec = {}
120+
with open(self.w2v_path, encoding='utf-8') as f:
121+
for line in f:
122+
word, vec = line.split(' ', 1)
123+
if k <= K:
124+
word_vec[word] = np.fromstring(vec, sep=' ')
125+
k += 1
126+
if k > K:
127+
if word in [self.bos, self.eos]:
128+
word_vec[word] = np.fromstring(vec, sep=' ')
129+
130+
if k > K and all([w in word_vec for w in [self.bos, self.eos]]):
131+
break
132+
return word_vec
133+
134+
def build_vocab(self, sentences, tokenize=True):
135+
assert hasattr(self, 'w2v_path'), 'w2v path not set'
136+
word_dict = self.get_word_dict(sentences, tokenize)
137+
self.word_vec = self.get_w2v(word_dict)
138+
print('Vocab size : %s' % (len(self.word_vec)))
139+
140+
# build w2v vocab with k most frequent words
141+
def build_vocab_k_words(self, K):
142+
assert hasattr(self, 'w2v_path'), 'w2v path not set'
143+
self.word_vec = self.get_w2v_k(K)
144+
print('Vocab size : %s' % (K))
145+
146+
def update_vocab(self, sentences, tokenize=True):
147+
assert hasattr(self, 'w2v_path'), 'warning : w2v path not set'
148+
assert hasattr(self, 'word_vec'), 'build_vocab before updating it'
149+
word_dict = self.get_word_dict(sentences, tokenize)
150+
151+
# keep only new words
152+
for word in self.word_vec:
153+
if word in word_dict:
154+
del word_dict[word]
155+
156+
# udpate vocabulary
157+
if word_dict:
158+
new_word_vec = self.get_w2v(word_dict)
159+
self.word_vec.update(new_word_vec)
160+
else:
161+
new_word_vec = []
162+
print('New vocab size : %s (added %s words)'% (len(self.word_vec), len(new_word_vec)))
163+
164+
def get_batch(self, batch):
165+
# sent in batch in decreasing order of lengths
166+
# batch: (bsize, max_len, word_dim)
167+
embed = np.zeros((len(batch[0]), len(batch), self.word_emb_dim))
168+
169+
for i in range(len(batch)):
170+
for j in range(len(batch[i])):
171+
embed[j, i, :] = self.word_vec[batch[i][j]]
172+
173+
return torch.FloatTensor(embed)
174+
175+
def tokenize(self, s):
176+
from nltk.tokenize import word_tokenize
177+
if self.moses_tok:
178+
s = ' '.join(word_tokenize(s))
179+
s = s.replace(" n't ", "n 't ") # HACK to get ~MOSES tokenization
180+
return s.split()
181+
else:
182+
return word_tokenize(s)
183+
184+
def prepare_samples(self, sentences, bsize, tokenize, verbose):
185+
sentences = [[self.bos] + s.split() + [self.eos] if not tokenize else
186+
[self.bos] + self.tokenize(s) + [self.eos] for s in sentences]
187+
n_w = np.sum([len(x) for x in sentences])
188+
189+
# filters words without w2v vectors
190+
for i in range(len(sentences)):
191+
s_f = [word for word in sentences[i] if word in self.word_vec]
192+
if not s_f:
193+
import warnings
194+
warnings.warn('No words in "%s" (idx=%s) have w2v vectors. \
195+
Replacing by "</s>"..' % (sentences[i], i))
196+
s_f = [self.eos]
197+
sentences[i] = s_f
198+
199+
lengths = np.array([len(s) for s in sentences])
200+
n_wk = np.sum(lengths)
201+
if verbose:
202+
print('Nb words kept : %s/%s (%.1f%s)' % (
203+
n_wk, n_w, 100.0 * n_wk / n_w, '%'))
204+
205+
# sort by decreasing length
206+
lengths, idx_sort = np.sort(lengths)[::-1], np.argsort(-lengths)
207+
sentences = np.array(sentences)[idx_sort]
208+
209+
return sentences, lengths, idx_sort
210+
211+
def encode(self, sentences, bsize=64, tokenize=True, verbose=False):
212+
tic = time.time()
213+
sentences, lengths, idx_sort = self.prepare_samples(
214+
sentences, bsize, tokenize, verbose)
215+
216+
embeddings = []
217+
for stidx in range(0, len(sentences), bsize):
218+
batch = self.get_batch(sentences[stidx:stidx + bsize])
219+
if self.is_cuda():
220+
batch = batch.cuda()
221+
with torch.no_grad():
222+
batch = self.forward((batch, lengths[stidx:stidx + bsize])).data.cpu().numpy()
223+
embeddings.append(batch)
224+
embeddings = np.vstack(embeddings)
225+
226+
# unsort
227+
idx_unsort = np.argsort(idx_sort)
228+
embeddings = embeddings[idx_unsort]
229+
230+
if verbose:
231+
print('Speed : %.1f sentences/s (%s mode, bsize=%s)' % (
232+
len(embeddings)/(time.time()-tic),
233+
'gpu' if self.is_cuda() else 'cpu', bsize))
234+
return embeddings
235+
236+
def visualize(self, sent, tokenize=True):
237+
238+
sent = sent.split() if not tokenize else self.tokenize(sent)
239+
sent = [[self.bos] + [word for word in sent if word in self.word_vec] + [self.eos]]
240+
241+
if ' '.join(sent[0]) == '%s %s' % (self.bos, self.eos):
242+
import warnings
243+
warnings.warn('No words in "%s" have w2v vectors. Replacing \
244+
by "%s %s"..' % (sent, self.bos, self.eos))
245+
batch = self.get_batch(sent)
246+
247+
if self.is_cuda():
248+
batch = batch.cuda()
249+
output = self.enc_lstm(batch)[0]
250+
output, idxs = torch.max(output, 0)
251+
# output, idxs = output.squeeze(), idxs.squeeze()
252+
idxs = idxs.data.cpu().numpy()
253+
argmaxs = [np.sum((idxs == k)) for k in range(len(sent[0]))]
254+
255+
# visualize model
256+
import matplotlib.pyplot as plt
257+
x = range(len(sent[0]))
258+
y = [100.0 * n / np.sum(argmaxs) for n in argmaxs]
259+
plt.xticks(x, sent[0], rotation=45)
260+
plt.bar(x, y)
261+
plt.ylabel('%')
262+
plt.title('Visualisation of words importance')
263+
plt.show()
264+
265+
return output, idxs

‎SentEval/examples/skipthought.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
from __future__ import absolute_import, division, unicode_literals
9+
10+
"""
11+
Example of file for SkipThought in SentEval
12+
"""
13+
import logging
14+
import sys
15+
sys.setdefaultencoding('utf8')
16+
17+
18+
# Set PATHs
19+
PATH_TO_SENTEVAL = '../'
20+
PATH_TO_DATA = '../data/senteval_data/'
21+
PATH_TO_SKIPTHOUGHT = ''
22+
23+
assert PATH_TO_SKIPTHOUGHT != '', 'Download skipthought and set correct PATH'
24+
25+
# import skipthought and Senteval
26+
sys.path.insert(0, PATH_TO_SKIPTHOUGHT)
27+
import skipthoughts
28+
sys.path.insert(0, PATH_TO_SENTEVAL)
29+
import senteval
30+
31+
32+
def prepare(params, samples):
33+
return
34+
35+
def batcher(params, batch):
36+
batch = [str(' '.join(sent), errors="ignore") if sent != [] else '.' for sent in batch]
37+
embeddings = skipthoughts.encode(params['encoder'], batch,
38+
verbose=False, use_eos=True)
39+
return embeddings
40+
41+
42+
# Set params for SentEval
43+
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10, 'batch_size': 512}
44+
params_senteval['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
45+
'tenacity': 5, 'epoch_size': 4}
46+
# Set up logger
47+
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
48+
49+
if __name__ == "__main__":
50+
# Load SkipThought model
51+
params_senteval['encoder'] = skipthoughts.load_model()
52+
53+
se = senteval.engine.SE(params_senteval, batcher, prepare)
54+
transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
55+
'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
56+
'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
57+
'Length', 'WordContent', 'Depth', 'TopConstituents',
58+
'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
59+
'OddManOut', 'CoordinationInversion']
60+
results = se.eval(transfer_tasks)
61+
print(results)

‎SentEval/senteval/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
from __future__ import absolute_import
9+
10+
from senteval.engine import SE

‎SentEval/senteval/binary.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
'''
9+
Binary classifier and corresponding datasets : MR, CR, SUBJ, MPQA
10+
'''
11+
from __future__ import absolute_import, division, unicode_literals
12+
13+
import io
14+
import os
15+
import numpy as np
16+
import logging
17+
18+
from senteval.tools.validation import InnerKFoldClassifier
19+
20+
21+
class BinaryClassifierEval(object):
22+
def __init__(self, pos, neg, seed=1111):
23+
self.seed = seed
24+
self.samples, self.labels = pos + neg, [1] * len(pos) + [0] * len(neg)
25+
self.n_samples = len(self.samples)
26+
27+
def do_prepare(self, params, prepare):
28+
# prepare is given the whole text
29+
return prepare(params, self.samples)
30+
# prepare puts everything it outputs in "params" : params.word2id etc
31+
# Those output will be further used by "batcher".
32+
33+
def loadFile(self, fpath):
34+
with io.open(fpath, 'r', encoding='latin-1') as f:
35+
return [line.split() for line in f.read().splitlines()]
36+
37+
def run(self, params, batcher):
38+
enc_input = []
39+
# Sort to reduce padding
40+
sorted_corpus = sorted(zip(self.samples, self.labels),
41+
key=lambda z: (len(z[0]), z[1]))
42+
sorted_samples = [x for (x, y) in sorted_corpus]
43+
sorted_labels = [y for (x, y) in sorted_corpus]
44+
logging.info('Generating sentence embeddings')
45+
for ii in range(0, self.n_samples, params.batch_size):
46+
batch = sorted_samples[ii:ii + params.batch_size]
47+
embeddings = batcher(params, batch)
48+
enc_input.append(embeddings)
49+
enc_input = np.vstack(enc_input)
50+
logging.info('Generated sentence embeddings')
51+
52+
config = {'nclasses': 2, 'seed': self.seed,
53+
'usepytorch': params.usepytorch,
54+
'classifier': params.classifier,
55+
'nhid': params.nhid, 'kfold': params.kfold}
56+
clf = InnerKFoldClassifier(enc_input, np.array(sorted_labels), config)
57+
devacc, testacc = clf.run()
58+
logging.debug('Dev acc : {0} Test acc : {1}\n'.format(devacc, testacc))
59+
return {'devacc': devacc, 'acc': testacc, 'ndev': self.n_samples,
60+
'ntest': self.n_samples}
61+
62+
63+
class CREval(BinaryClassifierEval):
64+
def __init__(self, task_path, seed=1111):
65+
logging.debug('***** Transfer task : CR *****\n\n')
66+
pos = self.loadFile(os.path.join(task_path, 'custrev.pos'))
67+
neg = self.loadFile(os.path.join(task_path, 'custrev.neg'))
68+
super(self.__class__, self).__init__(pos, neg, seed)
69+
70+
71+
class MREval(BinaryClassifierEval):
72+
def __init__(self, task_path, seed=1111):
73+
logging.debug('***** Transfer task : MR *****\n\n')
74+
pos = self.loadFile(os.path.join(task_path, 'rt-polarity.pos'))
75+
neg = self.loadFile(os.path.join(task_path, 'rt-polarity.neg'))
76+
super(self.__class__, self).__init__(pos, neg, seed)
77+
78+
79+
class SUBJEval(BinaryClassifierEval):
80+
def __init__(self, task_path, seed=1111):
81+
logging.debug('***** Transfer task : SUBJ *****\n\n')
82+
obj = self.loadFile(os.path.join(task_path, 'subj.objective'))
83+
subj = self.loadFile(os.path.join(task_path, 'subj.subjective'))
84+
super(self.__class__, self).__init__(obj, subj, seed)
85+
86+
87+
class MPQAEval(BinaryClassifierEval):
88+
def __init__(self, task_path, seed=1111):
89+
logging.debug('***** Transfer task : MPQA *****\n\n')
90+
pos = self.loadFile(os.path.join(task_path, 'mpqa.pos'))
91+
neg = self.loadFile(os.path.join(task_path, 'mpqa.neg'))
92+
super(self.__class__, self).__init__(pos, neg, seed)

‎SentEval/senteval/engine.py

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
'''
9+
10+
Generic sentence evaluation scripts wrapper
11+
12+
'''
13+
from __future__ import absolute_import, division, unicode_literals
14+
15+
from senteval import utils
16+
from senteval.binary import CREval, MREval, MPQAEval, SUBJEval
17+
from senteval.snli import SNLIEval
18+
from senteval.trec import TRECEval
19+
from senteval.sick import SICKEntailmentEval, SICKEval
20+
from senteval.mrpc import MRPCEval
21+
from senteval.sts import STS12Eval, STS13Eval, STS14Eval, STS15Eval, STS16Eval, STSBenchmarkEval, SICKRelatednessEval, STSBenchmarkFinetune, STSBenchmarkEvalDev
22+
from senteval.sst import SSTEval
23+
from senteval.rank import ImageCaptionRetrievalEval
24+
from senteval.probing import *
25+
26+
class SE(object):
27+
def __init__(self, params, batcher, prepare=None):
28+
# parameters
29+
params = utils.dotdict(params)
30+
params.usepytorch = True if 'usepytorch' not in params else params.usepytorch
31+
params.seed = 1111 if 'seed' not in params else params.seed
32+
33+
params.batch_size = 128 if 'batch_size' not in params else params.batch_size
34+
params.nhid = 0 if 'nhid' not in params else params.nhid
35+
params.kfold = 5 if 'kfold' not in params else params.kfold
36+
37+
if 'classifier' not in params or not params['classifier']:
38+
params.classifier = {'nhid': 0}
39+
40+
assert 'nhid' in params.classifier, 'Set number of hidden units in classifier config!!'
41+
42+
self.params = params
43+
44+
# batcher and prepare
45+
self.batcher = batcher
46+
self.prepare = prepare if prepare else lambda x, y: None
47+
48+
self.list_tasks = ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
49+
'SICKRelatedness', 'SICKEntailment', 'STSBenchmark',
50+
'SNLI', 'ImageCaptionRetrieval', 'STS12', 'STS13',
51+
'STS14', 'STS15', 'STS16',
52+
'Length', 'WordContent', 'Depth', 'TopConstituents',
53+
'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
54+
'OddManOut', 'CoordinationInversion', 'SICKRelatedness-finetune', 'STSBenchmark-finetune', 'STSBenchmark-fix', 'STSBenchmark-dev']
55+
56+
def eval(self, name):
57+
# evaluate on evaluation [name], either takes string or list of strings
58+
if (isinstance(name, list)):
59+
self.results = {x: self.eval(x) for x in name}
60+
return self.results
61+
62+
tpath = self.params.task_path
63+
assert name in self.list_tasks, str(name) + ' not in ' + str(self.list_tasks)
64+
65+
# Original SentEval tasks
66+
if name == 'CR':
67+
self.evaluation = CREval(tpath + '/downstream/CR', seed=self.params.seed)
68+
elif name == 'MR':
69+
self.evaluation = MREval(tpath + '/downstream/MR', seed=self.params.seed)
70+
elif name == 'MPQA':
71+
self.evaluation = MPQAEval(tpath + '/downstream/MPQA', seed=self.params.seed)
72+
elif name == 'SUBJ':
73+
self.evaluation = SUBJEval(tpath + '/downstream/SUBJ', seed=self.params.seed)
74+
elif name == 'SST2':
75+
self.evaluation = SSTEval(tpath + '/downstream/SST/binary', nclasses=2, seed=self.params.seed)
76+
elif name == 'SST5':
77+
self.evaluation = SSTEval(tpath + '/downstream/SST/fine', nclasses=5, seed=self.params.seed)
78+
elif name == 'TREC':
79+
self.evaluation = TRECEval(tpath + '/downstream/TREC', seed=self.params.seed)
80+
elif name == 'MRPC':
81+
self.evaluation = MRPCEval(tpath + '/downstream/MRPC', seed=self.params.seed)
82+
elif name == 'SICKRelatedness':
83+
self.evaluation = SICKRelatednessEval(tpath + '/downstream/SICK', seed=self.params.seed)
84+
elif name == 'STSBenchmark':
85+
self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed)
86+
elif name == 'STSBenchmark-dev':
87+
self.evaluation = STSBenchmarkEvalDev(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed)
88+
elif name == 'STSBenchmark-fix':
89+
self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark-fix', seed=self.params.seed)
90+
elif name == 'STSBenchmark-finetune':
91+
self.evaluation = STSBenchmarkFinetune(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed)
92+
elif name == 'SICKRelatedness-finetune':
93+
self.evaluation = SICKEval(tpath + '/downstream/SICK', seed=self.params.seed)
94+
elif name == 'SICKEntailment':
95+
self.evaluation = SICKEntailmentEval(tpath + '/downstream/SICK', seed=self.params.seed)
96+
elif name == 'SNLI':
97+
self.evaluation = SNLIEval(tpath + '/downstream/SNLI', seed=self.params.seed)
98+
elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']:
99+
fpath = name + '-en-test'
100+
self.evaluation = eval(name + 'Eval')(tpath + '/downstream/STS/' + fpath, seed=self.params.seed)
101+
elif name == 'ImageCaptionRetrieval':
102+
self.evaluation = ImageCaptionRetrievalEval(tpath + '/downstream/COCO', seed=self.params.seed)
103+
104+
# Probing Tasks
105+
elif name == 'Length':
106+
self.evaluation = LengthEval(tpath + '/probing', seed=self.params.seed)
107+
elif name == 'WordContent':
108+
self.evaluation = WordContentEval(tpath + '/probing', seed=self.params.seed)
109+
elif name == 'Depth':
110+
self.evaluation = DepthEval(tpath + '/probing', seed=self.params.seed)
111+
elif name == 'TopConstituents':
112+
self.evaluation = TopConstituentsEval(tpath + '/probing', seed=self.params.seed)
113+
elif name == 'BigramShift':
114+
self.evaluation = BigramShiftEval(tpath + '/probing', seed=self.params.seed)
115+
elif name == 'Tense':
116+
self.evaluation = TenseEval(tpath + '/probing', seed=self.params.seed)
117+
elif name == 'SubjNumber':
118+
self.evaluation = SubjNumberEval(tpath + '/probing', seed=self.params.seed)
119+
elif name == 'ObjNumber':
120+
self.evaluation = ObjNumberEval(tpath + '/probing', seed=self.params.seed)
121+
elif name == 'OddManOut':
122+
self.evaluation = OddManOutEval(tpath + '/probing', seed=self.params.seed)
123+
elif name == 'CoordinationInversion':
124+
self.evaluation = CoordinationInversionEval(tpath + '/probing', seed=self.params.seed)
125+
126+
self.params.current_task = name
127+
self.evaluation.do_prepare(self.params, self.prepare)
128+
129+
self.results = self.evaluation.run(self.params, self.batcher)
130+
131+
return self.results

‎SentEval/senteval/mrpc.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
'''
9+
MRPC : Microsoft Research Paraphrase (detection) Corpus
10+
'''
11+
from __future__ import absolute_import, division, unicode_literals
12+
13+
import os
14+
import logging
15+
import numpy as np
16+
import io
17+
18+
from senteval.tools.validation import KFoldClassifier
19+
20+
from sklearn.metrics import f1_score
21+
22+
23+
class MRPCEval(object):
24+
def __init__(self, task_path, seed=1111):
25+
logging.info('***** Transfer task : MRPC *****\n\n')
26+
self.seed = seed
27+
train = self.loadFile(os.path.join(task_path,
28+
'msr_paraphrase_train.txt'))
29+
test = self.loadFile(os.path.join(task_path,
30+
'msr_paraphrase_test.txt'))
31+
self.mrpc_data = {'train': train, 'test': test}
32+
33+
def do_prepare(self, params, prepare):
34+
# TODO : Should we separate samples in "train, test"?
35+
samples = self.mrpc_data['train']['X_A'] + \
36+
self.mrpc_data['train']['X_B'] + \
37+
self.mrpc_data['test']['X_A'] + self.mrpc_data['test']['X_B']
38+
return prepare(params, samples)
39+
40+
def loadFile(self, fpath):
41+
mrpc_data = {'X_A': [], 'X_B': [], 'y': []}
42+
with io.open(fpath, 'r', encoding='utf-8') as f:
43+
for line in f:
44+
text = line.strip().split('\t')
45+
mrpc_data['X_A'].append(text[3].split())
46+
mrpc_data['X_B'].append(text[4].split())
47+
mrpc_data['y'].append(text[0])
48+
49+
mrpc_data['X_A'] = mrpc_data['X_A'][1:]
50+
mrpc_data['X_B'] = mrpc_data['X_B'][1:]
51+
mrpc_data['y'] = [int(s) for s in mrpc_data['y'][1:]]
52+
return mrpc_data
53+
54+
def run(self, params, batcher):
55+
mrpc_embed = {'train': {}, 'test': {}}
56+
57+
for key in self.mrpc_data:
58+
logging.info('Computing embedding for {0}'.format(key))
59+
# Sort to reduce padding
60+
text_data = {}
61+
sorted_corpus = sorted(zip(self.mrpc_data[key]['X_A'],
62+
self.mrpc_data[key]['X_B'],
63+
self.mrpc_data[key]['y']),
64+
key=lambda z: (len(z[0]), len(z[1]), z[2]))
65+
66+
text_data['A'] = [x for (x, y, z) in sorted_corpus]
67+
text_data['B'] = [y for (x, y, z) in sorted_corpus]
68+
text_data['y'] = [z for (x, y, z) in sorted_corpus]
69+
70+
for txt_type in ['A', 'B']:
71+
mrpc_embed[key][txt_type] = []
72+
for ii in range(0, len(text_data['y']), params.batch_size):
73+
batch = text_data[txt_type][ii:ii + params.batch_size]
74+
embeddings = batcher(params, batch)
75+
mrpc_embed[key][txt_type].append(embeddings)
76+
mrpc_embed[key][txt_type] = np.vstack(mrpc_embed[key][txt_type])
77+
mrpc_embed[key]['y'] = np.array(text_data['y'])
78+
logging.info('Computed {0} embeddings'.format(key))
79+
80+
# Train
81+
trainA = mrpc_embed['train']['A']
82+
trainB = mrpc_embed['train']['B']
83+
trainF = np.c_[np.abs(trainA - trainB), trainA * trainB]
84+
trainY = mrpc_embed['train']['y']
85+
86+
# Test
87+
testA = mrpc_embed['test']['A']
88+
testB = mrpc_embed['test']['B']
89+
testF = np.c_[np.abs(testA - testB), testA * testB]
90+
testY = mrpc_embed['test']['y']
91+
92+
config = {'nclasses': 2, 'seed': self.seed,
93+
'usepytorch': params.usepytorch,
94+
'classifier': params.classifier,
95+
'nhid': params.nhid, 'kfold': params.kfold}
96+
clf = KFoldClassifier(train={'X': trainF, 'y': trainY},
97+
test={'X': testF, 'y': testY}, config=config)
98+
99+
devacc, testacc, yhat = clf.run()
100+
testf1 = round(100*f1_score(testY, yhat), 2)
101+
logging.debug('Dev acc : {0} Test acc {1}; Test F1 {2} for MRPC.\n'
102+
.format(devacc, testacc, testf1))
103+
return {'devacc': devacc, 'acc': testacc, 'f1': testf1,
104+
'ndev': len(trainA), 'ntest': len(testA)}

‎SentEval/senteval/probing.py

+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
'''
9+
probing tasks
10+
'''
11+
12+
from __future__ import absolute_import, division, unicode_literals
13+
14+
import os
15+
import io
16+
import copy
17+
import logging
18+
import numpy as np
19+
20+
from senteval.tools.validation import SplitClassifier
21+
22+
23+
class PROBINGEval(object):
24+
def __init__(self, task, task_path, seed=1111):
25+
self.seed = seed
26+
self.task = task
27+
logging.debug('***** (Probing) Transfer task : %s classification *****', self.task.upper())
28+
self.task_data = {'train': {'X': [], 'y': []},
29+
'dev': {'X': [], 'y': []},
30+
'test': {'X': [], 'y': []}}
31+
self.loadFile(task_path)
32+
logging.info('Loaded %s train - %s dev - %s test for %s' %
33+
(len(self.task_data['train']['y']), len(self.task_data['dev']['y']),
34+
len(self.task_data['test']['y']), self.task))
35+
36+
def do_prepare(self, params, prepare):
37+
samples = self.task_data['train']['X'] + self.task_data['dev']['X'] + \
38+
self.task_data['test']['X']
39+
return prepare(params, samples)
40+
41+
def loadFile(self, fpath):
42+
self.tok2split = {'tr': 'train', 'va': 'dev', 'te': 'test'}
43+
with io.open(fpath, 'r', encoding='utf-8') as f:
44+
for line in f:
45+
line = line.rstrip().split('\t')
46+
self.task_data[self.tok2split[line[0]]]['X'].append(line[-1].split())
47+
self.task_data[self.tok2split[line[0]]]['y'].append(line[1])
48+
49+
labels = sorted(np.unique(self.task_data['train']['y']))
50+
self.tok2label = dict(zip(labels, range(len(labels))))
51+
self.nclasses = len(self.tok2label)
52+
53+
for split in self.task_data:
54+
for i, y in enumerate(self.task_data[split]['y']):
55+
self.task_data[split]['y'][i] = self.tok2label[y]
56+
57+
def run(self, params, batcher):
58+
task_embed = {'train': {}, 'dev': {}, 'test': {}}
59+
bsize = params.batch_size
60+
logging.info('Computing embeddings for train/dev/test')
61+
for key in self.task_data:
62+
# Sort to reduce padding
63+
sorted_data = sorted(zip(self.task_data[key]['X'],
64+
self.task_data[key]['y']),
65+
key=lambda z: (len(z[0]), z[1]))
66+
self.task_data[key]['X'], self.task_data[key]['y'] = map(list, zip(*sorted_data))
67+
68+
task_embed[key]['X'] = []
69+
for ii in range(0, len(self.task_data[key]['y']), bsize):
70+
batch = self.task_data[key]['X'][ii:ii + bsize]
71+
embeddings = batcher(params, batch)
72+
task_embed[key]['X'].append(embeddings)
73+
task_embed[key]['X'] = np.vstack(task_embed[key]['X'])
74+
task_embed[key]['y'] = np.array(self.task_data[key]['y'])
75+
logging.info('Computed embeddings')
76+
77+
config_classifier = {'nclasses': self.nclasses, 'seed': self.seed,
78+
'usepytorch': params.usepytorch,
79+
'classifier': params.classifier}
80+
81+
if self.task == "WordContent" and params.classifier['nhid'] > 0:
82+
config_classifier = copy.deepcopy(config_classifier)
83+
config_classifier['classifier']['nhid'] = 0
84+
print(params.classifier['nhid'])
85+
86+
clf = SplitClassifier(X={'train': task_embed['train']['X'],
87+
'valid': task_embed['dev']['X'],
88+
'test': task_embed['test']['X']},
89+
y={'train': task_embed['train']['y'],
90+
'valid': task_embed['dev']['y'],
91+
'test': task_embed['test']['y']},
92+
config=config_classifier)
93+
94+
devacc, testacc = clf.run()
95+
logging.debug('\nDev acc : %.1f Test acc : %.1f for %s classification\n' % (devacc, testacc, self.task.upper()))
96+
97+
return {'devacc': devacc, 'acc': testacc,
98+
'ndev': len(task_embed['dev']['X']),
99+
'ntest': len(task_embed['test']['X'])}
100+
101+
"""
102+
Surface Information
103+
"""
104+
class LengthEval(PROBINGEval):
105+
def __init__(self, task_path, seed=1111):
106+
task_path = os.path.join(task_path, 'sentence_length.txt')
107+
# labels: bins
108+
PROBINGEval.__init__(self, 'Length', task_path, seed)
109+
110+
class WordContentEval(PROBINGEval):
111+
def __init__(self, task_path, seed=1111):
112+
task_path = os.path.join(task_path, 'word_content.txt')
113+
# labels: 200 target words
114+
PROBINGEval.__init__(self, 'WordContent', task_path, seed)
115+
116+
"""
117+
Latent Structural Information
118+
"""
119+
class DepthEval(PROBINGEval):
120+
def __init__(self, task_path, seed=1111):
121+
task_path = os.path.join(task_path, 'tree_depth.txt')
122+
# labels: bins
123+
PROBINGEval.__init__(self, 'Depth', task_path, seed)
124+
125+
class TopConstituentsEval(PROBINGEval):
126+
def __init__(self, task_path, seed=1111):
127+
task_path = os.path.join(task_path, 'top_constituents.txt')
128+
# labels: 'PP_NP_VP_.' .. (20 classes)
129+
PROBINGEval.__init__(self, 'TopConstituents', task_path, seed)
130+
131+
class BigramShiftEval(PROBINGEval):
132+
def __init__(self, task_path, seed=1111):
133+
task_path = os.path.join(task_path, 'bigram_shift.txt')
134+
# labels: 0 or 1
135+
PROBINGEval.__init__(self, 'BigramShift', task_path, seed)
136+
137+
# TODO: Voice?
138+
139+
"""
140+
Latent Semantic Information
141+
"""
142+
143+
class TenseEval(PROBINGEval):
144+
def __init__(self, task_path, seed=1111):
145+
task_path = os.path.join(task_path, 'past_present.txt')
146+
# labels: 'PRES', 'PAST'
147+
PROBINGEval.__init__(self, 'Tense', task_path, seed)
148+
149+
class SubjNumberEval(PROBINGEval):
150+
def __init__(self, task_path, seed=1111):
151+
task_path = os.path.join(task_path, 'subj_number.txt')
152+
# labels: 'NN', 'NNS'
153+
PROBINGEval.__init__(self, 'SubjNumber', task_path, seed)
154+
155+
class ObjNumberEval(PROBINGEval):
156+
def __init__(self, task_path, seed=1111):
157+
task_path = os.path.join(task_path, 'obj_number.txt')
158+
# labels: 'NN', 'NNS'
159+
PROBINGEval.__init__(self, 'ObjNumber', task_path, seed)
160+
161+
class OddManOutEval(PROBINGEval):
162+
def __init__(self, task_path, seed=1111):
163+
task_path = os.path.join(task_path, 'odd_man_out.txt')
164+
# labels: 'O', 'C'
165+
PROBINGEval.__init__(self, 'OddManOut', task_path, seed)
166+
167+
class CoordinationInversionEval(PROBINGEval):
168+
def __init__(self, task_path, seed=1111):
169+
task_path = os.path.join(task_path, 'coordination_inversion.txt')
170+
# labels: 'O', 'I'
171+
PROBINGEval.__init__(self, 'CoordinationInversion', task_path, seed)

‎SentEval/senteval/rank.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
'''
9+
Image-Caption Retrieval with COCO dataset
10+
'''
11+
from __future__ import absolute_import, division, unicode_literals
12+
13+
import os
14+
import sys
15+
import logging
16+
import numpy as np
17+
18+
try:
19+
import cPickle as pickle
20+
except ImportError:
21+
import pickle
22+
23+
from senteval.tools.ranking import ImageSentenceRankingPytorch
24+
25+
26+
class ImageCaptionRetrievalEval(object):
27+
def __init__(self, task_path, seed=1111):
28+
logging.debug('***** Transfer task: Image Caption Retrieval *****\n\n')
29+
30+
# Get captions and image features
31+
self.seed = seed
32+
train, dev, test = self.loadFile(task_path)
33+
self.coco_data = {'train': train, 'dev': dev, 'test': test}
34+
35+
def do_prepare(self, params, prepare):
36+
samples = self.coco_data['train']['sent'] + \
37+
self.coco_data['dev']['sent'] + \
38+
self.coco_data['test']['sent']
39+
prepare(params, samples)
40+
41+
def loadFile(self, fpath):
42+
coco = {}
43+
44+
for split in ['train', 'valid', 'test']:
45+
list_sent = []
46+
list_img_feat = []
47+
if sys.version_info < (3, 0):
48+
with open(os.path.join(fpath, split + '.pkl')) as f:
49+
cocodata = pickle.load(f)
50+
else:
51+
with open(os.path.join(fpath, split + '.pkl'), 'rb') as f:
52+
cocodata = pickle.load(f, encoding='latin1')
53+
54+
for imgkey in range(len(cocodata['features'])):
55+
assert len(cocodata['image_to_caption_ids'][imgkey]) >= 5, \
56+
cocodata['image_to_caption_ids'][imgkey]
57+
for captkey in cocodata['image_to_caption_ids'][imgkey][0:5]:
58+
sent = cocodata['captions'][captkey]['cleaned_caption']
59+
sent += ' .' # add punctuation to end of sentence in COCO
60+
list_sent.append(sent.encode('utf-8').split())
61+
list_img_feat.append(cocodata['features'][imgkey])
62+
assert len(list_sent) == len(list_img_feat) and \
63+
len(list_sent) % 5 == 0
64+
list_img_feat = np.array(list_img_feat).astype('float32')
65+
coco[split] = {'sent': list_sent, 'imgfeat': list_img_feat}
66+
return coco['train'], coco['valid'], coco['test']
67+
68+
def run(self, params, batcher):
69+
coco_embed = {'train': {'sentfeat': [], 'imgfeat': []},
70+
'dev': {'sentfeat': [], 'imgfeat': []},
71+
'test': {'sentfeat': [], 'imgfeat': []}}
72+
73+
for key in self.coco_data:
74+
logging.info('Computing embedding for {0}'.format(key))
75+
# Sort to reduce padding
76+
self.coco_data[key]['sent'] = np.array(self.coco_data[key]['sent'])
77+
self.coco_data[key]['sent'], idx_sort = np.sort(self.coco_data[key]['sent']), np.argsort(self.coco_data[key]['sent'])
78+
idx_unsort = np.argsort(idx_sort)
79+
80+
coco_embed[key]['X'] = []
81+
nsent = len(self.coco_data[key]['sent'])
82+
for ii in range(0, nsent, params.batch_size):
83+
batch = self.coco_data[key]['sent'][ii:ii + params.batch_size]
84+
embeddings = batcher(params, batch)
85+
coco_embed[key]['sentfeat'].append(embeddings)
86+
coco_embed[key]['sentfeat'] = np.vstack(coco_embed[key]['sentfeat'])[idx_unsort]
87+
coco_embed[key]['imgfeat'] = np.array(self.coco_data[key]['imgfeat'])
88+
logging.info('Computed {0} embeddings'.format(key))
89+
90+
config = {'seed': self.seed, 'projdim': 1000, 'margin': 0.2}
91+
clf = ImageSentenceRankingPytorch(train=coco_embed['train'],
92+
valid=coco_embed['dev'],
93+
test=coco_embed['test'],
94+
config=config)
95+
96+
bestdevscore, r1_i2t, r5_i2t, r10_i2t, medr_i2t, \
97+
r1_t2i, r5_t2i, r10_t2i, medr_t2i = clf.run()
98+
99+
logging.debug("\nTest scores | Image to text: \
100+
{0}, {1}, {2}, {3}".format(r1_i2t, r5_i2t, r10_i2t, medr_i2t))
101+
logging.debug("Test scores | Text to image: \
102+
{0}, {1}, {2}, {3}\n".format(r1_t2i, r5_t2i, r10_t2i, medr_t2i))
103+
104+
return {'devacc': bestdevscore,
105+
'acc': [(r1_i2t, r5_i2t, r10_i2t, medr_i2t),
106+
(r1_t2i, r5_t2i, r10_t2i, medr_t2i)],
107+
'ndev': len(coco_embed['dev']['sentfeat']),
108+
'ntest': len(coco_embed['test']['sentfeat'])}

‎SentEval/senteval/sick.py

+216
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
'''
9+
SICK Relatedness and Entailment
10+
'''
11+
from __future__ import absolute_import, division, unicode_literals
12+
13+
import os
14+
import io
15+
import logging
16+
import numpy as np
17+
18+
from sklearn.metrics import mean_squared_error
19+
from scipy.stats import pearsonr, spearmanr
20+
21+
from senteval.tools.relatedness import RelatednessPytorch
22+
from senteval.tools.validation import SplitClassifier
23+
24+
class SICKEval(object):
25+
def __init__(self, task_path, seed=1111):
26+
logging.debug('***** Transfer task : SICK-Relatedness*****\n\n')
27+
self.seed = seed
28+
train = self.loadFile(os.path.join(task_path, 'SICK_train.txt'))
29+
dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt'))
30+
test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt'))
31+
self.sick_data = {'train': train, 'dev': dev, 'test': test}
32+
33+
def do_prepare(self, params, prepare):
34+
samples = self.sick_data['train']['X_A'] + \
35+
self.sick_data['train']['X_B'] + \
36+
self.sick_data['dev']['X_A'] + \
37+
self.sick_data['dev']['X_B'] + \
38+
self.sick_data['test']['X_A'] + self.sick_data['test']['X_B']
39+
return prepare(params, samples)
40+
41+
def loadFile(self, fpath):
42+
skipFirstLine = True
43+
sick_data = {'X_A': [], 'X_B': [], 'y': []}
44+
with io.open(fpath, 'r', encoding='utf-8') as f:
45+
for line in f:
46+
if skipFirstLine:
47+
skipFirstLine = False
48+
else:
49+
text = line.strip().split('\t')
50+
sick_data['X_A'].append(text[1].split())
51+
sick_data['X_B'].append(text[2].split())
52+
sick_data['y'].append(text[3])
53+
54+
sick_data['y'] = [float(s) for s in sick_data['y']]
55+
return sick_data
56+
57+
def run(self, params, batcher):
58+
sick_embed = {'train': {}, 'dev': {}, 'test': {}}
59+
bsize = params.batch_size
60+
61+
for key in self.sick_data:
62+
logging.info('Computing embedding for {0}'.format(key))
63+
# Sort to reduce padding
64+
sorted_corpus = sorted(zip(self.sick_data[key]['X_A'],
65+
self.sick_data[key]['X_B'],
66+
self.sick_data[key]['y']),
67+
key=lambda z: (len(z[0]), len(z[1]), z[2]))
68+
69+
self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus]
70+
self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus]
71+
self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus]
72+
73+
for txt_type in ['X_A', 'X_B']:
74+
sick_embed[key][txt_type] = []
75+
for ii in range(0, len(self.sick_data[key]['y']), bsize):
76+
batch = self.sick_data[key][txt_type][ii:ii + bsize]
77+
embeddings = batcher(params, batch)
78+
sick_embed[key][txt_type].append(embeddings)
79+
sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type])
80+
sick_embed[key]['y'] = np.array(self.sick_data[key]['y'])
81+
logging.info('Computed {0} embeddings'.format(key))
82+
83+
# Train
84+
trainA = sick_embed['train']['X_A']
85+
trainB = sick_embed['train']['X_B']
86+
trainF = np.c_[np.abs(trainA - trainB), trainA * trainB]
87+
trainY = self.encode_labels(self.sick_data['train']['y'])
88+
89+
# Dev
90+
devA = sick_embed['dev']['X_A']
91+
devB = sick_embed['dev']['X_B']
92+
devF = np.c_[np.abs(devA - devB), devA * devB]
93+
devY = self.encode_labels(self.sick_data['dev']['y'])
94+
95+
# Test
96+
testA = sick_embed['test']['X_A']
97+
testB = sick_embed['test']['X_B']
98+
testF = np.c_[np.abs(testA - testB), testA * testB]
99+
testY = self.encode_labels(self.sick_data['test']['y'])
100+
101+
config = {'seed': self.seed, 'nclasses': 5}
102+
clf = RelatednessPytorch(train={'X': trainF, 'y': trainY},
103+
valid={'X': devF, 'y': devY},
104+
test={'X': testF, 'y': testY},
105+
devscores=self.sick_data['dev']['y'],
106+
config=config)
107+
108+
devspr, yhat = clf.run()
109+
110+
pr = pearsonr(yhat, self.sick_data['test']['y'])[0]
111+
sr = spearmanr(yhat, self.sick_data['test']['y'])[0]
112+
pr = 0 if pr != pr else pr
113+
sr = 0 if sr != sr else sr
114+
se = mean_squared_error(yhat, self.sick_data['test']['y'])
115+
logging.debug('Dev : Spearman {0}'.format(devspr))
116+
logging.debug('Test : Pearson {0} Spearman {1} MSE {2} \
117+
for SICK Relatedness\n'.format(pr, sr, se))
118+
119+
return {'devspearman': devspr, 'pearson': pr, 'spearman': sr, 'mse': se,
120+
'yhat': yhat, 'ndev': len(devA), 'ntest': len(testA)}
121+
122+
def encode_labels(self, labels, nclass=5):
123+
"""
124+
Label encoding from Tree LSTM paper (Tai, Socher, Manning)
125+
"""
126+
Y = np.zeros((len(labels), nclass)).astype('float32')
127+
for j, y in enumerate(labels):
128+
for i in range(nclass):
129+
if i+1 == np.floor(y) + 1:
130+
Y[j, i] = y - np.floor(y)
131+
if i+1 == np.floor(y):
132+
Y[j, i] = np.floor(y) - y + 1
133+
return Y
134+
135+
136+
class SICKEntailmentEval(SICKEval):
137+
def __init__(self, task_path, seed=1111):
138+
logging.debug('***** Transfer task : SICK-Entailment*****\n\n')
139+
self.seed = seed
140+
train = self.loadFile(os.path.join(task_path, 'SICK_train.txt'))
141+
dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt'))
142+
test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt'))
143+
self.sick_data = {'train': train, 'dev': dev, 'test': test}
144+
145+
def loadFile(self, fpath):
146+
label2id = {'CONTRADICTION': 0, 'NEUTRAL': 1, 'ENTAILMENT': 2}
147+
skipFirstLine = True
148+
sick_data = {'X_A': [], 'X_B': [], 'y': []}
149+
with io.open(fpath, 'r', encoding='utf-8') as f:
150+
for line in f:
151+
if skipFirstLine:
152+
skipFirstLine = False
153+
else:
154+
text = line.strip().split('\t')
155+
sick_data['X_A'].append(text[1].split())
156+
sick_data['X_B'].append(text[2].split())
157+
sick_data['y'].append(text[4])
158+
sick_data['y'] = [label2id[s] for s in sick_data['y']]
159+
return sick_data
160+
161+
def run(self, params, batcher):
162+
sick_embed = {'train': {}, 'dev': {}, 'test': {}}
163+
bsize = params.batch_size
164+
165+
for key in self.sick_data:
166+
logging.info('Computing embedding for {0}'.format(key))
167+
# Sort to reduce padding
168+
sorted_corpus = sorted(zip(self.sick_data[key]['X_A'],
169+
self.sick_data[key]['X_B'],
170+
self.sick_data[key]['y']),
171+
key=lambda z: (len(z[0]), len(z[1]), z[2]))
172+
173+
self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus]
174+
self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus]
175+
self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus]
176+
177+
for txt_type in ['X_A', 'X_B']:
178+
sick_embed[key][txt_type] = []
179+
for ii in range(0, len(self.sick_data[key]['y']), bsize):
180+
batch = self.sick_data[key][txt_type][ii:ii + bsize]
181+
embeddings = batcher(params, batch)
182+
sick_embed[key][txt_type].append(embeddings)
183+
sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type])
184+
logging.info('Computed {0} embeddings'.format(key))
185+
186+
# Train
187+
trainA = sick_embed['train']['X_A']
188+
trainB = sick_embed['train']['X_B']
189+
trainF = np.c_[np.abs(trainA - trainB), trainA * trainB]
190+
trainY = np.array(self.sick_data['train']['y'])
191+
192+
# Dev
193+
devA = sick_embed['dev']['X_A']
194+
devB = sick_embed['dev']['X_B']
195+
devF = np.c_[np.abs(devA - devB), devA * devB]
196+
devY = np.array(self.sick_data['dev']['y'])
197+
198+
# Test
199+
testA = sick_embed['test']['X_A']
200+
testB = sick_embed['test']['X_B']
201+
testF = np.c_[np.abs(testA - testB), testA * testB]
202+
testY = np.array(self.sick_data['test']['y'])
203+
204+
config = {'nclasses': 3, 'seed': self.seed,
205+
'usepytorch': params.usepytorch,
206+
'classifier': params.classifier,
207+
'nhid': params.nhid}
208+
clf = SplitClassifier(X={'train': trainF, 'valid': devF, 'test': testF},
209+
y={'train': trainY, 'valid': devY, 'test': testY},
210+
config=config)
211+
212+
devacc, testacc = clf.run()
213+
logging.debug('\nDev acc : {0} Test acc : {1} for \
214+
SICK entailment\n'.format(devacc, testacc))
215+
return {'devacc': devacc, 'acc': testacc,
216+
'ndev': len(devA), 'ntest': len(testA)}

‎SentEval/senteval/snli.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
'''
9+
SNLI - Entailment
10+
'''
11+
from __future__ import absolute_import, division, unicode_literals
12+
13+
import codecs
14+
import os
15+
import io
16+
import copy
17+
import logging
18+
import numpy as np
19+
20+
from senteval.tools.validation import SplitClassifier
21+
22+
23+
class SNLIEval(object):
24+
def __init__(self, taskpath, seed=1111):
25+
logging.debug('***** Transfer task : SNLI Entailment*****\n\n')
26+
self.seed = seed
27+
train1 = self.loadFile(os.path.join(taskpath, 's1.train'))
28+
train2 = self.loadFile(os.path.join(taskpath, 's2.train'))
29+
30+
trainlabels = io.open(os.path.join(taskpath, 'labels.train'),
31+
encoding='utf-8').read().splitlines()
32+
33+
valid1 = self.loadFile(os.path.join(taskpath, 's1.dev'))
34+
valid2 = self.loadFile(os.path.join(taskpath, 's2.dev'))
35+
validlabels = io.open(os.path.join(taskpath, 'labels.dev'),
36+
encoding='utf-8').read().splitlines()
37+
38+
test1 = self.loadFile(os.path.join(taskpath, 's1.test'))
39+
test2 = self.loadFile(os.path.join(taskpath, 's2.test'))
40+
testlabels = io.open(os.path.join(taskpath, 'labels.test'),
41+
encoding='utf-8').read().splitlines()
42+
43+
# sort data (by s2 first) to reduce padding
44+
sorted_train = sorted(zip(train2, train1, trainlabels),
45+
key=lambda z: (len(z[0]), len(z[1]), z[2]))
46+
train2, train1, trainlabels = map(list, zip(*sorted_train))
47+
48+
sorted_valid = sorted(zip(valid2, valid1, validlabels),
49+
key=lambda z: (len(z[0]), len(z[1]), z[2]))
50+
valid2, valid1, validlabels = map(list, zip(*sorted_valid))
51+
52+
sorted_test = sorted(zip(test2, test1, testlabels),
53+
key=lambda z: (len(z[0]), len(z[1]), z[2]))
54+
test2, test1, testlabels = map(list, zip(*sorted_test))
55+
56+
self.samples = train1 + train2 + valid1 + valid2 + test1 + test2
57+
self.data = {'train': (train1, train2, trainlabels),
58+
'valid': (valid1, valid2, validlabels),
59+
'test': (test1, test2, testlabels)
60+
}
61+
62+
def do_prepare(self, params, prepare):
63+
return prepare(params, self.samples)
64+
65+
def loadFile(self, fpath):
66+
with codecs.open(fpath, 'rb', 'latin-1') as f:
67+
return [line.split() for line in
68+
f.read().splitlines()]
69+
70+
def run(self, params, batcher):
71+
self.X, self.y = {}, {}
72+
dico_label = {'entailment': 0, 'neutral': 1, 'contradiction': 2}
73+
for key in self.data:
74+
if key not in self.X:
75+
self.X[key] = []
76+
if key not in self.y:
77+
self.y[key] = []
78+
79+
input1, input2, mylabels = self.data[key]
80+
enc_input = []
81+
n_labels = len(mylabels)
82+
for ii in range(0, n_labels, params.batch_size):
83+
batch1 = input1[ii:ii + params.batch_size]
84+
batch2 = input2[ii:ii + params.batch_size]
85+
86+
if len(batch1) == len(batch2) and len(batch1) > 0:
87+
enc1 = batcher(params, batch1)
88+
enc2 = batcher(params, batch2)
89+
enc_input.append(np.hstack((enc1, enc2, enc1 * enc2,
90+
np.abs(enc1 - enc2))))
91+
if (ii*params.batch_size) % (20000*params.batch_size) == 0:
92+
logging.info("PROGRESS (encoding): %.2f%%" %
93+
(100 * ii / n_labels))
94+
self.X[key] = np.vstack(enc_input)
95+
self.y[key] = [dico_label[y] for y in mylabels]
96+
97+
config = {'nclasses': 3, 'seed': self.seed,
98+
'usepytorch': params.usepytorch,
99+
'cudaEfficient': True,
100+
'nhid': params.nhid, 'noreg': True}
101+
102+
config_classifier = copy.deepcopy(params.classifier)
103+
config_classifier['max_epoch'] = 15
104+
config_classifier['epoch_size'] = 1
105+
config['classifier'] = config_classifier
106+
107+
clf = SplitClassifier(self.X, self.y, config)
108+
devacc, testacc = clf.run()
109+
logging.debug('Dev acc : {0} Test acc : {1} for SNLI\n'
110+
.format(devacc, testacc))
111+
return {'devacc': devacc, 'acc': testacc,
112+
'ndev': len(self.data['valid'][0]),
113+
'ntest': len(self.data['test'][0])}

‎SentEval/senteval/sst.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
'''
9+
SST - binary classification
10+
'''
11+
12+
from __future__ import absolute_import, division, unicode_literals
13+
14+
import os
15+
import io
16+
import logging
17+
import numpy as np
18+
19+
from senteval.tools.validation import SplitClassifier
20+
21+
22+
class SSTEval(object):
23+
def __init__(self, task_path, nclasses=2, seed=1111):
24+
self.seed = seed
25+
26+
# binary of fine-grained
27+
assert nclasses in [2, 5]
28+
self.nclasses = nclasses
29+
self.task_name = 'Binary' if self.nclasses == 2 else 'Fine-Grained'
30+
logging.debug('***** Transfer task : SST %s classification *****\n\n', self.task_name)
31+
32+
train = self.loadFile(os.path.join(task_path, 'sentiment-train'))
33+
dev = self.loadFile(os.path.join(task_path, 'sentiment-dev'))
34+
test = self.loadFile(os.path.join(task_path, 'sentiment-test'))
35+
self.sst_data = {'train': train, 'dev': dev, 'test': test}
36+
37+
def do_prepare(self, params, prepare):
38+
samples = self.sst_data['train']['X'] + self.sst_data['dev']['X'] + \
39+
self.sst_data['test']['X']
40+
return prepare(params, samples)
41+
42+
def loadFile(self, fpath):
43+
sst_data = {'X': [], 'y': []}
44+
with io.open(fpath, 'r', encoding='utf-8') as f:
45+
for line in f:
46+
if self.nclasses == 2:
47+
sample = line.strip().split('\t')
48+
sst_data['y'].append(int(sample[1]))
49+
sst_data['X'].append(sample[0].split())
50+
elif self.nclasses == 5:
51+
sample = line.strip().split(' ', 1)
52+
sst_data['y'].append(int(sample[0]))
53+
sst_data['X'].append(sample[1].split())
54+
assert max(sst_data['y']) == self.nclasses - 1
55+
return sst_data
56+
57+
def run(self, params, batcher):
58+
sst_embed = {'train': {}, 'dev': {}, 'test': {}}
59+
bsize = params.batch_size
60+
61+
for key in self.sst_data:
62+
logging.info('Computing embedding for {0}'.format(key))
63+
# Sort to reduce padding
64+
sorted_data = sorted(zip(self.sst_data[key]['X'],
65+
self.sst_data[key]['y']),
66+
key=lambda z: (len(z[0]), z[1]))
67+
self.sst_data[key]['X'], self.sst_data[key]['y'] = map(list, zip(*sorted_data))
68+
69+
sst_embed[key]['X'] = []
70+
for ii in range(0, len(self.sst_data[key]['y']), bsize):
71+
batch = self.sst_data[key]['X'][ii:ii + bsize]
72+
embeddings = batcher(params, batch)
73+
sst_embed[key]['X'].append(embeddings)
74+
sst_embed[key]['X'] = np.vstack(sst_embed[key]['X'])
75+
sst_embed[key]['y'] = np.array(self.sst_data[key]['y'])
76+
logging.info('Computed {0} embeddings'.format(key))
77+
78+
config_classifier = {'nclasses': self.nclasses, 'seed': self.seed,
79+
'usepytorch': params.usepytorch,
80+
'classifier': params.classifier}
81+
82+
clf = SplitClassifier(X={'train': sst_embed['train']['X'],
83+
'valid': sst_embed['dev']['X'],
84+
'test': sst_embed['test']['X']},
85+
y={'train': sst_embed['train']['y'],
86+
'valid': sst_embed['dev']['y'],
87+
'test': sst_embed['test']['y']},
88+
config=config_classifier)
89+
90+
devacc, testacc = clf.run()
91+
logging.debug('\nDev acc : {0} Test acc : {1} for \
92+
SST {2} classification\n'.format(devacc, testacc, self.task_name))
93+
94+
return {'devacc': devacc, 'acc': testacc,
95+
'ndev': len(sst_embed['dev']['X']),
96+
'ntest': len(sst_embed['test']['X'])}

‎SentEval/senteval/sts.py

+265
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
'''
9+
STS-{2012,2013,2014,2015,2016} (unsupervised) and
10+
STS-benchmark (supervised) tasks
11+
'''
12+
13+
from __future__ import absolute_import, division, unicode_literals
14+
15+
import os
16+
import io
17+
import numpy as np
18+
import logging
19+
20+
from scipy.stats import spearmanr, pearsonr
21+
22+
from senteval.utils import cosine
23+
from senteval.sick import SICKEval
24+
25+
26+
class STSEval(object):
27+
def loadFile(self, fpath):
28+
self.data = {}
29+
self.samples = []
30+
31+
for dataset in self.datasets:
32+
sent1, sent2 = zip(*[l.split("\t") for l in
33+
io.open(fpath + '/STS.input.%s.txt' % dataset,
34+
encoding='utf8').read().splitlines()])
35+
raw_scores = np.array([x for x in
36+
io.open(fpath + '/STS.gs.%s.txt' % dataset,
37+
encoding='utf8')
38+
.read().splitlines()])
39+
not_empty_idx = raw_scores != ''
40+
41+
gs_scores = [float(x) for x in raw_scores[not_empty_idx]]
42+
sent1 = np.array([s.split() for s in sent1], dtype=object)[not_empty_idx]
43+
sent2 = np.array([s.split() for s in sent2], dtype=object)[not_empty_idx]
44+
45+
# sort data by length to minimize padding in batcher
46+
sorted_data = sorted(zip(sent1, sent2, gs_scores),
47+
key=lambda z: (len(z[0]), len(z[1]), z[2]))
48+
sent1, sent2, gs_scores = map(list, zip(*sorted_data))
49+
50+
self.data[dataset] = (sent1, sent2, gs_scores)
51+
self.samples += sent1 + sent2
52+
53+
def do_prepare(self, params, prepare):
54+
if 'similarity' in params:
55+
self.similarity = params.similarity
56+
else: # Default similarity is cosine
57+
self.similarity = lambda s1, s2: np.nan_to_num(cosine(np.nan_to_num(s1), np.nan_to_num(s2)))
58+
return prepare(params, self.samples)
59+
60+
def run(self, params, batcher):
61+
results = {}
62+
all_sys_scores = []
63+
all_gs_scores = []
64+
for dataset in self.datasets:
65+
sys_scores = []
66+
input1, input2, gs_scores = self.data[dataset]
67+
for ii in range(0, len(gs_scores), params.batch_size):
68+
batch1 = input1[ii:ii + params.batch_size]
69+
batch2 = input2[ii:ii + params.batch_size]
70+
71+
# we assume get_batch already throws out the faulty ones
72+
if len(batch1) == len(batch2) and len(batch1) > 0:
73+
enc1 = batcher(params, batch1)
74+
enc2 = batcher(params, batch2)
75+
76+
for kk in range(enc2.shape[0]):
77+
sys_score = self.similarity(enc1[kk], enc2[kk])
78+
sys_scores.append(sys_score)
79+
all_sys_scores.extend(sys_scores)
80+
all_gs_scores.extend(gs_scores)
81+
results[dataset] = {'pearson': pearsonr(sys_scores, gs_scores),
82+
'spearman': spearmanr(sys_scores, gs_scores),
83+
'nsamples': len(sys_scores)}
84+
logging.debug('%s : pearson = %.4f, spearman = %.4f' %
85+
(dataset, results[dataset]['pearson'][0],
86+
results[dataset]['spearman'][0]))
87+
88+
weights = [results[dset]['nsamples'] for dset in results.keys()]
89+
list_prs = np.array([results[dset]['pearson'][0] for
90+
dset in results.keys()])
91+
list_spr = np.array([results[dset]['spearman'][0] for
92+
dset in results.keys()])
93+
94+
avg_pearson = np.average(list_prs)
95+
avg_spearman = np.average(list_spr)
96+
wavg_pearson = np.average(list_prs, weights=weights)
97+
wavg_spearman = np.average(list_spr, weights=weights)
98+
all_pearson = pearsonr(all_sys_scores, all_gs_scores)
99+
all_spearman = spearmanr(all_sys_scores, all_gs_scores)
100+
results['all'] = {'pearson': {'all': all_pearson[0],
101+
'mean': avg_pearson,
102+
'wmean': wavg_pearson},
103+
'spearman': {'all': all_spearman[0],
104+
'mean': avg_spearman,
105+
'wmean': wavg_spearman}}
106+
logging.debug('ALL : Pearson = %.4f, \
107+
Spearman = %.4f' % (all_pearson[0], all_spearman[0]))
108+
logging.debug('ALL (weighted average) : Pearson = %.4f, \
109+
Spearman = %.4f' % (wavg_pearson, wavg_spearman))
110+
logging.debug('ALL (average) : Pearson = %.4f, \
111+
Spearman = %.4f\n' % (avg_pearson, avg_spearman))
112+
113+
return results
114+
115+
116+
class STS12Eval(STSEval):
117+
def __init__(self, taskpath, seed=1111):
118+
logging.debug('***** Transfer task : STS12 *****\n\n')
119+
self.seed = seed
120+
self.datasets = ['MSRpar', 'MSRvid', 'SMTeuroparl',
121+
'surprise.OnWN', 'surprise.SMTnews']
122+
self.loadFile(taskpath)
123+
124+
125+
class STS13Eval(STSEval):
126+
# STS13 here does not contain the "SMT" subtask due to LICENSE issue
127+
def __init__(self, taskpath, seed=1111):
128+
logging.debug('***** Transfer task : STS13 (-SMT) *****\n\n')
129+
self.seed = seed
130+
self.datasets = ['FNWN', 'headlines', 'OnWN']
131+
self.loadFile(taskpath)
132+
133+
134+
class STS14Eval(STSEval):
135+
def __init__(self, taskpath, seed=1111):
136+
logging.debug('***** Transfer task : STS14 *****\n\n')
137+
self.seed = seed
138+
self.datasets = ['deft-forum', 'deft-news', 'headlines',
139+
'images', 'OnWN', 'tweet-news']
140+
self.loadFile(taskpath)
141+
142+
143+
class STS15Eval(STSEval):
144+
def __init__(self, taskpath, seed=1111):
145+
logging.debug('***** Transfer task : STS15 *****\n\n')
146+
self.seed = seed
147+
self.datasets = ['answers-forums', 'answers-students',
148+
'belief', 'headlines', 'images']
149+
self.loadFile(taskpath)
150+
151+
152+
class STS16Eval(STSEval):
153+
def __init__(self, taskpath, seed=1111):
154+
logging.debug('***** Transfer task : STS16 *****\n\n')
155+
self.seed = seed
156+
self.datasets = ['answer-answer', 'headlines', 'plagiarism',
157+
'postediting', 'question-question']
158+
self.loadFile(taskpath)
159+
160+
161+
class STSBenchmarkEval(STSEval):
162+
def __init__(self, task_path, seed=1111):
163+
logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n')
164+
self.seed = seed
165+
self.samples = []
166+
#train = self.loadFile(os.path.join(task_path, 'sts-train.csv'))
167+
#dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv'))
168+
#test = self.loadFile(os.path.join(task_path, 'sts-test.csv'))
169+
#self.datasets = ['train', 'dev', 'test']
170+
#self.data = {'train': train, 'dev': dev, 'test': test}
171+
test = self.loadFile(os.path.join(task_path, 'sts-test.csv'))
172+
self.datasets = ['test']
173+
self.data = {'test': test}
174+
175+
def loadFile(self, fpath):
176+
sick_data = {'X_A': [], 'X_B': [], 'y': []}
177+
with io.open(fpath, 'r', encoding='utf-8') as f:
178+
for line in f:
179+
text = line.strip().split('\t')
180+
sick_data['X_A'].append(text[5].split())
181+
sick_data['X_B'].append(text[6].split())
182+
sick_data['y'].append(text[4])
183+
184+
sick_data['y'] = [float(s) for s in sick_data['y']]
185+
self.samples += sick_data['X_A'] + sick_data["X_B"]
186+
return (sick_data['X_A'], sick_data["X_B"], sick_data['y'])
187+
188+
class STSBenchmarkEvalDev(STSEval):
189+
def __init__(self, task_path, seed=1111):
190+
logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n')
191+
self.seed = seed
192+
self.samples = []
193+
#train = self.loadFile(os.path.join(task_path, 'sts-train.csv'))
194+
#dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv'))
195+
#test = self.loadFile(os.path.join(task_path, 'sts-test.csv'))
196+
#self.datasets = ['train', 'dev', 'test']
197+
#self.data = {'train': train, 'dev': dev, 'test': test}
198+
dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv'))
199+
self.datasets = ['dev']
200+
self.data = {'dev': dev}
201+
202+
def loadFile(self, fpath):
203+
sick_data = {'X_A': [], 'X_B': [], 'y': []}
204+
with io.open(fpath, 'r', encoding='utf-8') as f:
205+
for line in f:
206+
text = line.strip().split('\t')
207+
sick_data['X_A'].append(text[5].split())
208+
sick_data['X_B'].append(text[6].split())
209+
sick_data['y'].append(text[4])
210+
211+
sick_data['y'] = [float(s) for s in sick_data['y']]
212+
self.samples += sick_data['X_A'] + sick_data["X_B"]
213+
return (sick_data['X_A'], sick_data["X_B"], sick_data['y'])
214+
215+
class STSBenchmarkFinetune(SICKEval):
216+
def __init__(self, task_path, seed=1111):
217+
logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n')
218+
self.seed = seed
219+
train = self.loadFile(os.path.join(task_path, 'sts-train.csv'))
220+
dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv'))
221+
test = self.loadFile(os.path.join(task_path, 'sts-test.csv'))
222+
self.sick_data = {'train': train, 'dev': dev, 'test': test}
223+
224+
def loadFile(self, fpath):
225+
sick_data = {'X_A': [], 'X_B': [], 'y': []}
226+
with io.open(fpath, 'r', encoding='utf-8') as f:
227+
for line in f:
228+
text = line.strip().split('\t')
229+
sick_data['X_A'].append(text[5].split())
230+
sick_data['X_B'].append(text[6].split())
231+
sick_data['y'].append(text[4])
232+
233+
sick_data['y'] = [float(s) for s in sick_data['y']]
234+
return sick_data
235+
236+
class SICKRelatednessEval(STSEval):
237+
def __init__(self, task_path, seed=1111):
238+
logging.debug('\n\n***** Transfer task : SICKRelatedness*****\n\n')
239+
self.seed = seed
240+
self.samples = []
241+
#train = self.loadFile(os.path.join(task_path, 'SICK_train.txt'))
242+
#dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt'))
243+
#test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt'))
244+
#self.datasets = ['train', 'dev', 'test']
245+
#self.data = {'train': train, 'dev': dev, 'test': test}
246+
test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt'))
247+
self.datasets = ['test']
248+
self.data = {'test': test}
249+
250+
def loadFile(self, fpath):
251+
skipFirstLine = True
252+
sick_data = {'X_A': [], 'X_B': [], 'y': []}
253+
with io.open(fpath, 'r', encoding='utf-8') as f:
254+
for line in f:
255+
if skipFirstLine:
256+
skipFirstLine = False
257+
else:
258+
text = line.strip().split('\t')
259+
sick_data['X_A'].append(text[1].split())
260+
sick_data['X_B'].append(text[2].split())
261+
sick_data['y'].append(text[3])
262+
263+
sick_data['y'] = [float(s) for s in sick_data['y']]
264+
self.samples += sick_data['X_A'] + sick_data["X_B"]
265+
return (sick_data['X_A'], sick_data["X_B"], sick_data['y'])

‎SentEval/senteval/tools/__init__.py

Whitespace-only changes.

‎SentEval/senteval/tools/classifier.py

+202
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
8+
"""
9+
Pytorch Classifier class in the style of scikit-learn
10+
Classifiers include Logistic Regression and MLP
11+
"""
12+
13+
from __future__ import absolute_import, division, unicode_literals
14+
15+
import numpy as np
16+
import copy
17+
from senteval import utils
18+
19+
import torch
20+
from torch import nn
21+
import torch.nn.functional as F
22+
23+
24+
class PyTorchClassifier(object):
25+
def __init__(self, inputdim, nclasses, l2reg=0., batch_size=64, seed=1111,
26+
cudaEfficient=False):
27+
# fix seed
28+
np.random.seed(seed)
29+
torch.manual_seed(seed)
30+
torch.cuda.manual_seed(seed)
31+
32+
self.inputdim = inputdim
33+
self.nclasses = nclasses
34+
self.l2reg = l2reg
35+
self.batch_size = batch_size
36+
self.cudaEfficient = cudaEfficient
37+
38+
def prepare_split(self, X, y, validation_data=None, validation_split=None):
39+
# Preparing validation data
40+
assert validation_split or validation_data
41+
if validation_data is not None:
42+
trainX, trainy = X, y
43+
devX, devy = validation_data
44+
else:
45+
permutation = np.random.permutation(len(X))
46+
trainidx = permutation[int(validation_split * len(X)):]
47+
devidx = permutation[0:int(validation_split * len(X))]
48+
trainX, trainy = X[trainidx], y[trainidx]
49+
devX, devy = X[devidx], y[devidx]
50+
51+
device = torch.device('cpu') if self.cudaEfficient else torch.device('cuda')
52+
53+
trainX = torch.from_numpy(trainX).to(device, dtype=torch.float32)
54+
trainy = torch.from_numpy(trainy).to(device, dtype=torch.int64)
55+
devX = torch.from_numpy(devX).to(device, dtype=torch.float32)
56+
devy = torch.from_numpy(devy).to(device, dtype=torch.int64)
57+
58+
return trainX, trainy, devX, devy
59+
60+
def fit(self, X, y, validation_data=None, validation_split=None,
61+
early_stop=True):
62+
self.nepoch = 0
63+
bestaccuracy = -1
64+
stop_train = False
65+
early_stop_count = 0
66+
67+
# Preparing validation data
68+
trainX, trainy, devX, devy = self.prepare_split(X, y, validation_data,
69+
validation_split)
70+
71+
# Training
72+
while not stop_train and self.nepoch <= self.max_epoch:
73+
self.trainepoch(trainX, trainy, epoch_size=self.epoch_size)
74+
accuracy = self.score(devX, devy)
75+
if accuracy > bestaccuracy:
76+
bestaccuracy = accuracy
77+
bestmodel = copy.deepcopy(self.model)
78+
elif early_stop:
79+
if early_stop_count >= self.tenacity:
80+
stop_train = True
81+
early_stop_count += 1
82+
self.model = bestmodel
83+
return bestaccuracy
84+
85+
def trainepoch(self, X, y, epoch_size=1):
86+
self.model.train()
87+
for _ in range(self.nepoch, self.nepoch + epoch_size):
88+
permutation = np.random.permutation(len(X))
89+
all_costs = []
90+
for i in range(0, len(X), self.batch_size):
91+
# forward
92+
idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().to(X.device)
93+
94+
Xbatch = X[idx]
95+
ybatch = y[idx]
96+
97+
if self.cudaEfficient:
98+
Xbatch = Xbatch.cuda()
99+
ybatch = ybatch.cuda()
100+
output = self.model(Xbatch)
101+
# loss
102+
loss = self.loss_fn(output, ybatch)
103+
all_costs.append(loss.data.item())
104+
# backward
105+
self.optimizer.zero_grad()
106+
loss.backward()
107+
# Update parameters
108+
self.optimizer.step()
109+
self.nepoch += epoch_size
110+
111+
def score(self, devX, devy):
112+
self.model.eval()
113+
correct = 0
114+
if not isinstance(devX, torch.cuda.FloatTensor) or self.cudaEfficient:
115+
devX = torch.FloatTensor(devX).cuda()
116+
devy = torch.LongTensor(devy).cuda()
117+
with torch.no_grad():
118+
for i in range(0, len(devX), self.batch_size):
119+
Xbatch = devX[i:i + self.batch_size]
120+
ybatch = devy[i:i + self.batch_size]
121+
if self.cudaEfficient:
122+
Xbatch = Xbatch.cuda()
123+
ybatch = ybatch.cuda()
124+
output = self.model(Xbatch)
125+
pred = output.data.max(1)[1]
126+
correct += pred.long().eq(ybatch.data.long()).sum().item()
127+
accuracy = 1.0 * correct / len(devX)
128+
return accuracy
129+
130+
def predict(self, devX):
131+
self.model.eval()
132+
if not isinstance(devX, torch.cuda.FloatTensor):
133+
devX = torch.FloatTensor(devX).cuda()
134+
yhat = np.array([])
135+
with torch.no_grad():
136+
for i in range(0, len(devX), self.batch_size):
137+
Xbatch = devX[i:i + self.batch_size]
138+
output = self.model(Xbatch)
139+
yhat = np.append(yhat,
140+
output.data.max(1)[1].cpu().numpy())
141+
yhat = np.vstack(yhat)
142+
return yhat
143+
144+
def predict_proba(self, devX):
145+
self.model.eval()
146+
probas = []
147+
with torch.no_grad():
148+
for i in range(0, len(devX), self.batch_size):
149+
Xbatch = devX[i:i + self.batch_size]
150+
vals = F.softmax(self.model(Xbatch).data.cpu().numpy())
151+
if not probas:
152+
probas = vals
153+
else:
154+
probas = np.concatenate(probas, vals, axis=0)
155+
return probas
156+
157+
158+
"""
159+
MLP with Pytorch (nhid=0 --> Logistic Regression)
160+
"""
161+
162+
class MLP(PyTorchClassifier):
163+
def __init__(self, params, inputdim, nclasses, l2reg=0., batch_size=64,
164+
seed=1111, cudaEfficient=False):
165+
super(self.__class__, self).__init__(inputdim, nclasses, l2reg,
166+
batch_size, seed, cudaEfficient)
167+
"""
168+
PARAMETERS:
169+
-nhid: number of hidden units (0: Logistic Regression)
170+
-optim: optimizer ("sgd,lr=0.1", "adam", "rmsprop" ..)
171+
-tenacity: how many times dev acc does not increase before stopping
172+
-epoch_size: each epoch corresponds to epoch_size pass on the train set
173+
-max_epoch: max number of epoches
174+
-dropout: dropout for MLP
175+
"""
176+
177+
self.nhid = 0 if "nhid" not in params else params["nhid"]
178+
self.optim = "adam" if "optim" not in params else params["optim"]
179+
self.tenacity = 5 if "tenacity" not in params else params["tenacity"]
180+
self.epoch_size = 4 if "epoch_size" not in params else params["epoch_size"]
181+
self.max_epoch = 200 if "max_epoch" not in params else params["max_epoch"]
182+
self.dropout = 0. if "dropout" not in params else params["dropout"]
183+
self.batch_size = 64 if "batch_size" not in params else params["batch_size"]
184+
185+
if params["nhid"] == 0:
186+
self.model = nn.Sequential(
187+
nn.Linear(self.inputdim, self.nclasses),
188+
).cuda()
189+
else:
190+
self.model = nn.Sequential(
191+
nn.Linear(self.inputdim, params["nhid"]),
192+
nn.Dropout(p=self.dropout),
193+
nn.Sigmoid(),
194+
nn.Linear(params["nhid"], self.nclasses),
195+
).cuda()
196+
197+
self.loss_fn = nn.CrossEntropyLoss().cuda()
198+
self.loss_fn.size_average = False
199+
200+
optim_fn, optim_params = utils.get_optimizer(self.optim)
201+
self.optimizer = optim_fn(self.model.parameters(), **optim_params)
202+
self.optimizer.param_groups[0]['weight_decay'] = self.l2reg

0 commit comments

Comments
 (0)
Please sign in to comment.