Skip to content

Commit

Permalink
allow nlg-eval --setup to work from anywhere. Fixes #78 (#79)
Browse files Browse the repository at this point in the history
* allow nlg-eval --setup to work from anywhere. Fixes #78

nlg-eval --setup now takes a path as an argument or "default".

During setup,

- The code location is inferred from the package directory.
- The data location is stored in a config file, so that it does not have
  to be provided again.

A common function (`nlgeval.utils.get_data_dir()`) provides
(rudimentarily checked) access to the data location via environment
variable or config file (in this order). Added some colorful hints
as to where data is stored to setup/CLI.

* move nltk downloading to setup function, never got this to work
* setup: update glove2word2vec download URL after PR merge
  • Loading branch information
temporaer authored and juharris committed Jul 18, 2019
1 parent 5a89faf commit 7aa86b8
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 49 deletions.
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ Then run:

```bash
# Install the Python dependencies.
# It may take a while to run because it's downloading some files. You can instead run `pip install -v -e .` to see more details.
pip install -e .
pip install git+https://github.com/Maluuba/nlg-eval.git@master

# If using macOS High Sierra or higher, run this before run setup, to allow multithreading
# export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES

# Download required data files.
nlg-eval --setup
# Download required data (e.g. models, embeddings) and external code files.
# If you don't like the default path (~/.cache/nlgeval), specify a path where the files are downloaded.
# The data path is stored in ~/.config/nlgeval/rc.json and can be overwritten by
# setting the NLGEVAL_DATA environment variable.
nlg-eval --setup [data path]
```

## Usage ##
Expand Down Expand Up @@ -158,6 +160,8 @@ where to find its models and data. E.g.

NLGEVAL_DATA=~/workspace/nlg-eval/nlgeval/data

This variable overrides the value provided during setup (stored in `~/.config/nlgeval/rc.json`)

## Microsoft Open Source Code of Conduct ##
This project has adopted the [Microsoft Open Source Code of
Conduct](https://opensource.microsoft.com/codeofconduct/).
Expand Down
79 changes: 63 additions & 16 deletions bin/nlg-eval
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.

import json
import logging
import os
import stat
Expand All @@ -11,8 +12,12 @@ import time
from zipfile import ZipFile

import click
from xdg import XDG_CONFIG_HOME, XDG_CACHE_HOME

import nlgeval
import nlgeval.utils

CODE_PATH = nlgeval.__path__[0]


def _download_file(d):
Expand Down Expand Up @@ -57,28 +62,41 @@ def _download_file(d):
raise


def setup(ctx, param, value):
if not value:
return
@click.command()
@click.argument("data_path", required=False)
def setup(data_path):
"""
Download required code and data files for nlg-eval.
If the data_path argument is provided, install to the given location.
Otherwise, your cache directory is used (usually ~/.cache/nlgeval).
"""
from nltk.downloader import download
download('punkt')

from multiprocessing import Pool

data_path = os.getenv('NLGEVAL_DATA', 'nlgeval/data')
path = 'nlgeval/word2vec/glove2word2vec.py'
if data_path is None:
data_path = os.getenv('NLGEVAL_DATA', os.path.join(XDG_CACHE_HOME, 'nlgeval'))
click.secho("Installing to {}".format(data_path), fg='red')
click.secho("In case of incomplete downloads, delete the directory and run `nlg-eval --setup {}' again.".format(data_path),
fg='red')

path = os.path.join(CODE_PATH, 'word2vec/glove2word2vec.py')
if os.path.exists(path):
os.remove('nlgeval/word2vec/glove2word2vec.py')
os.remove(path)

downloads = []

if sys.version_info[0] == 2:
downloads.append(dict(
url='https://raw.githubusercontent.com/manasRK/glove-gensim/42ce46f00e83d3afa028fb6bf17ed3c90ca65fcc/glove2word2vec.py',
target_dir='nlgeval/word2vec'
target_dir=os.path.join(CODE_PATH, 'word2vec')
))
else:
# Change URL once https://github.com/robmsmt/glove-gensim/pull/1 has been merged.
downloads.append(dict(
url='https://raw.githubusercontent.com/juharris/glove-gensim/4c2224bccd61627b76c50a5e1d6afd1c82699d22/glove2word2vec.py',
target_dir='nlgeval/word2vec'
url='https://raw.githubusercontent.com/robmsmt/glove-gensim/4c2224bccd61627b76c50a5e1d6afd1c82699d22/glove2word2vec.py',
target_dir=os.path.join(CODE_PATH, 'word2vec')
))

setup_glove = not os.path.exists(os.path.join(data_path, 'glove.6B.300d.model.bin'))
Expand Down Expand Up @@ -121,7 +139,7 @@ def setup(ctx, param, value):
# multi-bleu.perl
downloads.append(dict(
url='https://raw.githubusercontent.com/moses-smt/mosesdecoder/b199e654df2a26ea58f234cbb642e89d9c1f269d/scripts/generic/multi-bleu.perl',
target_dir='nlgeval/multibleu'
target_dir=os.path.join(CODE_PATH, 'multibleu')
))

for target_dir in {d['target_dir'] for d in downloads}:
Expand All @@ -138,7 +156,7 @@ def setup(ctx, param, value):
from nlgeval.word2vec.generate_w2v_files import generate
with ZipFile(os.path.join(data_path, 'glove.6B.zip')) as z:
z.extract('glove.6B.300d.txt', data_path)
generate()
generate(data_path)
for p in [
os.path.join(data_path, 'glove.6B.zip'),
os.path.join(data_path, 'glove.6B.300d.txt'),
Expand All @@ -147,23 +165,52 @@ def setup(ctx, param, value):
if os.path.exists(p):
os.remove(p)

path = 'nlgeval/multibleu/multi-bleu.perl'
path = os.path.join(CODE_PATH, 'multibleu/multi-bleu.perl')
stats = os.stat(path)
os.chmod(path, stats.st_mode | stat.S_IEXEC)

ctx.exit()
cfg_path = os.path.join(XDG_CONFIG_HOME, "nlgeval")
if not os.path.exists(cfg_path):
os.makedirs(cfg_path)
rc = dict()
try:
with open(os.path.join(cfg_path, "rc.json"), 'rt') as f:
rc = json.load(f)
except:
print("WARNING: could not read rc.json in %s, overwriting" % cfg_path)
rc['data_path'] = data_path
with open(os.path.join(cfg_path, "rc.json"), 'wt') as f:
f.write(json.dumps(rc))


@click.command()
@click.option('--setup', is_flag=True, callback=setup, expose_value=False, is_eager=True)
@click.option('--references', type=click.Path(exists=True), multiple=True, required=True, help='Path of the reference file. This option can be provided multiple times for multiple reference files.')
@click.option('--hypothesis', type=click.Path(exists=True), required=True, help='Path of the hypothesis file.')
@click.option('--no-overlap', is_flag=True, help='Flag. If provided, word overlap based metrics will not be computed.')
@click.option('--no-skipthoughts', is_flag=True, help='Flag. If provided, skip-thought cosine similarity will not be computed.')
@click.option('--no-glove', is_flag=True, help='Flag. If provided, other word embedding based metrics will not be computed.')
def compute_metrics(hypothesis, references, no_overlap, no_skipthoughts, no_glove):
"""
Compute nlg-eval metrics.
The --hypothesis and at least one --references parameters are required.
To download the data and additional code files, use `nlg-eval --setup [data path]`.
Note that nlg-eval also features an API, which may be easier to use.
"""
try:
data_dir = nlgeval.utils.get_data_dir()
except nlgeval.utils.InvalidDataDirException:
sys.exit(1)
click.secho("Using data from {}".format(data_dir), fg='green')
click.secho("In case of broken downloads, remove the directory and run setup again.", fg='green')
nlgeval.compute_metrics(hypothesis, references, no_overlap, no_skipthoughts, no_glove)


if __name__ == '__main__':
compute_metrics()
if len(sys.argv) > 1 and sys.argv[1] == '--setup':
del sys.argv[0]
setup()
else:
compute_metrics()
5 changes: 3 additions & 2 deletions nlgeval/skipthoughts/skipthoughts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
from nltk.tokenize import word_tokenize
from scipy.linalg import norm
from six.moves import cPickle as pkl
from nlgeval.utils import get_data_dir
import logging

profile = False

#-----------------------------------------------------------------------------#
# Specify model and table locations here
#-----------------------------------------------------------------------------#
path_to_models = os.environ.get('NLGEVAL_DATA', os.path.join(os.path.dirname(__file__), '..', 'data'))
path_to_tables = os.environ.get('NLGEVAL_DATA', os.path.join(os.path.dirname(__file__), '..', 'data'))
path_to_models = get_data_dir()
path_to_tables = get_data_dir()
#-----------------------------------------------------------------------------#

path_to_umodel = os.path.join(path_to_models, 'uni_skip.npz')
Expand Down
33 changes: 33 additions & 0 deletions nlgeval/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import click
import json
import os

from xdg import XDG_CONFIG_HOME


class InvalidDataDirException(Exception):
pass


def get_data_dir():
if os.environ.get('NLGEVAL_DATA'):
if not os.path.exists(os.environ.get('NLGEVAL_DATA')):
click.secho("NLGEVAL_DATA variable is set but points to non-existent path.", fg='red', err=True)
raise InvalidDataDirException()
return os.environ.get('NLGEVAL_DATA')
else:
try:
cfg_file = os.path.join(XDG_CONFIG_HOME, 'nlgeval', 'rc.json')
with open(cfg_file, 'rt') as f:
rc = json.load(f)
if not os.path.exists(rc['data_path']):
click.secho("Data path found in {} does not exist: {} " % (cfg_file, rc['data_path']), fg='red', err=True)
click.secho("Run `nlg-eval --setup DATA_DIR' to download or set $NLGEVAL_DATA to an existing location",
fg='red', err=True)
raise InvalidDataDirException()
return rc['data_path']
except:
click.secho("Could not determine location of data.", fg='red', err=True)
click.secho("Run `nlg-eval --setup DATA_DIR' to download or set $NLGEVAL_DATA to an existing location", fg='red',
err=True)
raise InvalidDataDirException()
6 changes: 5 additions & 1 deletion nlgeval/word2vec/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
import os
import numpy as np

from nlgeval.utils import get_data_dir


try:
from gensim.models import KeyedVectors
except ImportError:
Expand All @@ -10,7 +14,7 @@

class Embedding(object):
def __init__(self):
path = os.environ.get('NLGEVAL_DATA', os.path.join(os.path.dirname(__file__), '..', 'data'))
path = get_data_dir()
self.m = KeyedVectors.load(os.path.join(path, 'glove.6B.300d.model.bin'), mmap='r')
try:
self.unk = self.m.vectors.mean(axis=0)
Expand Down
3 changes: 1 addition & 2 deletions nlgeval/word2vec/generate_w2v_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ def txt2bin(filename):
KeyedVectors.load(filename.replace('txt', 'bin'), mmap='r')


def generate():
path = os.path.join(os.path.dirname(__file__), "..", "data")
def generate(path):
glove_vector_file = os.path.join(path, 'glove.6B.300d.txt')
output_model_file = os.path.join(path, 'glove.6B.300d.model.txt')

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ scikit-learn>=0.17
gensim>=3
Theano>=0.8.1
tqdm>=4.24
xdg
3 changes: 2 additions & 1 deletion requirements_py2.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
click>=6.3
nltk>=3.1
numpy>=1.11.0
numpy>=1.11.0<=1.17
psutil>=5.6.2
requests>=2.19
six>=1.11
Expand All @@ -9,3 +9,4 @@ scikit-learn<0.21
gensim<1
Theano>=0.8.1
tqdm>=4.24
xdg==1.0.7
25 changes: 2 additions & 23 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,6 @@
from pip.req import parse_requirements


def _post_setup():
from nltk.downloader import download
download('punkt')


# Set up post install actions as per https://stackoverflow.com/a/36902139/1226799
class PostDevelopCommand(develop):
def run(self):
develop.run(self)
_post_setup()


class PostInstallCommand(install):
def run(self):
install.run(self)
_post_setup()


if __name__ == '__main__':
requirements_path = 'requirements.txt'
if sys.version_info[0] < 3:
Expand All @@ -42,7 +24,7 @@ def run(self):
reqs = [str(ir.req) for ir in install_reqs]

setup(name='nlg-eval',
version='2.1',
version='2.2',
description="Wrapper for multiple NLG evaluation methods and metrics.",
author='Shikhar Sharma, Hannes Schulz, Justin Harris',
author_email='[email protected], [email protected], [email protected]',
Expand All @@ -51,7 +33,4 @@ def run(self):
include_package_data=True,
scripts=['bin/nlg-eval'],
install_requires=reqs,
cmdclass={
'develop': PostDevelopCommand,
'install': PostInstallCommand,
})
)

0 comments on commit 7aa86b8

Please sign in to comment.