diff --git a/README.md b/README.md index 7ae40e9..ff85b08 100644 --- a/README.md +++ b/README.md @@ -37,14 +37,16 @@ Then run: ```bash # Install the Python dependencies. -# It may take a while to run because it's downloading some files. You can instead run `pip install -v -e .` to see more details. -pip install -e . +pip install git+https://github.com/Maluuba/nlg-eval.git@master # If using macOS High Sierra or higher, run this before run setup, to allow multithreading # export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES -# Download required data files. -nlg-eval --setup +# Download required data (e.g. models, embeddings) and external code files. +# If you don't like the default path (~/.cache/nlgeval), specify a path where the files are downloaded. +# The data path is stored in ~/.config/nlgeval/rc.json and can be overwritten by +# setting the NLGEVAL_DATA environment variable. +nlg-eval --setup [data path] ``` ## Usage ## @@ -158,6 +160,8 @@ where to find its models and data. E.g. NLGEVAL_DATA=~/workspace/nlg-eval/nlgeval/data +This variable overrides the value provided during setup (stored in `~/.config/nlgeval/rc.json`) + ## Microsoft Open Source Code of Conduct ## This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). diff --git a/bin/nlg-eval b/bin/nlg-eval index 5577392..0e608ef 100755 --- a/bin/nlg-eval +++ b/bin/nlg-eval @@ -3,6 +3,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +import json import logging import os import stat @@ -11,8 +12,12 @@ import time from zipfile import ZipFile import click +from xdg import XDG_CONFIG_HOME, XDG_CACHE_HOME import nlgeval +import nlgeval.utils + +CODE_PATH = nlgeval.__path__[0] def _download_file(d): @@ -57,28 +62,41 @@ def _download_file(d): raise -def setup(ctx, param, value): - if not value: - return +@click.command() +@click.argument("data_path", required=False) +def setup(data_path): + """ + Download required code and data files for nlg-eval. + + If the data_path argument is provided, install to the given location. + Otherwise, your cache directory is used (usually ~/.cache/nlgeval). + """ + from nltk.downloader import download + download('punkt') + from multiprocessing import Pool - data_path = os.getenv('NLGEVAL_DATA', 'nlgeval/data') - path = 'nlgeval/word2vec/glove2word2vec.py' + if data_path is None: + data_path = os.getenv('NLGEVAL_DATA', os.path.join(XDG_CACHE_HOME, 'nlgeval')) + click.secho("Installing to {}".format(data_path), fg='red') + click.secho("In case of incomplete downloads, delete the directory and run `nlg-eval --setup {}' again.".format(data_path), + fg='red') + + path = os.path.join(CODE_PATH, 'word2vec/glove2word2vec.py') if os.path.exists(path): - os.remove('nlgeval/word2vec/glove2word2vec.py') + os.remove(path) downloads = [] if sys.version_info[0] == 2: downloads.append(dict( url='https://raw.githubusercontent.com/manasRK/glove-gensim/42ce46f00e83d3afa028fb6bf17ed3c90ca65fcc/glove2word2vec.py', - target_dir='nlgeval/word2vec' + target_dir=os.path.join(CODE_PATH, 'word2vec') )) else: - # Change URL once https://github.com/robmsmt/glove-gensim/pull/1 has been merged. downloads.append(dict( - url='https://raw.githubusercontent.com/juharris/glove-gensim/4c2224bccd61627b76c50a5e1d6afd1c82699d22/glove2word2vec.py', - target_dir='nlgeval/word2vec' + url='https://raw.githubusercontent.com/robmsmt/glove-gensim/4c2224bccd61627b76c50a5e1d6afd1c82699d22/glove2word2vec.py', + target_dir=os.path.join(CODE_PATH, 'word2vec') )) setup_glove = not os.path.exists(os.path.join(data_path, 'glove.6B.300d.model.bin')) @@ -121,7 +139,7 @@ def setup(ctx, param, value): # multi-bleu.perl downloads.append(dict( url='https://raw.githubusercontent.com/moses-smt/mosesdecoder/b199e654df2a26ea58f234cbb642e89d9c1f269d/scripts/generic/multi-bleu.perl', - target_dir='nlgeval/multibleu' + target_dir=os.path.join(CODE_PATH, 'multibleu') )) for target_dir in {d['target_dir'] for d in downloads}: @@ -138,7 +156,7 @@ def setup(ctx, param, value): from nlgeval.word2vec.generate_w2v_files import generate with ZipFile(os.path.join(data_path, 'glove.6B.zip')) as z: z.extract('glove.6B.300d.txt', data_path) - generate() + generate(data_path) for p in [ os.path.join(data_path, 'glove.6B.zip'), os.path.join(data_path, 'glove.6B.300d.txt'), @@ -147,23 +165,52 @@ def setup(ctx, param, value): if os.path.exists(p): os.remove(p) - path = 'nlgeval/multibleu/multi-bleu.perl' + path = os.path.join(CODE_PATH, 'multibleu/multi-bleu.perl') stats = os.stat(path) os.chmod(path, stats.st_mode | stat.S_IEXEC) - ctx.exit() + cfg_path = os.path.join(XDG_CONFIG_HOME, "nlgeval") + if not os.path.exists(cfg_path): + os.makedirs(cfg_path) + rc = dict() + try: + with open(os.path.join(cfg_path, "rc.json"), 'rt') as f: + rc = json.load(f) + except: + print("WARNING: could not read rc.json in %s, overwriting" % cfg_path) + rc['data_path'] = data_path + with open(os.path.join(cfg_path, "rc.json"), 'wt') as f: + f.write(json.dumps(rc)) @click.command() -@click.option('--setup', is_flag=True, callback=setup, expose_value=False, is_eager=True) @click.option('--references', type=click.Path(exists=True), multiple=True, required=True, help='Path of the reference file. This option can be provided multiple times for multiple reference files.') @click.option('--hypothesis', type=click.Path(exists=True), required=True, help='Path of the hypothesis file.') @click.option('--no-overlap', is_flag=True, help='Flag. If provided, word overlap based metrics will not be computed.') @click.option('--no-skipthoughts', is_flag=True, help='Flag. If provided, skip-thought cosine similarity will not be computed.') @click.option('--no-glove', is_flag=True, help='Flag. If provided, other word embedding based metrics will not be computed.') def compute_metrics(hypothesis, references, no_overlap, no_skipthoughts, no_glove): + """ + Compute nlg-eval metrics. + + The --hypothesis and at least one --references parameters are required. + + To download the data and additional code files, use `nlg-eval --setup [data path]`. + + Note that nlg-eval also features an API, which may be easier to use. + """ + try: + data_dir = nlgeval.utils.get_data_dir() + except nlgeval.utils.InvalidDataDirException: + sys.exit(1) + click.secho("Using data from {}".format(data_dir), fg='green') + click.secho("In case of broken downloads, remove the directory and run setup again.", fg='green') nlgeval.compute_metrics(hypothesis, references, no_overlap, no_skipthoughts, no_glove) if __name__ == '__main__': - compute_metrics() + if len(sys.argv) > 1 and sys.argv[1] == '--setup': + del sys.argv[0] + setup() + else: + compute_metrics() diff --git a/nlgeval/skipthoughts/skipthoughts.py b/nlgeval/skipthoughts/skipthoughts.py index 963328f..b1de54e 100644 --- a/nlgeval/skipthoughts/skipthoughts.py +++ b/nlgeval/skipthoughts/skipthoughts.py @@ -13,6 +13,7 @@ from nltk.tokenize import word_tokenize from scipy.linalg import norm from six.moves import cPickle as pkl +from nlgeval.utils import get_data_dir import logging profile = False @@ -20,8 +21,8 @@ #-----------------------------------------------------------------------------# # Specify model and table locations here #-----------------------------------------------------------------------------# -path_to_models = os.environ.get('NLGEVAL_DATA', os.path.join(os.path.dirname(__file__), '..', 'data')) -path_to_tables = os.environ.get('NLGEVAL_DATA', os.path.join(os.path.dirname(__file__), '..', 'data')) +path_to_models = get_data_dir() +path_to_tables = get_data_dir() #-----------------------------------------------------------------------------# path_to_umodel = os.path.join(path_to_models, 'uni_skip.npz') diff --git a/nlgeval/utils.py b/nlgeval/utils.py new file mode 100644 index 0000000..d9df32a --- /dev/null +++ b/nlgeval/utils.py @@ -0,0 +1,33 @@ +import click +import json +import os + +from xdg import XDG_CONFIG_HOME + + +class InvalidDataDirException(Exception): + pass + + +def get_data_dir(): + if os.environ.get('NLGEVAL_DATA'): + if not os.path.exists(os.environ.get('NLGEVAL_DATA')): + click.secho("NLGEVAL_DATA variable is set but points to non-existent path.", fg='red', err=True) + raise InvalidDataDirException() + return os.environ.get('NLGEVAL_DATA') + else: + try: + cfg_file = os.path.join(XDG_CONFIG_HOME, 'nlgeval', 'rc.json') + with open(cfg_file, 'rt') as f: + rc = json.load(f) + if not os.path.exists(rc['data_path']): + click.secho("Data path found in {} does not exist: {} " % (cfg_file, rc['data_path']), fg='red', err=True) + click.secho("Run `nlg-eval --setup DATA_DIR' to download or set $NLGEVAL_DATA to an existing location", + fg='red', err=True) + raise InvalidDataDirException() + return rc['data_path'] + except: + click.secho("Could not determine location of data.", fg='red', err=True) + click.secho("Run `nlg-eval --setup DATA_DIR' to download or set $NLGEVAL_DATA to an existing location", fg='red', + err=True) + raise InvalidDataDirException() diff --git a/nlgeval/word2vec/evaluate.py b/nlgeval/word2vec/evaluate.py index 410aceb..58822e8 100644 --- a/nlgeval/word2vec/evaluate.py +++ b/nlgeval/word2vec/evaluate.py @@ -2,6 +2,10 @@ # Licensed under the MIT license. See LICENSE.md file in the project root for full license information. import os import numpy as np + +from nlgeval.utils import get_data_dir + + try: from gensim.models import KeyedVectors except ImportError: @@ -10,7 +14,7 @@ class Embedding(object): def __init__(self): - path = os.environ.get('NLGEVAL_DATA', os.path.join(os.path.dirname(__file__), '..', 'data')) + path = get_data_dir() self.m = KeyedVectors.load(os.path.join(path, 'glove.6B.300d.model.bin'), mmap='r') try: self.unk = self.m.vectors.mean(axis=0) diff --git a/nlgeval/word2vec/generate_w2v_files.py b/nlgeval/word2vec/generate_w2v_files.py index 60d8d1e..3228267 100644 --- a/nlgeval/word2vec/generate_w2v_files.py +++ b/nlgeval/word2vec/generate_w2v_files.py @@ -18,8 +18,7 @@ def txt2bin(filename): KeyedVectors.load(filename.replace('txt', 'bin'), mmap='r') -def generate(): - path = os.path.join(os.path.dirname(__file__), "..", "data") +def generate(path): glove_vector_file = os.path.join(path, 'glove.6B.300d.txt') output_model_file = os.path.join(path, 'glove.6B.300d.model.txt') diff --git a/requirements.txt b/requirements.txt index f82f94a..fd498ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ scikit-learn>=0.17 gensim>=3 Theano>=0.8.1 tqdm>=4.24 +xdg diff --git a/requirements_py2.txt b/requirements_py2.txt index 019d4e2..2b15499 100644 --- a/requirements_py2.txt +++ b/requirements_py2.txt @@ -1,6 +1,6 @@ click>=6.3 nltk>=3.1 -numpy>=1.11.0 +numpy>=1.11.0<=1.17 psutil>=5.6.2 requests>=2.19 six>=1.11 @@ -9,3 +9,4 @@ scikit-learn<0.21 gensim<1 Theano>=0.8.1 tqdm>=4.24 +xdg==1.0.7 diff --git a/setup.py b/setup.py index 609b4eb..ac5ff59 100755 --- a/setup.py +++ b/setup.py @@ -16,24 +16,6 @@ from pip.req import parse_requirements -def _post_setup(): - from nltk.downloader import download - download('punkt') - - -# Set up post install actions as per https://stackoverflow.com/a/36902139/1226799 -class PostDevelopCommand(develop): - def run(self): - develop.run(self) - _post_setup() - - -class PostInstallCommand(install): - def run(self): - install.run(self) - _post_setup() - - if __name__ == '__main__': requirements_path = 'requirements.txt' if sys.version_info[0] < 3: @@ -42,7 +24,7 @@ def run(self): reqs = [str(ir.req) for ir in install_reqs] setup(name='nlg-eval', - version='2.1', + version='2.2', description="Wrapper for multiple NLG evaluation methods and metrics.", author='Shikhar Sharma, Hannes Schulz, Justin Harris', author_email='shikhar.sharma@microsoft.com, hannes.schulz@microsoft.com, justin.harris@microsoft.com', @@ -51,7 +33,4 @@ def run(self): include_package_data=True, scripts=['bin/nlg-eval'], install_requires=reqs, - cmdclass={ - 'develop': PostDevelopCommand, - 'install': PostInstallCommand, - }) + )