diff --git a/CMakeLists.txt b/CMakeLists.txt index 925d58a..3b6093e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,12 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. An additional grant # of patent rights can be found in the PATENTS file in the same directory. +# +# Copyright (c) Microsoft Corporation. All rights reserved +# Licensed under the BSD License -CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) -CMAKE_POLICY(VERSION 2.6) +CMAKE_MINIMUM_REQUIRED(VERSION 2.8 FATAL_ERROR) +CMAKE_POLICY(VERSION 2.8) FIND_PACKAGE(Torch REQUIRED) FIND_PACKAGE(OpenMP) +FIND_PACKAGE(CUDA REQUIRED) SET(CMAKE_CXX_FLAGS "-std=c++11 -Ofast") IF(OpenMP_FOUND) @@ -17,6 +21,13 @@ IF(OpenMP_FOUND) SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") ENDIF() +SET(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3 -shared -Xcompiler -fPIC + -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 + -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 + -gencode arch=compute_52,code=sm_52 -gencode arch=compute_52,code=compute_52") + +INCLUDE_DIRECTORIES(${CMAKE_PREFIX_PATH}/include/THC) + # C++ library IF(APPLE) SET(CMAKE_SHARED_LIBRARY_SUFFIX ".so") @@ -25,6 +36,14 @@ FILE(GLOB CPPSRC fairseq/clib/*.cpp) ADD_LIBRARY(fairseq_clib SHARED ${CPPSRC}) INSTALL(TARGETS fairseq_clib DESTINATION "${ROCKS_LIBDIR}") +ADD_LIBRARY(dp_lib SHARED fairseq/models/c_sample_dp.cc) +INSTALL(TARGETS dp_lib DESTINATION "${ROCKS_LIBDIR}") + +CUDA_ADD_LIBRARY(compute_logpy_lib SHARED fairseq/models/compute_logpy.cu) +INSTALL(TARGETS compute_logpy_lib DESTINATION "${ROCKS_LIBDIR}") + +TARGET_LINK_LIBRARIES(compute_logpy_lib ${CUDA_LIBRARIES}) + # Lua library INSTALL(DIRECTORY "fairseq" DESTINATION "${ROCKS_LUADIR}" FILES_MATCHING PATTERN "*.lua") diff --git a/README.md b/README.md index d62c767..142dfcb 100644 --- a/README.md +++ b/README.md @@ -1,46 +1,60 @@ # Introduction -This is fairseq, a sequence-to-sequence learning toolkit for [Torch](http://torch.ch/) from Facebook AI Research tailored to Neural Machine Translation (NMT). -It implements the convolutional NMT models proposed in [Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122) and [A Convolutional Encoder Model for Neural Machine Translation](https://arxiv.org/abs/1611.02344) as well as a standard LSTM-based model. -It features multi-GPU training on a single machine as well as fast beam search generation on both CPU and GPU. -We provide pre-trained models for English to French, English to German and English to Romanian translation. +This is NPMT, the source codes of [Towards Nerual Phrase-based Machine Translation](https://arxiv.org/abs/1706.05565) and [Sequence Modeling via Segmentations](https://arxiv.org/abs/1702.07463) from Microsoft Research. +It is built on top of the [fairseq toolkit](https://github.com/facebookresearch/fairseq) in [Torch](http://torch.ch/). +We present the setup and Neural Machine Translation (NMT) experiments in [Towards Nerual Phrase-based Machine Translation](https://arxiv.org/abs/1706.05565). -Note, there is now a PyTorch version [fairseq-py](https://github.com/facebookresearch/fairseq-py) of this toolkit and new development efforts will focus on it. +## NPMT +Neural Phrase-based Machine Translation (NPMT) explicitly models the phrase structures in output sequences using Sleep-WAke Networks (SWAN), a recently proposed segmentation-based sequence modeling method. +To mitigate the monotonic alignment requirement of SWAN, we introduce a new layer to perform (soft) local reordering of input sequences. +Different from existing neural machine translation (NMT) approaches, NPMT does not use attention-based decoding mechanisms. +Instead, it directly outputs phrases in a sequential order and can decode in linear time. + +Model architecture +![Example](npmt.png) + +An illustration of using NPMT in German-English translation +![Example](de-en_example.png) + + +Please refer to the [PR](https://github.com/posenhuang/NPMT/pull/1) for our implementations. Our implementation is based on the [lastest version](https://github.com/posenhuang/NPMT/commit/7d017f0a46a3cddfc420a4778d9541ba38b6a43d) of fairseq. -![Model](fairseq.gif) # Citation If you use the code in your paper, then please cite it as: ``` -@article{gehring2017convs2s, - author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N}, - title = "{Convolutional Sequence to Sequence Learning}", - journal = {ArXiv e-prints}, - archivePrefix = "arXiv", - eprinttype = {arxiv}, - eprint = {1705.03122}, - primaryClass = "cs.CL", - keywords = {Computer Science - Computation and Language}, - year = 2017, - month = May, +@article{pshuang2018NPMT, + author = {Po{-}Sen Huang and + Chong Wang and + Sitao Huang and + Dengyong Zhou and + Li Deng}, + title = {Towards Neural Phrase-based Machine Translation}, + journal = {CoRR}, + volume = {abs/1706.05565}, + year = {2017}, + url = {http://arxiv.org/abs/1706.05565}, + archivePrefix = {arXiv}, + eprint = {1706.05565}, } ``` and ``` -@article{gehring2016convenc, - author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Dauphin, Yann N}, - title = "{A Convolutional Encoder Model for Neural Machine Translation}", - journal = {ArXiv e-prints}, - archivePrefix = "arXiv", - eprinttype = {arxiv}, - eprint = {1611.02344}, - primaryClass = "cs.CL", - keywords = {Computer Science - Computation and Language}, - year = 2016, - month = Nov, +@inproceedings{wang2017SWAN, + author = {Chong Wang and + Yining Wang and + Po{-}Sen Huang and + Abdelrahman Mohamed and + Dengyong Zhou and + Li Deng}, + title = {Sequence Modeling via Segmentations}, + booktitle = {Proceedings of the 34th International Conference on Machine Learning, + {ICML} 2017, Sydney, NSW, Australia, 6-11 August 2017}, + pages = {3674--3683}, + year = {2017}, } ``` @@ -71,41 +85,6 @@ The LuaRocks installation provides a command-line tool that includes the followi # Quick Start -## Evaluating Pre-trained Models -First, download a pre-trained model along with its vocabularies: -``` -$ curl https://s3.amazonaws.com/fairseq/models/wmt14.en-fr.fconv-cuda.tar.bz2 | tar xvjf - -``` - -This will unpack vocabulary files and a serialized model for English to French translation to `wmt14.en-fr.fconv-cuda/`. - -Alternatively, use a CPU-based model: -``` -$ curl https://s3.amazonaws.com/fairseq/models/wmt14.en-fr.fconv-float.tar.bz2 | tar xvjf - -``` - -Let's use `fairseq generate-lines` to translate some text. -This model uses a [Byte Pair Encoding (BPE) vocabulary](https://arxiv.org/abs/1508.07909), so we'll have to apply the encoding to the source text. -This can be done with [apply_bpe.py](https://github.com/rsennrich/subword-nmt/blob/master/apply_bpe.py) using the `bpecodes` file in within `wmt14.en-fr.fconv-cuda/`. -`@@` is used as a continuation marker and the original text can be easily recovered with e.g. `sed s/@@ //g`. -Prior to BPE, input text needs to be tokenized using `tokenizer.perl` from [mosesdecoder](https://github.com/moses-smt/mosesdecoder). -Here, we use a beam size of 5: -``` -$ fairseq generate-lines -path wmt14.en-fr.fconv-cuda/model.th7 -sourcedict wmt14.en-fr.fconv-cuda/dict.en.th7 \ - -targetdict wmt14.en-fr.fconv-cuda/dict.fr.th7 -beam 5 -| [target] Dictionary: 44666 types -| [source] Dictionary: 44409 types -> Why is it rare to discover new marine mam@@ mal species ? -S Why is it rare to discover new marine mam@@ mal species ? -O Why is it rare to discover new marine mam@@ mal species ? -H -0.068684287369251 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ? -A 1 1 4 4 6 6 7 11 9 9 9 12 13 -``` - -This generation script produces four types of output: a line prefixed with *S* shows the supplied source sentence after applying the vocabulary; *O* is a copy of the original source sentence; *H* is the hypothesis along with an average log-likelihood and *A* are attention maxima for each word in the hypothesis (including the end-of-sentence marker which is omitted from the text). - -Check [below](#pre-trained-models) for a full list of pre-trained models available. - ## Training a New Model ### Data Pre-processing @@ -123,118 +102,102 @@ $ fairseq preprocess -sourcelang de -targetlang en \ ``` This will write binarized data that can be used for model training to data-bin/iwslt14.tokenized.de-en. +We also provide an example of pre-processing script for the IWSLT15 English-Vietnamese corpus. +Pre-process and binarize the data as follows: +``` +$ cd data/ +$ bash prepare-iwslt15.sh +$ cd .. +$ TEXT=data/iwslt15 +$ fairseq preprocess -sourcelang en -targetlang vi \ + -trainpref $TEXT/train -validpref $TEXT/tst2012 -testpref $TEXT/tst2013 \ + -thresholdsrc 5 -thresholdtgt 5 -destdir data-bin/iwslt15.tokenized.en-vi +``` + ### Training Use `fairseq train` to train a new model. -Here a few example settings that work well for the IWSLT14 dataset: +Here a few example settings that work well for the IWSLT14, IWSLT15 datasets: ``` -# Standard bi-directional LSTM model -$ mkdir -p trainings/blstm -$ fairseq train -sourcelang de -targetlang en -datadir data-bin/iwslt14.tokenized.de-en \ - -model blstm -nhid 512 -dropout 0.2 -dropout_hid 0 -optim adam -lr 0.0003125 -savedir trainings/blstm - -# Fully convolutional sequence-to-sequence model -$ mkdir -p trainings/fconv -$ fairseq train -sourcelang de -targetlang en -datadir data-bin/iwslt14.tokenized.de-en \ - -model fconv -nenclayer 4 -nlayer 3 -dropout 0.2 -optim nag -lr 0.25 -clip 0.1 \ - -momentum 0.99 -timeavg -bptt 0 -savedir trainings/fconv - -# Convolutional encoder, LSTM decoder -$ mkdir -p trainings/convenc +# NPMT model (IWSLT DE-EN) +$ mkdir -p trainings/iwslt_de_en $ fairseq train -sourcelang de -targetlang en -datadir data-bin/iwslt14.tokenized.de-en \ - -model conv -nenclayer 6 -dropout 0.2 -dropout_hid 0 -savedir trainings/convenc + -sourcelang de -targetlang en -model npmt -nhid 256 -dec_unit_size 512 -dropout .5 \ + -dropout_hid 0 -npmt_dropout .5 -optim adam -lr 0.001 -batchsize 32 -log_interval 100 \ + -nlayer 2 -nenclayer 2 -kwidth 7 -max_segment_len 6 -rnn_mode GRU -group_size 500 \ + -use_resnet_enc -use_resnet_dec -log -momentum 0.99 -clip 10 -maxbatch 600 -bptt 0 \ + -maxepoch 100 -ndatathreads 4 -seed 1002 -maxsourcelen 75 -num_lower_win_layers 1 \ + -save_interval 250 -use_accel -noearlystop -validbleu -lrshrink 1.25 -minepochtoanneal 18 \ + -annealing_type slow -savedir trainings/iwslt_de_en + +# NPMT model (IWSLT EN-DE) +$ mkdir -p trainings/iwslt_en_de +$ fairseq train -sourcelang en -targetlang de -datadir data-bin/iwslt14.tokenized.en-de \ + -model npmt -nhid 256 -dec_unit_size 512 -dropout .5 -dropout_hid 0 -npmt_dropout .5 \ + -optim adam -lr 0.001 -batchsize 32 -log_interval 100 -nlayer 2 -nenclayer 2 -kwidth 7 \ + -max_segment_len 6 -rnn_mode GRU -group_size 500 -use_resnet_enc -use_resnet_dec \ + -log -momentum 0.99 -clip 10 -maxbatch 800 -bptt 0 -maxepoch 100 -ndatathreads 4 \ + -seed 1002 -maxsourcelen 75 -num_lower_win_layers 1 -save_interval 250 -use_accel \ + -noearlystop -validbleu -lrshrink 1.25 -minepochtoaneal 15 \ + -annealing_type slow -savedir trainings/iwslt_en_de + +# NPMT model (IWSLT EN-VI) +$ mkdir -p trainings/iwslt_en_vi +$ fairseq train -sourcelang en -targetlang vi -datadir data-bin/iwslt15.tokenized.en-vi \ + -model npmt -nhid 512 -dec_unit_size 512 -dropout .4 -dropout_hid 0 -npmt_dropout .4 \ + -optim adam -lr 0.001 -batchsize 48 -log_interval 100 -nlayer 3 -nenclayer 2 -kwidth 7 \ + -max_segment_len 7 -rnn_mode LSTM -group_size 800 -use_resnet_enc -use_resnet_dec -log \ + -momentum 0.99 -clip 500 -maxbatch 800 -bptt 0 -maxepoch 50 -ndatathreads 4 -seed 1002 \ + -maxsourcelen 75 -num_lower_win_layers 1 -save_interval 250 -use_accel -noearlystop \ + -validbleu -nembed 512 -lrshrink 1.25 -minepochtoanneal 8 -annealing_type slow \ + -savedir trainings/iwslt_en_vi ``` + By default, `fairseq train` will use all available GPUs on your machine. Use the [CUDA_VISIBLE_DEVICES](http://acceleware.com/blog/cudavisibledevices-masking-gpus) environment variable to select specific GPUs or `-ngpus` to change the number of GPU devices that will be used. ### Generation Once your model is trained, you can translate with it using `fairseq generate` (for binarized data) or `fairseq generate-lines` (for text). -Here, we'll do it for a fully convolutional model: +Here, we'll do it for a NPMT model: ``` -# Optional: optimize for generation speed -$ fairseq optimize-fconv -input_model trainings/fconv/model_best.th7 -output_model trainings/fconv/model_best_opt.th7 # Translate some text $ DATA=data-bin/iwslt14.tokenized.de-en $ fairseq generate-lines -sourcedict $DATA/dict.de.th7 -targetdict $DATA/dict.en.th7 \ - -path trainings/fconv/model_best_opt.th7 -beam 10 -nbest 2 -| [target] Dictionary: 24738 types -| [source] Dictionary: 35474 types -> eine sprache ist ausdruck des menschlichen geistes . -S eine sprache ist ausdruck des menschlichen geistes . -O eine sprache ist ausdruck des menschlichen geistes . -H -0.23804219067097 a language is expression of human mind . -A 2 2 3 4 5 6 7 8 9 -H -0.23861141502857 a language is expression of the human mind . -A 2 2 3 4 5 7 6 7 9 9 -``` + -path trainings/iwslt_de_en/model_bestbleu.th7 -beam 1 -model npmt +| [target] Dictionary: 22823 types +| [source] Dictionary: 32010 types +> danke , aber das beste kommt noch . +max decoding: | 1:184 1:15| 2:4| 3:28| 4:6 4:282| 6:16 6:201 6:311| 8:5| +avg. phrase size 1.666667 +S danke , aber das beste kommt noch . +O danke , aber das beste kommt noch . +H -0.10934638977051 thank you , but the best is still coming . +A 1 -### CPU Generation -Use `fairseq tofloat` to convert a trained model to use CPU-only operations (this has to be done on a GPU machine): ``` -# Optional: optimize for generation speed -$ fairseq optimize-fconv -input_model trainings/fconv/model_best.th7 -output_model trainings/fconv/model_best_opt.th7 +where the ``max decoding`` suggests the output segments are ``| thank you | , | but | the best | is still coming | . |``, and ``avg. phrase size`` represents the average phrase length ``10/6 = 1.666667``. -# Convert to float -$ fairseq tofloat -input_model trainings/fconv/model_best_opt.th7 \ - -output_model trainings/fconv/model_best_opt-float.th7 -# Translate some text -$ fairseq generate-lines -sourcedict $DATA/dict.de.th7 -targetdict $DATA/dict.en.th7 \ - -path trainings/fconv/model_best_opt-float.th7 -beam 10 -nbest 2 -> eine sprache ist ausdruck des menschlichen geistes . -S eine sprache ist ausdruck des menschlichen geistes . -O eine sprache ist ausdruck des menschlichen geistes . -H -0.2380430996418 a language is expression of human mind . -A 2 2 3 4 5 6 7 8 9 -H -0.23861189186573 a language is expression of the human mind . -A 2 2 3 4 5 7 6 7 9 9 +Generation with the binarized test sets can be run as follows (not in batched mode), e.g. for German-English: ``` -# Pre-trained Models - -We provide the following pre-trained fully convolutional sequence-to-sequence models: - -* [wmt14.en-fr.fconv-cuda.tar.bz2](https://s3.amazonaws.com/fairseq/models/wmt14.en-fr.fconv-cuda.tar.bz2): Pre-trained model for [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) including vocabularies -* [wmt14.en-fr.fconv-float.tar.bz2](https://s3.amazonaws.com/fairseq/models/wmt14.en-fr.fconv-float.tar.bz2): CPU version of the above -* [wmt14.en-de.fconv-cuda.tar.bz2](https://s3.amazonaws.com/fairseq/models/wmt14.en-de.fconv-cuda.tar.bz2): Pre-trained model for [WMT14 English-German](https://nlp.stanford.edu/projects/nmt) including vocabularies -* [wmt14.en-de.fconv-float.tar.bz2](https://s3.amazonaws.com/fairseq/models/wmt14.en-de.fconv-float.tar.bz2): CPU version of the above -* [wmt16.en-ro.fconv-cuda.tar.bz2](https://s3.amazonaws.com/fairseq/models/wmt16.en-ro.fconv-cuda.tar.bz2): Pre-trained model for WMT16 English-Romanian including vocabularies. - This model was trained on the [original WMT bitext](http://statmt.org/wmt16/translation-task.html#Download) as well as [back-translated data](http://data.statmt.org/rsennrich/wmt16_backtranslations/en-ro) provided by Rico Sennrich. -* [wmt16.en-ro.fconv-float.tar.bz2](https://s3.amazonaws.com/fairseq/models/wmt16.en-ro.fconv-float.tar.bz2): CPU version of the above - -In addition, we provide pre-processed and binarized test sets for the models above: - -* [wmt14.en-fr.newstest2014.tar.bz2](https://s3.amazonaws.com/fairseq/data/wmt14.en-fr.newstest2014.tar.bz2): newstest2014 test set for WMT14 English-French -* [wmt14.en-fr.ntst1213.tar.bz2](https://s3.amazonaws.com/fairseq/data/wmt14.en-fr.ntst1213.tar.bz2): newstest2012 and newstest2013 test sets for WMT14 English-French -* [wmt14.en-de.newstest2014.tar.bz2](https://s3.amazonaws.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2): newstest2014 test set for WMT14 English-German -* [wmt16.en-ro.newstest2014.tar.bz2](https://s3.amazonaws.com/fairseq/data/wmt16.en-ro.newstest2016.tar.bz2): newstest2016 test set for WMT16 English-Romanian - -Generation with the binarized test sets can be run in batch mode as follows, e.g. for English-French on a GTX-1080ti: -``` -$ curl https://s3.amazonaws.com/fairseq/data/wmt14.en-fr.newstest2014.tar.bz2 | tar xvjf - - -$ fairseq generate -sourcelang en -targetlang fr -datadir data-bin/wmt14.en-fr -dataset newstest2014 \ - -path wmt14.en-fr.fconv-cuda/model.th7 -beam 5 -batchsize 128 | tee /tmp/gen.out +$ fairseq generate -sourcelang de -targetlang en -datadir data-bin/iwslt14.tokenized.de-en \ + -path trainings/iwslt_de_en/model_bestbleu.th7 -beam 10 -lenpen 1 -dataset test -model npmt | tee /tmp/gen.out ... -| Translated 3003 sentences (95451 tokens) in 136.3s (700.49 tokens/s) -| Timings: setup 0.1s (0.1%), encoder 1.9s (1.4%), decoder 108.9s (79.9%), search_results 0.0s (0.0%), search_prune 12.5s (9.2%) -| BLEU4 = 43.43, 68.2/49.2/37.4/28.8 (BP=0.996, ratio=1.004, sys_len=92087, ref_len=92448) +| Translated 6750 sentences (137891 tokens) in 3013.7s (45.75 tokens/s) +| Timings: setup 10.7s (0.4%), encoder 28.2s (0.9%), decoder 2747.9s (91.2%), search_results 0.0s (0.0%), search_prune 0.0s (0.0%) +| BLEU4 = 29.92, 64.7/37.9/23.8/15.3 (BP=0.973, ratio=1.027, sys_len=127660, ref_len=131141) # Word-level BLEU scoring: $ grep ^H /tmp/gen.out | cut -f3- | sed 's/@@ //g' > /tmp/gen.out.sys $ grep ^T /tmp/gen.out | cut -f2- | sed 's/@@ //g' > /tmp/gen.out.ref $ fairseq score -sys /tmp/gen.out.sys -ref /tmp/gen.out.ref -BLEU4 = 40.55, 67.6/46.5/34.0/25.3 (BP=1.000, ratio=0.998, sys_len=81369, ref_len=81194) -``` +BLEU4 = 29.92, 64.7/37.9/23.8/15.3 (BP=0.973, ratio=1.027, sys_len=127660, ref_len=131141) -# Join the fairseq community +``` -* Facebook page: https://www.facebook.com/groups/fairseq.users -* Google group: https://groups.google.com/forum/#!forum/fairseq-users -* Contact: [jgehring@fb.com](mailto:jgehring@fb.com), [michaelauli@fb.com](mailto:michaelauli@fb.com) # License -fairseq is BSD-licensed. -The license applies to the pre-trained models as well. -We also provide an additional patent grant. +fairseq is BSD-licensed. The released codes modified the original fairseq are BSD-licensed. +The rest of the codes are MIT-licensed. diff --git a/data/prepare-iwslt14.sh b/data/prepare-iwslt14.sh index 2effe47..db28fcb 100644 --- a/data/prepare-iwslt14.sh +++ b/data/prepare-iwslt14.sh @@ -57,7 +57,10 @@ for l in $src $tgt; do perl $TOKENIZER -threads 8 -l $l > $tmp/$tok echo "" done -perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175 + +# Only use up to 50 as in https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh +# and https://github.com/rizar/actor-critic-public +perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 50 for l in $src $tgt; do perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l done diff --git a/data/prepare-iwslt15.sh b/data/prepare-iwslt15.sh new file mode 100755 index 0000000..52d7938 --- /dev/null +++ b/data/prepare-iwslt15.sh @@ -0,0 +1,34 @@ +#!/bin/sh +# Copied from https://github.com/tensorflow/nmt/blob/master/nmt/scripts/download_iwslt15.sh +# +# Download small-scale IWSLT15 Vietnames to English translation data for NMT +# model training. +# +# Usage: +# ./download_iwslt15.sh path-to-output-dir +# +# If output directory is not specified, "./iwslt15" will be used as the default +# output directory. + +OUT_DIR="${1:-iwslt15}" +SITE_PREFIX="https://nlp.stanford.edu/projects/nmt/data" + +mkdir -v -p $OUT_DIR + +# Download iwslt15 small dataset from standford website. +echo "Download training dataset train.en and train.vi." +curl -o "$OUT_DIR/train.en" "$SITE_PREFIX/iwslt15.en-vi/train.en" +curl -o "$OUT_DIR/train.vi" "$SITE_PREFIX/iwslt15.en-vi/train.vi" + +echo "Download dev dataset tst2012.en and tst2012.vi." +curl -o "$OUT_DIR/tst2012.en" "$SITE_PREFIX/iwslt15.en-vi/tst2012.en" +curl -o "$OUT_DIR/tst2012.vi" "$SITE_PREFIX/iwslt15.en-vi/tst2012.vi" + +echo "Download test dataset tst2013.en and tst2013.vi." +curl -o "$OUT_DIR/tst2013.en" "$SITE_PREFIX/iwslt15.en-vi/tst2013.en" +curl -o "$OUT_DIR/tst2013.vi" "$SITE_PREFIX/iwslt15.en-vi/tst2013.vi" + +echo "Download vocab file vocab.en and vocab.vi." +curl -o "$OUT_DIR/vocab.en" "$SITE_PREFIX/iwslt15.en-vi/vocab.en" +curl -o "$OUT_DIR/vocab.vi" "$SITE_PREFIX/iwslt15.en-vi/vocab.vi" + diff --git a/de-en_example.png b/de-en_example.png new file mode 100755 index 0000000..9da9459 Binary files /dev/null and b/de-en_example.png differ diff --git a/fairseq/models/DummyCriterion.lua b/fairseq/models/DummyCriterion.lua new file mode 100755 index 0000000..df1c624 --- /dev/null +++ b/fairseq/models/DummyCriterion.lua @@ -0,0 +1,25 @@ +-- Copyright (c) Microsoft Corporation. All rights reserved. +-- Licensed under the MIT License. +-- +--[[ +-- +-- Dummy Criterion +-- +--]] + +local DummyCriterion, parent = torch.class('nn.DummyCriterion', 'nn.Criterion') + +function DummyCriterion:__init() + parent.__init(self) +end + +function DummyCriterion:updateOutput(input, target) + self.output = torch.mean(input) + return self.output +end + +function DummyCriterion:updateGradInput(input, target) + local n = input:nElement() + self.gradInput = input.new(input:size()):fill(1.0/n) + return self.gradInput +end diff --git a/fairseq/models/avgpool_model.lua b/fairseq/models/avgpool_model.lua index 7ea1b93..d815ce1 100644 --- a/fairseq/models/avgpool_model.lua +++ b/fairseq/models/avgpool_model.lua @@ -5,6 +5,9 @@ -- the root directory of this source tree. An additional grant of patent rights -- can be found in the PATENTS file in the same directory. -- +-- Copyright (c) Microsoft Corporation. All rights reserved. +-- Licensed under the BSD License. +-- --[[ -- -- This model closely follows the conditional setup of rnn-lib v1, with -name @@ -101,7 +104,7 @@ AvgpoolModel.makeAttention = argcheck{ local encoderOutPooled, encoderOutSingle = encoderOut:split(2) -- Projection of previous hidden state onto source word space - local prevhProj = nn.Linear(config.nhid, config.nembed)(prevh) + local prevhProj = nn.Linear(config.dec_unit_size, config.nembed)(prevh) local decoderRep = nn.CAddTable()({prevhProj, input}) -- Compute scores (usually denoted with alpha) using a simple dot @@ -148,7 +151,7 @@ AvgpoolModel.makeDecoderRNN = argcheck{ local rnn = nn.CLSTM{ attention = attnmodule, inputsize = config.nembed, - hidsize = config.nhid, + hidsize = config.dec_unit_size, nlayer = config.nlayer, winitfun = function(network) rmutils.defwinitfun(network, config.init_range) @@ -171,8 +174,8 @@ AvgpoolModel.makeDecoderRNN = argcheck{ end local scaleHidden = nn.Identity() - if config.nhid ~= config.nembed then - scaleHidden = nn.Linear(config.nhid, config.nembed) + if config.dec_unit_size ~= config.nembed then + scaleHidden = nn.Linear(config.dec_unit_size, config.nembed) end local decoderRNNOut = scaleHidden( diff --git a/fairseq/models/bgru_model.lua b/fairseq/models/bgru_model.lua new file mode 100755 index 0000000..47a3e78 --- /dev/null +++ b/fairseq/models/bgru_model.lua @@ -0,0 +1,174 @@ +-- Copyright (c) 2017-present, Facebook, Inc. +-- All rights reserved. +-- +-- This source code is licensed under the license found in the LICENSE file in +-- the root directory of this source tree. An additional grant of patent rights +-- can be found in the PATENTS file in the same directory. +-- +-- Copyright (c) Microsoft Corporation. All rights reserved. +-- Licensed under the BSD License. +-- +--[[ +-- +-- This model uses a bi-directional LSTM encoder. The direction is reversed +-- between layers and two separate columns run in parallel: one on the normal +-- input and one on the reversed input (as described in +-- http://arxiv.org/abs/1606.04199). +-- +-- The attention mechanism and the decoder setup are identical to the avgpool +-- model. +-- +--]] + +require 'nn' +require 'rnnlib' +local usecudnn = pcall(require, 'cudnn') +local argcheck = require 'argcheck' +local mutils = require 'fairseq.models.utils' +local rutils = require 'rnnlib.mutils' + +local BGRUModel = torch.class('BGRUModel', 'AvgpoolModel') + +BGRUModel.makeEncoderColumn = argcheck{ + {name='self', type='BGRUModel'}, + {name='config', type='table'}, + {name='inith', type='nngraph.Node'}, + {name='input', type='nngraph.Node'}, + {name='nlayers', type='number'}, + call = function(self, config, inith, input, nlayers) + local rnnconfig = { + inputsize = config.nembed, + hidsize = config.nhid, + nlayer = 1, + winitfun = function(network) + rutils.defwinitfun(network, config.init_range) + end, + usecudnn = usecudnn, + } + + local rnn = nn.GRU(rnnconfig) + rnn.saveHidden = false + local output = nn.SelectTable(-1)(nn.SelectTable(2)( + rnn({inith, input}):annotate{name = 'encoderRNN'} + )) + rnnconfig.inputsize = config.nhid + + for i = 2, nlayers do + if config.dropout_hid > 0 then + output = nn.MapTable(nn.Dropout(config.dropout_hid))(output) + end + local rnn = nn.GRU(rnnconfig) + rnn.saveHidden = false + output = nn.SelectTable(-1)(nn.SelectTable(2)( + rnn({ + inith, + nn.ReverseTable()(output), + }) + )) + end + return output + end +} + +BGRUModel.makeEncoder = argcheck{ + doc=[[ +This encoder runs a forward and backward LSTM network and concatenates their +top-most hidden states. +]], + {name='self', type='BGRUModel'}, + {name='config', type='table'}, + call = function(self, config) + local sourceIn = nn.Identity()() + local inith, tokens = sourceIn:split(2) + + local dict = config.srcdict + local lut = mutils.makeLookupTable(config, dict:size(), + dict.pad_index) + local embed + if config.dropout_src > 0 then + embed = nn.MapTable(nn.Sequential() + :add(lut) + :add(nn.Dropout(config.dropout_src)))(tokens) + else + embed = nn.MapTable(lut)(tokens) + end + + local col1 = self:makeEncoderColumn{ + config = config, + inith = inith, + input = embed, + nlayers = config.nenclayer, + } + local col2 = self:makeEncoderColumn{ + config = config, + inith = inith, + input = nn.ReverseTable()(embed), + nlayers = config.nenclayer, + } + + -- Each column will switch direction between layers. Before merging, + -- they should both run in the same direction (here: forward). + if config.nenclayer % 2 == 0 then + col1 = nn.ReverseTable()(col1) + else + col2 = nn.ReverseTable()(col2) + end + + local prepare = nn.Sequential() + -- Concatenate forward and backward states + prepare:add(nn.JoinTable(2, 2)) + -- Scale down to nhid for further processing + prepare:add(nn.Linear(config.nhid * 2, config.nembed, false)) + -- Add singleton dimension for subsequent joining + prepare:add(nn.View(-1, 1, config.nembed)) + + local joinedOutput = nn.JoinTable(1, 2)( + nn.MapTable(prepare)( + nn.ZipTable()({col1, col2}) + ) + ) + if config.dropout_hid > 0 then + joinedOutput = nn.Dropout(config.dropout_hid)(joinedOutput) + end + + -- avgpool_model.makeDecoder() expects two encoder outputs, one for + -- attention score computation and the other one for applying them. + -- We'll just use the same output for both. + return nn.gModule({sourceIn}, { + joinedOutput, nn.Identity()(joinedOutput) + }) + end +} + +BGRUModel.prepareSource = argcheck{ + {name='self', type='BGRUModel'}, + call = function(self) + -- Device buffers for samples + local buffers = { + source = {}, + } + + -- NOTE: It's assumed that all encoders start from the same hidden + -- state. + local encoderRNN = mutils.findAnnotatedNode( + self:network(), 'encoderRNN' + ) + assert(encoderRNN ~= nil) + + return function(sample) + -- Encoder input + local source = {} + for i = 1, sample.source:size(1) do + buffers.source[i] = buffers.source[i] + or torch.Tensor():type(self:type()) + source[i] = mutils.sendtobuf(sample.source[i], + buffers.source[i]) + end + + local initialHidden = encoderRNN:initializeHidden(sample.bsz) + return {initialHidden, source} + end + end +} + +return BGRUModel diff --git a/fairseq/models/c_sample_dp.cc b/fairseq/models/c_sample_dp.cc new file mode 100755 index 0000000..ae2ed81 --- /dev/null +++ b/fairseq/models/c_sample_dp.cc @@ -0,0 +1,220 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// Forward-backward probability computation using dynamic programming. +// +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace std; + +extern "C" { + #include "lua.h" + #include "lualib.h" + #include "lauxlib.h" +}; + +#define real double + +#define IND_LOGPY(i,j,k) ((i)*(T2+1)*(max_segment_len+1)+(j)*(max_segment_len+1)+(k)) +#define IND_ALPHA(i,j) ((i)*(T2+1)+(j)) +#define IND_BETA(i,j) ((i)*(T2+1)+(j)) + +#define LOGINF 1000000 +#define EPS 1e-10 +#define MAX(x,y) ((x)>(y))?(x):(y) + +inline real log1sub(real log_a) { + return log1p(-exp(log_a)); +} + +inline real logadd(real log_a, real log_b) { + return max(log_a, log_b) + log1p(exp(-abs(log_a-log_b))); +} + +void set_all(real* x, int start, int end, real val) { + for (int i = start; i < end; ++i) + x[i] = val; +} + +void in_place_softmax(real* x, int start, int end) { + real max_x = x[start]; + for (int i = start+1; i < end; ++i) { + max_x = max(max_x, x[i]); + } + real sum_x = 0.; + for (int i = start; i < end; ++i) { + x[i] = exp(x[i] - max_x); + sum_x += x[i]; + } + for (int i = start; i < end; ++i) { + x[i] /= sum_x; + } +} + +void in_place_cumsum(real* x, int start, int end) { + real sum_x = 0.0; + for (int i = start; i < end; ++i) { + sum_x += x[i]; + x[i] = sum_x; + } +} + +extern "C" { + static int c_sample_dp(lua_State*); + static int c_reverse_log_cumsum(lua_State*); + int luaopen_libdp_lib(lua_State*); +}; + +typedef struct { + int batch_size; + int T1; + int T2; + int max_segment_len; + real* logpy; + real* alpha; + real* beta; + real* seg_weight; + real* ylength; + real* xlength; +} strct_states; + +static void subprocess_c_sample_dp(strct_states* s, int p_batch) { + //default_random_engine generator; + //uniform_real_distribution distribution(0.0,1.0); + int batch_size = s->batch_size; + int T1 = s->T1; + int T2 = s->T2; + int max_segment_len = s->max_segment_len; + real* logpy = s->logpy + p_batch * T1 * (T2+1) * (max_segment_len+1); + real* alpha = s->alpha + p_batch * (T1+1) * (T2+1); + real* beta = s->beta + p_batch * (T1+1) * (T2+1); + real* seg_weight = s->seg_weight + p_batch * T1 * (T2+1) * (max_segment_len+1); + int ylength = (int)(s->ylength[p_batch]); + int xlength = (int)(s->xlength[p_batch]); + alpha[IND_ALPHA(0, 0)] = 0.; + for (int t = 1; t <= T1; ++t) { + for (int j = 0; j <= ylength; ++j) { + int j_low = max(1, j - max_segment_len + 1); + for (int j_start = j_low; j_start <= j+1; ++j_start) { + real prob = alpha[IND_ALPHA(t-1, j_start-1)] + logpy[IND_LOGPY(t-1, j_start-1, j-j_start+1)]; + alpha[IND_ALPHA(t, j)] = logadd(alpha[IND_ALPHA(t, j)], prob); + } + } + } + beta[IND_BETA(xlength, ylength)] = 0.; + for (int t = xlength-1; t >= 0; --t) { + for (int j = 0; j <= ylength; ++j) { + int j_high = min(ylength, j + max_segment_len); + for (int j_end = j; j_end <= j_high; ++j_end) { + real prob = beta[IND_BETA(t+1, j_end)] + logpy[IND_LOGPY(t, j, j_end-j)]; + beta[IND_BETA(t, j)] = logadd(beta[IND_BETA(t, j)], prob); + } + } + } + for (int t = 1; t <= T1; ++t) { + int jstart_l = max(1, ylength - (T1 - t + 1) * max_segment_len +1); + int jstart_u = min(ylength+1, (t - 1) * max_segment_len + 1); + for (int j_start = jstart_l; j_start <= jstart_u; j_start++) { + int j_len = min(max_segment_len, ylength-j_start+1); + int j_end = j_start + j_len - 1; + for (int j = j_start-1; j <= j_end; ++j) { + seg_weight[IND_LOGPY(t-1, j_start-1, j-j_start+1)] + = logpy[IND_LOGPY(t-1, j_start-1,j-j_start+1)] + + alpha[IND_ALPHA(t-1, j_start-1)] + + beta[IND_BETA(t, j)]; + } + } + } +} + +static int c_sample_dp(lua_State* L) { + strct_states s; + s.batch_size = (int)(lua_tonumber(L, 1)); + s.T1 = (int)(lua_tonumber(L, 2)); + s.T2 = (int)(lua_tonumber(L, 3)); + s.max_segment_len = (int)(lua_tonumber(L, 4)); + int num_thread = (int)(lua_tonumber(L, 5)); + s.logpy = (real*)((unsigned long long)(lua_tonumber(L, 6))); + s.alpha = (real*)((unsigned long long)(lua_tonumber(L, 7))); + s.beta = (real*)((unsigned long long)(lua_tonumber(L, 8))); + s.seg_weight = (real*)((unsigned long long)(lua_tonumber(L, 9))); + s.ylength = (real*)((unsigned long long)(lua_tonumber(L, 10))); + s.xlength = (real*)((unsigned long long)(lua_tonumber(L, 11))); + + int p = 0; + std::thread* ths = new std::thread[num_thread]; + while (p < s.batch_size) { + int p_ths = 0; + for(int i = 0; i < num_thread; i++) { + ths[p_ths++] = std::thread(subprocess_c_sample_dp, &s, p++); + if (p >= s.batch_size) break; + } + for(int i = 0; i < p_ths; i++) { + ths[i].join(); + } + } + delete[] ths; + return 0; +} + +static void subprocess_c_reverse_log_cumsum(strct_states* s, int p_batch) { + int batch_size = s->batch_size; + int T1 = s->T1; + int T2 = s->T2; + int max_segment_len = s->max_segment_len; + real* seg_weight = s->seg_weight + p_batch * T1 * (T2+1) * (max_segment_len+1); + int ylength = (int)(s->ylength[p_batch]); + for (int t = 1; t <= T1; ++t) { + int jstart_l = max(1, ylength - (T1 - t + 1) * max_segment_len +1); + int jstart_u = min(ylength+1, (t - 1) * max_segment_len + 1); + for (int j_start = jstart_l; j_start <= jstart_u; j_start++) { + int j_len = min(max_segment_len, ylength-j_start+1); + int j_end = j_start + j_len - 1; + for (int j = j_end-1; j >= j_start; --j) { + seg_weight[IND_LOGPY(t-1, j_start-1, j-j_start+1)] = logadd(seg_weight[IND_LOGPY(t-1, j_start-1, j-j_start+1)], + seg_weight[IND_LOGPY(t-1, j_start-1, j-j_start+2)]); + } + } + } +} + +static int c_reverse_log_cumsum(lua_State* L) { + strct_states s; + s.batch_size = (int)(lua_tonumber(L, 1)); + s.T1 = (int)(lua_tonumber(L, 2)); + s.T2 = (int)(lua_tonumber(L, 3)); + s.max_segment_len = (int)(lua_tonumber(L, 4)); + int num_thread = (int)(lua_tonumber(L, 5)); + s.seg_weight = (real*)((unsigned long long)(lua_tonumber(L, 6))); + s.ylength = (real*)((unsigned long long)(lua_tonumber(L, 7))); + + int p = 0; + std::thread* ths = new std::thread[num_thread]; + while (p < s.batch_size) { + int p_ths = 0; + for(int i = 0; i < num_thread; i++) { + ths[p_ths++] = std::thread(subprocess_c_reverse_log_cumsum, &s, p++); + if (p >= s.batch_size) break; + } + for(int i = 0; i < p_ths; i++) { + ths[i].join(); + } + } + delete[] ths; + return 0; + +} + +int luaopen_libdp_lib(lua_State* L) { + lua_register(L, "c_sample_dp", c_sample_dp); + lua_register(L, "c_reverse_log_cumsum", c_reverse_log_cumsum); + return 0; +} diff --git a/fairseq/models/compute_logpy.cu b/fairseq/models/compute_logpy.cu new file mode 100755 index 0000000..9c739c1 --- /dev/null +++ b/fairseq/models/compute_logpy.cu @@ -0,0 +1,158 @@ +/* Copyright (c) Microsoft Corporation. All rights reserved. + Licensed under the MIT License. + + CUDA implementation of the compute_logpy_post + +*/ + +#include +#include + +#include "THC.h" +#include "THCTensor.h" + +extern "C" { + #include "lua.h" + #include "luaT.h" + #include "lualib.h" + #include "lauxlib.h" +}; + +extern "C" { + static int compute_logpy_prep(lua_State* L); + static int compute_logpy_post(lua_State* L); + int luaopen_libcompute_logpy_lib(lua_State* L); +}; + +const float loginf = 1000000.0; + +template +T *getStoragePtr(lua_State* L, THCT * tct) +{ + T *ptr; + if (tct->storage) { + ptr = (T*)(tct->storage->data + tct->storageOffset); + } else { + lua_pushfstring(L, "THCudaTensor cannot be an empty tensor"); + lua_error(L); + } + return ptr; +} + +int compute_logpy_prep(lua_State* L) +{ + THCudaTensor *hidden_inputs_tensor = static_cast(luaT_checkudata(L, 1, "torch.CudaTensor")); + THCudaTensor *xlength_tensor = static_cast(luaT_checkudata(L, 2, "torch.CudaTensor")); + THCudaTensor *yref_tensor = static_cast(luaT_checkudata(L, 3, "torch.CudaTensor")); + THCudaTensor *ylength_tensor = static_cast(luaT_checkudata(L, 4, "torch.CudaTensor")); + + int batch_size = (int)(lua_tonumber(L, 5)); + int T1 = (int)(lua_tonumber(L, 6)); + int T2 = (int)(lua_tonumber(L, 7)); + + THCudaTensor *concat_hts_g_tensor = static_cast(luaT_checkudata(L, 8, "torch.CudaTensor")); + THCudaTensor *concat_inputs_g_tensor = static_cast(luaT_checkudata(L, 9, "torch.CudaTensor")); + + float *hidden_inputs = getStoragePtr(L, hidden_inputs_tensor); + float *xlength = getStoragePtr(L, xlength_tensor); + float *yref = getStoragePtr(L, yref_tensor); + float *ylength = getStoragePtr(L, ylength_tensor); + float *concat_hts_g = getStoragePtr(L, concat_hts_g_tensor); + float *concat_inputs_g = getStoragePtr(L, concat_inputs_g_tensor); + + + + return 0; +} + +__global__ void compute_logpy_post_kernel( float *t_prob_all, + float *yref, + float *ylength, + float *logpy, + int *sorted_schedule, + int s, + int max_jlen, + int vocab_size, + int batch_size, + int batch_max_segment_len, + int T1, + int T2, + int si) +{ + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + + int local_ylength = int(ylength[threadIdx.x]); + + int t = sorted_schedule[(blockIdx.x + si - 1) * 4]; + int j_start = sorted_schedule[(blockIdx.x + si - 1) * 4 + 1]; + int j_len = sorted_schedule[(blockIdx.x + si - 1) * 4 + 2]; + + float local_t_vec = 0; + + if (j_start <= (local_ylength + 1)) + { + local_t_vec = loginf + t_prob_all[idx * (max_jlen+1) * vocab_size + vocab_size - 1]; + + atomicAdd(&(logpy[threadIdx.x * T1 * (T2+1) * (batch_max_segment_len+1) + + (t-1) * (T2+1) * (batch_max_segment_len+1) + + (j_start-1) * (batch_max_segment_len+1) ]), + local_t_vec + ); + } + + float tmp_result = 0; + for (int i = 1; (i < j_len + 1) && (i + j_start <= (local_ylength+1)); i++) + { + tmp_result += t_prob_all[idx * (max_jlen+1) * vocab_size + (i-1) * vocab_size + int(yref[threadIdx.x * T2 + j_start + i - 2]) - 1]; + + local_t_vec = loginf + tmp_result + + t_prob_all[idx * (max_jlen+1) * vocab_size + (i) * vocab_size + vocab_size - 1]; + + atomicAdd(&(logpy[threadIdx.x * T1 * (T2+1) * (batch_max_segment_len+1) + + (t-1) * (T2+1) * (batch_max_segment_len+1) + + (j_start-1) * (batch_max_segment_len+1) + i ]), + local_t_vec + ); + } +} + +int compute_logpy_post(lua_State* L) +{ + THCudaTensor *t_prob_all_tensor = static_cast(luaT_checkudata(L, 1, "torch.CudaTensor")); + THCudaTensor *yref_tensor = static_cast(luaT_checkudata(L, 2, "torch.CudaTensor")); + THCudaTensor *ylength_tensor = static_cast(luaT_checkudata(L, 3, "torch.CudaTensor")); + THCudaTensor *logpy_tensor = static_cast(luaT_checkudata(L, 4, "torch.CudaTensor")); + THCudaIntTensor *sorted_schedule_tensor = static_cast(luaT_checkudata(L, 5, "torch.CudaIntTensor")); + + int s = (int)(lua_tonumber(L, 6)); + int max_jlen = (int)(lua_tonumber(L, 7)); + int vocab_size = (int)(lua_tonumber(L, 8)); + int batch_size = (int)(lua_tonumber(L, 9)); + int batch_max_segment_len = (int)(lua_tonumber(L, 10)); + int T1 = (int)(lua_tonumber(L, 11)); + int T2 = (int)(lua_tonumber(L, 12)); + int si = (int)(lua_tonumber(L, 13)); + + float *t_prob_all = getStoragePtr(L, t_prob_all_tensor); + float *yref = getStoragePtr(L, yref_tensor); + float *ylength = getStoragePtr(L, ylength_tensor); + float *logpy = getStoragePtr(L, logpy_tensor); + int *sorted_schedule = getStoragePtr(L, sorted_schedule_tensor); + + dim3 blockDim(batch_size); + dim3 gridDim(s); + compute_logpy_post_kernel<<>>(t_prob_all, yref, ylength, logpy, sorted_schedule, + s, max_jlen, vocab_size, batch_size, batch_max_segment_len, + T1, T2, si); + + cudaDeviceSynchronize(); + + return 0; +} + +int luaopen_libcompute_logpy_lib(lua_State* L) { + lua_register(L, "compute_logpy_prep", compute_logpy_prep); + lua_register(L, "compute_logpy_post", compute_logpy_post); + return 0; +} + diff --git a/fairseq/models/init.lua b/fairseq/models/init.lua index 62e099d..526327c 100644 --- a/fairseq/models/init.lua +++ b/fairseq/models/init.lua @@ -5,6 +5,9 @@ -- the root directory of this source tree. An additional grant of patent rights -- can be found in the PATENTS file in the same directory. -- +-- Copyright (c) Microsoft Corporation. All rights reserved. +-- Licensed under the BSD License. +-- --[[ -- -- init files for the models. @@ -14,7 +17,14 @@ require 'fairseq.models.model' require 'fairseq.models.avgpool_model' require 'fairseq.models.blstm_model' +require 'fairseq.models.bgru_model' require 'fairseq.models.fconv_model' require 'fairseq.models.selection_blstm_model' require 'fairseq.models.conv_model' require 'fairseq.models.ensemble_model' +require 'fairseq.models.npmt_model' +require 'fairseq.models.mRNN' +require 'fairseq.models.npmt' +require 'fairseq.models.npmt_utils' +require 'fairseq.models.DummyCriterion' +require 'fairseq.models.window_attn' \ No newline at end of file diff --git a/fairseq/models/mRNN.lua b/fairseq/models/mRNN.lua new file mode 100755 index 0000000..665a346 --- /dev/null +++ b/fairseq/models/mRNN.lua @@ -0,0 +1,254 @@ +-- Copyright (c) Microsoft Corporation. All rights reserved. +-- Licensed under the MIT License. +-- +--[[ +-- +-- Input: a tensor of input: eq_length * batch_size * input_dim or a table of h0 and input +-- where h0: num_layer * batch_size * hidden_dim; the first dimension optional if num_layer == 1 +-- Output: a tensor of eq_length * batch_size * output_dim +-- +--]] +-- + +require("nn") +require("cudnn") +require("cutorch") + +local mRNN, parent = torch.class('nn.mRNN', 'nn.Container') + +function mRNN:__init(input_dim, hidden_dim, bd, mode, has_otherOutput, use_resnet, use_skip_mode, dropout, add_output) + parent.__init(self) + local batchFirst = batchFirst or true + assert(batchFirst == true) + + if use_resnet then + print("resnet is used in current layer") + end + + self.bd = bd or false + self.mode = mode or "CUDNN_GRU" + self.has_otherOutput = has_otherOutput or false + self.dropout = dropout or 0 + self.add_output = add_output or false + self.use_resnet = use_resnet or false + + local rnn_hidden_dim = hidden_dim + if self.bd then + if self.add_output then + print("bd rnn output is added") + else + rnn_hidden_dim = rnn_hidden_dim / 2 + end + end + + self.rnn_hidden_dim = rnn_hidden_dim + local rnn = cudnn.RNN(input_dim, rnn_hidden_dim, 1, batchFirst) + if use_skip_mode then + assert(not self.bd and input_dim == hidden_dim) + rnn.inputMode = 'CUDNN_SKIP_INPUT' + end + self.rnn = rnn + + rnn.mode = self.mode + if self.bd then + rnn.numDirections = 2 + rnn.bidirectional = 'CUDNN_BIDIRECTIONAL' + end + rnn:reset() + self:add(rnn) + if use_resnet and input_dim ~= hidden_dim then + self.input_proj = nn.Bottle(nn.Linear(input_dim, hidden_dim, false)) + self:add(self.input_proj) + end +end + +function mRNN:setStates(h0, c0) + self.rnn.hiddenInput = h0:clone() + if c0 then + self.rnn.cellInput = c0:clone() + end +end + +function mRNN:getNextMemInput() + if self.rnn.cellOutput then + return self.rnn.cellOutput:clone() + else + return self.rnn.hiddenOutput:clone() + end +end + +function mRNN:updateOutput(input) + self.recompute_backward = true + local c0, h0, x + if torch.type(input) == "table" then + assert(not self.bd) + if #input == 2 then + h0, x = unpack(input) + if (h0:dim() == 2) then + h0 = h0:view(1, h0:size(1), h0:size(2)) + end + if self.mode == "CUDNN_LSTM" then + self.rnn.cellInput = h0 + else + self.rnn.hiddenInput = h0 + end + elseif #input == 3 then + c0, h0, x = unpack(input) + if (h0:dim() == 2) then + h0 = h0:view(1, h0:size(1), h0:size(2)) + end + if (c0:dim() == 2) then + c0 = c0:view(1, c0:size(1), c0:size(2)) + end + self.rnn.hiddenInput = h0 + self.rnn.cellInput = c0 + end + else + x = input + end + local rnn_output = self.rnn:updateOutput(x) + if self.bd and self.add_output then + rnn_output = torch.add(rnn_output[{{}, {}, {1,self.rnn_hidden_dim}}], + rnn_output[{{}, {}, {self.rnn_hidden_dim+1,2*self.rnn_hidden_dim}}]) + end + local output = rnn_output + if self.use_resnet then + if self.input_proj then + output = torch.add(output, self.input_proj:updateOutput(x)) + else + output = torch.add(output, x) + end + end + + if self.has_otherOutput then + local otherOutput + if self.mode == "CUDNN_LSTM" then + otherOutput = self.rnn.cellOutput:clone() + else + otherOutput = self.rnn.hiddenOutput:clone() + end + assert(otherOutput:dim() == 3) + if self.bd then + if self.add_output then + otherOutput = torch.add(otherOutput[1], otherOutput[2]) + else + otherOutput = torch.cat(otherOutput[1], otherOutput[2], 2) + end + else + assert(otherOutput:size(1) == 1) + otherOutput = otherOutput[1] + end + assert(otherOutput:dim() == 2) + self.output = {output, otherOutput} + else + self.output = output + end + return self.output +end + +function mRNN:backward(input, gradOutput, scale) + scale = scale or 1 + self.recompute_backward = false + if self.has_otherOutput then + local otherGradOutput + gradOutput, otherGradOutput = unpack(gradOutput) + assert(otherGradOutput:dim() == 2) + if self.bd then + local forwardGradOutput, backwardGradOutput + if self.add_output then + forwardGradOutput = otherGradOutput + backwardGradOutput = otherGradOutput + else + forwardGradOutput, backwardGradOutput = unpack(torch.chunk(otherGradOutput, 2, 2)) + end + otherGradOutput = torch.cat(forwardGradOutput:reshape(1, forwardGradOutput:size(1), forwardGradOutput:size(2)), + backwardGradOutput:reshape(1, backwardGradOutput:size(1), backwardGradOutput:size(2)), 1) + else + otherGradOutput = otherGradOutput:view(1, otherGradOutput:size(1), otherGradOutput:size(2)) + end + assert(otherGradOutput:dim() == 3) + if self.mode == "CUDNN_LSTM" then + self.rnn.gradCellOutput = otherGradOutput + else + self.rnn.gradHiddenOutput = otherGradOutput + end + end + + local h0, c0, x + if torch.type(input) == "table" then + if #input == 2 then + h0, x = unpack(input) + elseif #input == 3 then + c0, h0, x = unpack(input) + end + else + x = input + end + + local gradInput = x.new(x:size()):zero() + if self.use_resnet then + if self.input_proj then + gradInput:add(self.input_proj:backward(x, gradOutput)) + else + gradInput:add(gradOutput) + end + end + + if self.bd and self.add_output then + gradInput:add(self.rnn:backward(x, torch.cat(gradOutput, gradOutput, 3), scale)) + else + gradInput:add(self.rnn:backward(x, gradOutput, scale)) + end + + local h0_grad, c0_grad + if h0 and c0 then + h0_grad = self.rnn.gradHiddenInput + c0_grad = self.rnn.gradCellInput + self.gradInput = {c0_grad:view(h0:size()), h0_grad:view(c0:size()), gradInput} + elseif h0 then + if self.mode == "CUDNN_LSTM" then + h0_grad = self.rnn.gradCellInput + else + h0_grad = self.rnn.gradHiddenInput + end + self.gradInput = {h0_grad:view(h0:size()), gradInput} + else + self.gradInput = gradInput + end + return self.gradInput +end + +function mRNN:updateGradInput(input, gradOutput) + if self.recompute_backward then + self:backward(input, gradOutput, 1.0) + end + return self.gradInput +end + +function mRNN:accGradParameters(input, gradOutput, scale) + if self.recompute_backward then + self:backward(input, gradOutput, scale) + end +end + +function mRNN:clearState() + parent.clearState(self) + self.rnn:resetStates() +end + +function mRNN:__tostring__() + local tab = ' ' + local line = '\n' + local next = ' -> ' + local str = 'nn.Container' + str = str .. ' {' .. line .. tab .. '[input' + for i=1,#self.modules do + str = str .. next .. '(' .. i .. ')' + end + str = str .. next .. 'output]' + for i=1,#self.modules do + str = str .. line .. tab .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab) + end + str = str .. line .. '}' + return str +end diff --git a/fairseq/models/npmt.lua b/fairseq/models/npmt.lua new file mode 100755 index 0000000..860fb99 --- /dev/null +++ b/fairseq/models/npmt.lua @@ -0,0 +1,1143 @@ +-- Copyright (c) Microsoft Corporation. All rights reserved. +-- Licensed under the MIT License. +-- +--[[ +-- +-- NPMT model +-- it should actually be a criterion, but since it itself has +-- parameters, we still treat it as a module. +-- +--]] +-- +local NPMT, parent = torch.class('nn.NPMT', 'nn.Container') + +function torch.tab_to_ngram(tab) + local ngram = {} + for i = 1, #tab do + ngram[i] = " " .. table.concat(tab[i], " ") + end + return ngram +end + +function NPMT:__init(configs) + parent.__init(self) + self.use_cuda = configs.use_cuda or true + self.num_thread = self.num_thread or 5 + self.max_segment_len = configs.max_segment_len or 5 + self.ngpus = configs.ngpus or 1 + print("npmt is running on ", self.ngpus, " GPUs") + + self.use_cimpl = configs.use_cimpl or true + self.use_accel = configs.use_accel or false + + if self.use_cimpl then + require "libdp_lib" + end + if self.use_accel then -- TODO + print('Use CUDA accel') + require "libcompute_logpy_lib" + end + + self.vocab_size = configs.target_vocab_size or configs.dict:size() + 2 + self.embedding_size = configs.nembed or 256 + self.dec_unit_size = configs.dec_unit_size or configs.nhid or 256 + self.num_layers = configs.num_dec_layers or configs.nlayer or 1 + self.grad_check = configs.grad_check or false + self.rnn_mode = configs.npmt_rnn_mode or configs.rnn_mode or "LSTM" + self.nnlm_rnn_mode = configs.npmt_nnlm_rnn_mode or self.rnn_mode + if self.grad_check then + require('rnn') + self.start_symbol = configs.start_symbol or self.vocab_size + else + self.start_symbol = configs.start_symbol or self.vocab_size - 1 + end + + self.end_segment_symbol = configs.end_segment_symbol or self.vocab_size + self.pad_index = configs.dict.pad_index + self.use_nnlm = configs.use_nnlm or false + self.group_size = configs.group_size or 512 + + self.report_time = configs.report_time or false + self.precompute_gradInput = true + self.lm_concat = configs.lm_concat or false + self.dropout = configs.npmt_dropout or configs.dropout or 0 + self.nnlm_dropout = (configs.nnlm_dropout and configs.nnlm_dropout > 0 and configs.nnlm_dropout) or self.dropout or 0 + + if self.dropout > 0 then + print("npmt is using dropout ", self.dropout) + end + + self.seq = nn.Sequential() + if self.use_cuda then + require "cudnn" + local cudnn_mode = string.format("CUDNN_%s", self.rnn_mode) + local rnn = nn.mRNN(self.embedding_size, self.dec_unit_size, false, cudnn_mode, false, configs.use_resnet_dec) + self.seq:add(rnn) + if self.dropout > 0 then + self.seq:add(nn.Dropout(self.dropout)) + end + for i = 2, self.num_layers do + local rnn = nn.mRNN(self.dec_unit_size, self.dec_unit_size, false, cudnn_mode, false, configs.use_resnet_dec) + self.seq:add(rnn) + if self.dropout > 0 then + self.seq:add(nn.Dropout(self.dropout)) + end + end + else + local rnn_class + if self.rnn_mode == "GRU" then + rnn_class = nn.SeqGRU + else + rnn_class = nn.SeqLSTM + end + local rnn = rnn_class(self.embedding_size, self.dec_unit_size) + rnn.batchfirst = true + self.seq:add(rnn) + for i = 2, self.num_layers do + rnn = rnn_class(self.dec_unit_size, self.dec_unit_size) + rnn.batchfirst = true + self.seq:add(rnn) + end + end + + self.sub_outnet = nn.Sequential() + self.sub_outnet:add(self.seq) + self.sub_outnet:add(nn.Contiguous()) + self.sub_outnet:add(nn.View(-1, self.dec_unit_size)) + self.sub_outnet:add(nn.Linear(self.dec_unit_size, self.vocab_size)) + self.sub_outnet:add(nn.LogSoftMax()) + + self.outnet = nn.Sequential() + self.dict = nn.LookupTable(self.vocab_size, self.embedding_size) + self.outnet:add(nn.ParallelTable():add(nn.Identity()):add(self.dict)) + self.outnet:add(self.sub_outnet) + + self:add(self.outnet) + + if self.use_nnlm then -- if we opt to use an additional language model for input + self.nnlm = nn.Sequential() + if configs.npmt_separate_embeddding then + self.nnlm_dict = nn.LookupTable(self.vocab_size, self.embedding_size) + else + self.nnlm_dict = self.dict:clone("weight", "gradWeight") -- sharing the weights with self.dict + end + self.nnlm:add(self.nnlm_dict) + self.nnlm_rnn = nn.Sequential() + local nnlm_rnn_inst + if self.use_cuda then + local cudnn_mode = string.format("CUDNN_%s", self.nnlm_rnn_mode) + nnlm_rnn_inst = nn.mRNN(self.embedding_size, self.dec_unit_size, false, cudnn_mode, false, configs.use_resnet_dec) + self.nnlm_rnn:add(nnlm_rnn_inst) + if self.nnlm_dropout> 0 then + self.nnlm_rnn:add(nn.Dropout(self.nnlm_dropout)) + end + for i = 2, self.num_layers do + nnlm_rnn_inst = nn.mRNN(self.dec_unit_size, self.dec_unit_size, false, cudnn_mode, false, configs.use_resnet_dec) + self.nnlm_rnn:add(nnlm_rnn_inst) + if self.nnlm_dropout > 0 then + self.nnlm_rnn:add(nn.Dropout(self.nnlm_dropout)) + end + end + else + if self.rnn_mode == "GRU" then + nnlm_rnn_inst = nn.SeqGRU(self.embedding_size, self.dec_unit_size) + else + nnlm_rnn_inst = nn.SeqLSTM(self.embedding_size, self.dec_unit_size) + end + nnlm_rnn_inst.batchfirst = true + self.nnlm_rnn:add(nnlm_rnn_inst) + for i = 2, self.num_layers do + if self.rnn_mode == "GRU" then + nnlm_rnn_inst = nn.SeqGRU(self.dec_unit_size, self.dec_unit_size) + else + nnlm_rnn_inst = nn.SeqLSTM(self.dec_unit_size, self.dec_unit_size) + end + nnlm_rnn_inst.batchfirst = true + self.nnlm_rnn:add(nnlm_rnn_inst) + end + end + self.nnlm:add(self.nnlm_rnn) + self:add(self.nnlm) + if self.lm_concat then + self.lm_concat_proj = nn.Linear(self.dec_unit_size*2, self.dec_unit_size, false) + self:add(lm_concat_proj) + end + end + + self.logpy = torch.Tensor() + self.alpha = torch.Tensor() + self.beta = torch.Tensor() + self.logpy_per_data = torch.Tensor() + self.seg_weight = torch.Tensor() + self.seg_weight_cum = torch.Tensor() + + if self.use_cuda then + cudnn.convert(self.outnet, cudnn) + self:cuda() + end +end + +function NPMT:get_jstart_range(t, T1, minT2, maxT2) + return math.max(1, minT2 - (T1-t+1)* self.batch_max_segment_len + 1), math.min(maxT2+1, (t-1) * self.batch_max_segment_len + 1) +end + +function NPMT:compute_logpy(hidden_inputs, xlength, yref, ylength, batch_size, T1, T2) + self.outnet:evaluate() --- set outnet in evalate mode + self.logpy:resize(batch_size, T1, T2+1, self.batch_max_segment_len+1):fill(-torch.loginf()) +-- self.logpy_c:resize(batch_size, T1, T2+1, self.batch_max_segment_len+1):fill(-torch.loginf()) + -- for word-based, each word is padded with zero (so that we can -- + -- easily know how long each word is), then in the following code, + -- we will turn this into a proper sequence and padded with + -- end-symbol + -- + -- for letter-based, the entire sequence is padded with end-symbol + + local start_vector = yref.new(batch_size,1):fill(self.start_symbol) + if torch.type(start_vector) ~= "torch.CudaTensor" then + start_vector = start_vector:long() + end + + if self.use_nnlm then + self.nnlm_input = torch.cat(start_vector, yref) + self.nnlm_output = self.nnlm:forward(self.nnlm_input) + end + + local y_input + local minT2 = ylength:min() + + local schedule = {} + for t = 1, T1 do + local jstart_l, jstart_u = self:get_jstart_range(t, T1, minT2, T2) + for j_start = jstart_l, jstart_u do + local j_len = math.min(self.batch_max_segment_len, T2-j_start+1) + local j_end = j_start + j_len - 1 + table.insert(schedule, {t, j_start, j_len, j_end}) + end + end + if #schedule == 0 then + return nil + end + local _, schedule_order = torch.sort(torch.Tensor(schedule)[{{}, 3}]) + local sorted_schedule = {} + for si = 1, #schedule do + table.insert(sorted_schedule, schedule[schedule_order[si]]) + end + + self.sorted_schedule = sorted_schedule + local sorted_schedule_tensor = torch.CudaIntTensor(sorted_schedule) -- for compute_logpy_post use + +-- print(os.clock(), "forward", self.sorted_schedule[{1,1}], T1, hidden_inputs:size(2)) + + local concat_inputs = torch.Tensor() + local concat_hts = torch.Tensor() + if self.use_cuda then + concat_inputs = concat_inputs:cuda() + concat_hts = concat_hts:cuda() + end + self.group_size = math.max(self.group_size, batch_size) + + concat_inputs:resize(self.group_size, self.batch_max_segment_len + 1) + concat_hts:resize(self.group_size, self.dec_unit_size) + + local si = 1 + while si <= #sorted_schedule do + local si_next = math.min(si + math.floor(self.group_size / batch_size) - 1, #sorted_schedule) + local s = si_next - si + 1 + local max_jlen = sorted_schedule[si_next][3] + + local t_concatInputs = concat_inputs[{{1, s * batch_size}, {1, 1 + max_jlen}}] + local t_concatHts = concat_hts[{{1, s * batch_size}, {}}] + t_concatInputs:fill(self.end_segment_symbol) + t_concatHts:zero() + + for ell = si, si_next do + local t, j_start, j_len, j_end = unpack(sorted_schedule[ell]) + local low_idx, high_idx = (ell-si)*batch_size+1, (ell-si+1)*batch_size + y_input = start_vector:clone() + if j_len > 0 then + y_input = torch.cat({y_input, yref[{{}, {j_start,j_end}}]}) + end + local hidden_input = hidden_inputs[{{}, t, {}}] + if self.use_nnlm then + if self.lm_concat then + local hidden_input_concat = torch.cat(hidden_input, self.nnlm_output[{{}, j_start, {}}], 2) + hidden_input = self.lm_concat_proj:updateOutput(hidden_input_concat) + else + hidden_input = torch.add(hidden_input, self.nnlm_output[{{}, j_start, {}}]) + end + end + t_concatHts[{{low_idx, high_idx}, {}}]:copy(hidden_input) + t_concatInputs[{{low_idx, high_idx}, {1, y_input:size(2)}}]:copy(y_input) + end + + local t_prob_all = self.outnet:updateOutput({t_concatHts, t_concatInputs}):view(s*batch_size, max_jlen+1, self.vocab_size) + + if self.use_accel then + compute_logpy_post( t_prob_all, yref, ylength, self.logpy, sorted_schedule_tensor, + s, max_jlen, self.vocab_size, batch_size, self.batch_max_segment_len, T1, T2, si) + else + -- Torch version of compute_logpy_post + local t_vec = t_prob_all.new(batch_size) + local t_valid = t_prob_all.new(batch_size) + for ell = si, si_next do + local t, j_start, j_len, j_end = unpack(sorted_schedule[ell]) + local low_idx, high_idx = (ell-si)*batch_size+1, (ell-si+1)*batch_size + local t_prob = t_prob_all[{{low_idx, high_idx}, {}, {}}] + + local t_vec_whole = nil + if j_len > 0 then + t_vec_whole = t_prob[{{},{1, j_len},{}}]:gather(3, yref[{{},{j_start, j_end}}]:contiguous():view(batch_size, j_len, 1)):view(batch_size, j_len) + end + + t_valid:copy(ylength:ge(j_start-1)) -- a 0/1 vector of length batch_size + self.logpy[{{},t,j_start,1}]:add(torch.cmul(t_valid, torch.loginf() + t_prob[{{},1,self.end_segment_symbol}])) + + t_vec:zero() + for j = j_start, j_end do --- this implies j_end >= j_start (when j=j_start-1, it means an empty segment) + t_valid:copy(ylength:ge(j)) -- a 0/1 vector of length batch_size + -- Use gather to fetch the corresponding values in the yref + t_vec:add(t_vec_whole[{{}, j-j_start+1}]) + -- when j = j_start-1, this j-j_start+2 is 1, which is the first index, in WASM, + -- index 1 is for empty segment (segment length 0) while in segment.lua, index 1 is for + -- segment length 1. So they differ by shifting 1 index. + -- If non-empty, add end_symbol + t_vec; else add end_symbol + self.logpy[{{},t,j_start,j-j_start+2}]:add( + torch.cmul(t_valid, torch.loginf() + t_vec + t_prob[{{},j-j_start+2,self.end_segment_symbol}])) + end + end + end + si = si_next + 1 + end + -- For debug use. Need to declare and use logpy_c +-- print(torch.all(torch.eq(self.logpy_c, self.logpy))) +-- io.write("finall: Press to continue...") +-- io.read() +end + +function NPMT:print_best_path(xlength_, yref_, ylength_, vocab) + assert(self.alpha:size(1) == 1) -- only work for batch size 1 + local T1 = xlength_[1] + local T2 = ylength_[1] + local yref = yref_[1] + local logpy = self.logpy[{1, {}, {}, {}}] + local alpha = logpy.new(T1+1, T2+1) + local prev = logpy.new(T1+1, T2+1):fill(-1) + alpha:fill(-torch.loginf()) + alpha[{1,1}] = 0 + for t = 1, T1 do + for j = 0, T2 do + local j_low = math.max(1, j-self.batch_max_segment_len+1) + for j_start = j_low, j+1 do + local logprob = alpha[{t, j_start}] + logpy[{t, j_start, j-j_start+2}] + if logprob > alpha[{t+1, j+1}] then + alpha[{t+1, j+1}] = logprob + prev[{t, j+1}] = j_start-1 + end + end + end + end + local j = T2 + local out_str = "|" + for t = T1, 1, -1 do + local prev_j = prev[{t, j+1}] + for k = j, prev_j+1, -1 do + out_str = vocab[yref[k]] .. out_str + end + if j > prev_j then + out_str = "|" .. out_str + end + j = prev_j + end + print("best path: ", out_str) + return out_str +end + +function NPMT:alpha_and_beta(xlength, ylength, batch_size, T1, T2) + self.alpha:resize(batch_size, T1+1, T2+1):fill(-torch.loginf()) + self.beta:resize(batch_size, T1+1, T2+1):fill(-torch.loginf()) + self.seg_weight:resizeAs(self.logpy):fill(-torch.loginf()) + + if self.use_cimpl then + self.logpy = self.logpy:double() + self.alpha = self.alpha:double() + self.beta = self.beta:double() + ylength = ylength:double() + xlength = xlength:double() + self.seg_weight = self.seg_weight:double() + + c_sample_dp( + batch_size, + T1, + T2, + self.batch_max_segment_len, + self.num_thread, + tonumber(torch.data(self.logpy, true)), + tonumber(torch.data(self.alpha, true)), + tonumber(torch.data(self.beta, true)), + tonumber(torch.data(self.seg_weight, true)), + tonumber(torch.data(ylength, true)), + tonumber(torch.data(xlength, true))) + + if (self.use_cuda) then + self.logpy = self.logpy:cuda() + self.alpha = self.alpha:cuda() + self.beta = self.beta:cuda() + self.seg_weight = self.seg_weight:cuda() + ylength = ylength:cuda() + xlength = xlength:cuda() + else + ylength = ylength:long() + xlength = xlength:long() + end + else + --- not use c implementation --- + self.alpha[{{}, 1, 1}]:zero() + for t = 1, T1 do + for j = 0, T2 do + local j_low = math.max(1, j-self.batch_max_segment_len+1) + for j_start = j_low, j+1 do + local logprob = self.alpha[{{}, t, j_start}] + self.logpy[{{}, t, j_start, j-j_start+2}] + self.alpha[{{}, t+1, j+1}] = torch.logadd(self.alpha[{{}, t+1, j+1}], logprob) + end + end + end + for i = 1, batch_size do + self.beta[{i, xlength[i]+1, ylength[i]+1}] = 0 + end + for t = T1-1, 0, -1 do + for j = 0, T2 do + for j_end = j, math.min(T2, j + self.batch_max_segment_len) do + local logprob = self.beta[{{}, t+2, j_end+1}] + self.logpy[{{}, t+1, j+1, j_end-j+1}] + self.beta[{{}, t+1, j+1}] = torch.logadd(self.beta[{{}, t+1, j+1}], logprob) + end + end + end + + local minT2 = ylength:min() + for t = 1, T1 do + local jstart_l, jstart_u = self:get_jstart_range(t, T1, minT2, T2) + for j_start = jstart_l, jstart_u do + local j_len = math.min(self.batch_max_segment_len, T2-j_start+1) + local j_end = j_start + j_len - 1 + for j = j_start-1, j_end do + self.seg_weight[{{}, t, j_start, j-j_start+2}] = self.logpy[{{}, t, j_start, j-j_start+2}] + + self.alpha[{{}, t, j_start}] + + self.beta[{{}, t+1, j+1}] + end + end + end + end + + self.logpy_per_data = self.beta[{{}, 1, 1}]:clone() + if self.report_time then + local logpy_per_data_alpha = self.alpha.new(batch_size) + for i = 1, batch_size do + logpy_per_data_alpha[i] = self.alpha[{i, xlength[i]+1, ylength[i]+1}] + end + print(string.format("%.25f", torch.sum(self.logpy_per_data) - torch.sum(logpy_per_data_alpha))) + print(torch.sum(self.logpy_per_data)) + end + + self.seg_weight:add(-self.logpy_per_data:view(batch_size, 1, 1, 1):repeatTensor(1, T1, T2+1, self.batch_max_segment_len+1)) + self.seg_weight_cum:resizeAs(self.seg_weight):fill(-torch.loginf()) + + if self.use_cimpl then + self.seg_weight_cum:copy(self.seg_weight) + self.seg_weight_cum = self.seg_weight_cum:double() + ylength = ylength:double() + c_reverse_log_cumsum( + batch_size, + T1, + T2, + self.batch_max_segment_len, + self.num_thread, + tonumber(torch.data(self.seg_weight_cum, true)), + tonumber(torch.data(ylength, true))) + if (self.use_cuda) then + ylength= ylength:cuda() + self.seg_weight_cum = self.seg_weight_cum:cuda() + else + ylength = ylength:long() + end + self.seg_weight:exp() -- make it actual weight + self.seg_weight_cum:exp() + else + self.seg_weight:exp() -- make it actual weight + self.seg_weight_cum:copy(self.seg_weight) + self.seg_weight_cum = self.seg_weight_cum:index(4, torch.linspace(self.batch_max_segment_len+1, 1, self.batch_max_segment_len+1):long()) + :cumsum(4) + :index(4, torch.linspace(self.batch_max_segment_len+1, 1, self.batch_max_segment_len+1):long()) + end +end + +function NPMT:compute_gradients(hidden_inputs, xlength, yref, ylength, batch_size, T1, T2) + self.outnet:training() --- set outnet in training mode + local grad_hidden_inputs = hidden_inputs.new(hidden_inputs:size()):zero() + + local start_vector = yref.new(batch_size,1):fill(self.start_symbol) + if torch.type(start_vector) ~= "torch.CudaTensor" then + start_vector = start_vector:long() + end + + local nnlm_gradOutput + if self.use_nnlm then + nnlm_gradOutput = self.nnlm_output.new(self.nnlm_output:size()):zero() + end + + local sorted_schedule = self.sorted_schedule -- copy from forward + assert(#sorted_schedule > 0) +-- print(os.clock(), "backward", self.sorted_schedule[{1,1}], T1, hidden_inputs:size(2)) + + local concat_inputs = torch.Tensor() + local concat_hts = torch.Tensor() + local gradOutput = torch.Tensor() + if (self.use_cuda) then + concat_inputs = concat_inputs:cuda() + concat_hts = concat_hts:cuda() + gradOutput = gradOutput:cuda() + end + + concat_inputs:resize(self.group_size, self.batch_max_segment_len + 1) + concat_hts:resize(self.group_size, self.dec_unit_size) + gradOutput:resize(self.group_size, self.batch_max_segment_len + 1, self.vocab_size) + + local grad_scale = -1.0 / (yref:size(1) * self.ngpus) + local y_input + local si = 1 + + local skip_sample = false + for si = 1, #sorted_schedule do + if sorted_schedule[si][1] > hidden_inputs:size(2) then + skip_sample = true + end + end + if skip_sample then + print('skip') + else + while si <= #sorted_schedule do + local si_next = math.min(si + math.floor(self.group_size / batch_size) - 1, #sorted_schedule) + local s = si_next - si + 1 + local max_jlen = sorted_schedule[si_next][3] + + local t_concatInputs = concat_inputs[{{1, s * batch_size}, {1, max_jlen + 1}}] + local t_concatHts = concat_hts[{{1, s * batch_size}, {}}] + local t_gradOutput = gradOutput[{{1, s * batch_size}, {1, max_jlen +1}, {}}] + t_concatInputs:fill(self.end_segment_symbol) + t_concatHts:zero() + t_gradOutput:zero() + + for ell = si, si_next do + local t, j_start, j_len, j_end = unpack(sorted_schedule[ell]) + local low_idx, high_idx = (ell-si)*batch_size+1, (ell-si+1)*batch_size + y_input = start_vector:clone() + if j_end >= j_start then + y_input = torch.cat({y_input, yref[{{}, {j_start,j_end}}]}) + end +-- if t > hidden_inputs:size(2) then +-- print("xlength", xlength) +-- print("ylength", ylength) +-- print("ylength", yref) +-- print("size", hidden_inputs:size()) +-- print("t", t, j_start, j_len, j_end, T1) +-- end + local hidden_input = hidden_inputs[{{}, t, {}}] + if self.use_nnlm then + if self.lm_concat then + local hidden_input_concat = torch.cat(hidden_input, self.nnlm_output[{{}, j_start, {}}], 2) + hidden_input = self.lm_concat_proj:updateOutput(hidden_input_concat) + else + hidden_input = torch.add(hidden_input, self.nnlm_output[{{}, j_start, {}}]) + end + end + t_concatHts[{{low_idx, high_idx}, {}}]:copy(hidden_input) + t_concatInputs[{{low_idx, high_idx}, {1, y_input:size(2)}}]:copy(y_input) + if j_len > 0 then + local yweight = self.seg_weight_cum[{{}, t, j_start, {2, j_len+1}}]:contiguous() + local ysnipt = yref[{{}, {j_start, j_end}}]:contiguous() + -- Use scatter to put batch of y batch to corresponding place + t_gradOutput[{{low_idx, high_idx}, {1, j_len}, {}}]:scatter(3, ysnipt:view(batch_size, j_len, 1), yweight:view(batch_size, j_len, 1)) + end + t_gradOutput[{{low_idx, high_idx}, {1,j_len+1}, self.end_segment_symbol}]:copy(self.seg_weight[{{}, t, j_start, {1,j_len+1}}]) + end + + t_gradOutput:mul(grad_scale) + local reshaped_t_gradOutput = t_gradOutput:reshape(s*batch_size*(max_jlen+1), self.vocab_size) + self.outnet:forward({t_concatHts, t_concatInputs}) + self.outnet:backward({t_concatHts, t_concatInputs}, reshaped_t_gradOutput) + reshaped_t_gradOutput:set() + + + for ell = si, si_next do + local t, j_start, j_len, j_end = unpack(sorted_schedule[ell]) + local t_valid = gradOutput.new(batch_size):zero() + t_valid:copy(xlength:ge(t)) + local low_idx, high_idx = (ell-si)*batch_size+1, (ell-si+1)*batch_size + local grad_input = torch.cmul(self.outnet.gradInput[1][{{low_idx, high_idx}, {}}], + t_valid:view(batch_size, 1):expand(batch_size, hidden_inputs:size(3))) + if self.nnlm then + if self.lm_concat then + local hidden_input_concat = torch.cat(hidden_inputs[{{}, t, {}}], self.nnlm_output[{{}, j_start, {}}], 2) + local concat_grad_input = self.lm_concat_proj:backward(hidden_input_concat, grad_input) + grad_hidden_inputs[{{},t,{}}]:add(concat_grad_input[{{}, {1, self.dec_unit_size}}]) + nnlm_gradOutput[{{}, j_start, {}}]:add(concat_grad_input[{{}, {self.dec_unit_size + 1, 2*self.dec_unit_size}}]) + else + nnlm_gradOutput[{{}, j_start, {}}]:add(grad_input) + grad_hidden_inputs[{{},t,{}}]:add(grad_input) + end + else + grad_hidden_inputs[{{},t,{}}]:add(grad_input) + end + end + si = si_next + 1 + end + end + if self.use_nnlm then + self.nnlm:backward(self.nnlm_input, nnlm_gradOutput) + end + + if (self.report_time) then + print(' compute time => ', os.clock() - t_clock) + end + + self.gradInput = {grad_hidden_inputs, + xlength.new(xlength:size()):zero(), + yref.new(yref:size()):zero(), + ylength.new(ylength:size()):zero()} +end + +function NPMT:forward_and_backward(input) + local t_clock = nil + if self.report_time then + t_clock = os.clock() + end + local hidden_inputs, xlength, yref, ylength = unpack(input) + -- hidden_inputs: [torch.CudaTensor of size batch_size, T1, hidden] + -- yref: [torch.CudaTensor of size batch_size, T2] + -- xlength: [torch.CudaTensor of size batch_size] + -- ylength: [torch.CudaTensor of size batch_size] + + local batch_size = hidden_inputs:size(1) + local T1, T2 = hidden_inputs:size(2), yref:size(2) + self.batch_max_segment_len = math.min(self.max_segment_len, T2) + + assert(hidden_inputs:dim() == 3) + if torch.type(yref) ~= "torch.CudaTensor" then + xlength = xlength:long() + yref = yref:long() + ylength = ylength:long() + else + xlength = xlength:cuda() + yref = yref:cuda() + ylength = ylength:cuda() + end + + --Initialization + self.output = nil + if self.precompute_gradInput then + self.gradInput = {} + end + + if (self.report_time) then + print(' Initialization time => ', os.clock() - t_clock) + end + + -- Step 1: compute log p(y_{j1:j2}|h_t) + if (self.report_time) then + t_clock = os.clock() + end + + self:compute_logpy(hidden_inputs, xlength, yref, ylength, batch_size, T1, T2) + + if (self.report_time) then + print(' Time for phase 1 ==> ', os.clock() - t_clock) + end + + -- Step 2: sum over the probs + if (self.report_time) then + t_clock = os.clock() + end + + self:alpha_and_beta(xlength, ylength, batch_size, T1, T2) + self.output = self.logpy_per_data.new(1):fill(-torch.sum(self.logpy_per_data)) + + if (self.report_time) then + print(' Time for phase 2 ==> ', os.clock() - t_clock) + end + + -- Step 3: compute gradients + if (self.report_time) then + t_clock = os.clock() + end + + if (self.precompute_gradInput) then + self:compute_gradients(hidden_inputs, xlength, yref, ylength, batch_size, T1, T2) + end + + if (self.report_time) then + print(' Time for phase 3 ==> ', os.clock() - t_clock) + end +end + +function NPMT:forward(input) + self:forward_and_backward(input) + return self.output +end + +function NPMT:backward(input, gradOutput, scale) + assert(self.precompute_gradInput) + return self.gradInput +end + +function NPMT:updateOutput(input) + return self:forward(input) +end + +function NPMT:updateGradInput(input, gradOutput) + if self.precompute_gradInput then + self:backward(input, gradOutput, 1.0) + end + return self.gradInput +end + +function NPMT:accGradParameters(input, gradOutput, scale) + if self.precompute_gradInput then + self:backward(input, gradOutput, scale) + end +end + +function NPMT:SelectRememberRNNStates(rnns, new_ht_idx) + local ht + if self.use_cuda then + if self.rnn_mode == "LSTM" then + ht = rnns[1].cellOutput:view(-1, self.dec_unit_size) + :index(1, torch.LongTensor(new_ht_idx)) + :view(1, #new_ht_idx, self.dec_unit_size) + else + ht = rnns[1].hiddenOutput:view(-1, self.dec_unit_size) + :index(1, torch.LongTensor(new_ht_idx)) + :view(1, #new_ht_idx, self.dec_unit_size) + end + for i = 1, #rnns do + rnns[i].hiddenInput = rnns[i].hiddenOutput:view(-1, self.dec_unit_size) + :index(1, torch.LongTensor(new_ht_idx)) + :view(1, #new_ht_idx, self.dec_unit_size) + if self.rnn_mode == "LSTM" then + rnns[i].cellInput = rnns[i].cellOutput:view(-1, self.dec_unit_size) + :index(1, torch.LongTensor(new_ht_idx)) + :view(1, #new_ht_idx, self.dec_unit_size) + end + end + else + --TODO fix lstm on cpu + ht = rnns[1]._output[1]:view(-1, self.dec_unit_size) + :index(1, torch.LongTensor(new_ht_idx)) + :view(#new_ht_idx, self.dec_unit_size) + for i = 1, #rnns do + rnns[i].h0 = rnns[i]._output[1]:view(-1, self.dec_unit_size) + :index(1, torch.LongTensor(new_ht_idx)) + :view(#new_ht_idx, self.dec_unit_size) + if rnns[i].cell then + rnns[i].c0 = rnns[i].cell[1]:view(-1, self.dec_unit_size) + :index(1, torch.LongTensor(new_ht_idx)) + :view(#new_ht_idx, self.dec_unit_size) + end + end + end + return ht +end + +function NPMT:rememberRNNStates(rnns) + local ht + if self.use_cuda then + if self.rnn_mode == "LSTM" then + ht = rnns[1].cellOutput:clone() + else + ht = rnns[1].hiddenOutput:clone() + end + for i = 1, #rnns do --remember old states + rnns[i].hiddenInput = rnns[i].hiddenOutput:clone() + if self.rnn_mode == "LSTM" then + rnns[i].cellInput = rnns[i].cellOutput:clone() + end + end + else + --TODO fix lstm on cpu + ht = rnns[1]._output[1]:clone() + for i = 1, #rnns do --remember old states + rnns[i].h0 = rnns[i]._output[1]:clone() + if rnns[i].cell then + rnns[i].c0 = rnns[i].cell[1]:clone() + end + end + end + return ht +end + +function NPMT:GetNNLMRnns() + local rnns, rnn_containers + if self.use_cuda then + rnns, rnn_containers = self.nnlm:findModules("cudnn.RNN") + else + if self.rnn_mode == "LSTM" then + rnns, rnn_containers = self.nnlm:findModules("nn.SeqLSTM") + elseif self.rnn_mode == "GRU" then + rnns, rnn_containers = self.nnlm:findModules("nn.SeqGRU") + else + assert(false) + end + end + return rnns +end + +function NPMT:GetOutnetRnns() + local rnns, rnn_containers + if self.use_cuda then + rnns, rnn_containers = self.outnet:findModules("cudnn.RNN") + else + if self.rnn_mode == "LSTM" then + rnns, rnn_containers = self.outnet:findModules("nn.SeqLSTM") + elseif self.rnn_mode == "GRU" then + rnns, rnn_containers = self.outnet:findModules("nn.SeqGRU") + else + assert(false) + end + end + return rnns +end + +function NPMT:clearOutnetStates() + local rnns = self:GetOutnetRnns() + for i = 1, #rnns do + rnns[i]:resetStates() + end +end + +function NPMT:resetRNNStates() + local rnns = self:GetOutnetRnns() + for i = 1, #rnns do + rnns[i]:resetStates() + end + if self.use_nnlm then + rnns = self:GetNNLMRnns() + for i = 1, #rnns do + rnns[i]:resetStates() + end + end +end + +function NPMT:training() + self.precompute_gradInput = true + self:resetRNNStates() + parent.training(self) +end + +function NPMT:evaluate() + self.precompute_gradInput = false + self:resetRNNStates() + parent.evaluate(self) +end + +function NPMT:clearState() + parent.clearState(self) + self:resetRNNStates() + self.logpy:set() + self.alpha:set() + self.beta:set() + self.logpy_per_data:set() + self.seg_weight:set() + self.seg_weight_cum:set() + self.output = 0. + self.gradInput = {} + self.sorted_schedule = {} + if self.use_nnlm then + if self.nnlm_output ~= nil then + self.nnlm_output:set() + end + if self.nnlm_input ~= nil then + self.nnlm_input:set() + end + end +end + +function NPMT:predict(input, xlength, test_mode) + -- TODO fixes + local batch_size, T1 = input:size(1), input:size(2) +-- assert(batch_size == 1) + local max_segment_len = self.max_segment_len + local sts_input = torch.Tensor()--:fill(self.start_symbol) + if (self.use_cuda) then + sts_input = sts_input:cuda() + end + local start_symbol = self.start_symbol + local rnns = self:GetOutnetRnns() + + local tab_output_seqs = {} + local tab_output_probs = {} + local test_segments = 0 + + for b = 1, batch_size do + local out_str = "|" + local output_symbol = nil + local output_seq = {} + local output_probs = {} + local num_segments = 0 + sts_input:resize(1, 1):fill(self.start_symbol) + + local nnlm_rnns, nnlm_output + if self.use_nnlm then + nnlm_rnns = self:GetNNLMRnns() + nnlm_output = self.nnlm:updateOutput(sts_input):view(-1, self.dec_unit_size) + self:rememberRNNStates(nnlm_rnns) + end + + for t = 1, xlength[b] do + sts_input[1][1] = start_symbol + local ht = input[{{b},t,{}}]:clone() + self:clearOutnetStates() + if self.use_nnlm then + if self.lm_concat then + ht = self.lm_concat_proj:updateOutput(torch.cat(ht, nnlm_output, 2)) + else + ht:add(nnlm_output) + end + end + local new_segment = false + for j = 1, max_segment_len do + local output_prob = self.outnet:updateOutput({ht, sts_input}):view(-1) -- 1-dimensional vector of length V + local max_prob, output_symbol = output_prob:view(-1):max(1) + output_symbol = output_symbol:squeeze() + if output_symbol == self.end_segment_symbol then + break -- finished reading this segment + else + if not new_segment then + num_segments = num_segments + 1 + new_segment = true + end + + table.insert(output_seq, output_symbol) + table.insert(output_probs, max_prob[1]) + if test_mode then +-- out_str = out_str .. " " .. output_symbol + -- input time t, output corresponding segments + out_str = out_str .. " " .. t .. ':' .. output_symbol +-- out_str = out_str .. " " .. vocab[output_symbol] + end + sts_input[1][1] = output_symbol + ht = self:rememberRNNStates(rnns) + if self.use_nnlm then + nnlm_output = self.nnlm:updateOutput(sts_input):view(-1, self.dec_unit_size) + self:rememberRNNStates(nnlm_rnns) + end + end + end + if new_segment and test_mode then + out_str = out_str .. '|' + end + end + + if #output_seq == 0 then + local eos_index = 3 + table.insert(output_seq, eos_index) -- eos + table.insert(output_probs, 1) + end + self:clearState() -- don't leave things for next one example + if test_mode then + print("max decoding:", out_str) + end + + table.insert(tab_output_seqs, torch.Tensor(output_seq)) + table.insert(tab_output_probs, output_probs) -- place holder, dummy + test_segments = test_segments + num_segments + end + tab_output_seqs = nn.FlattenTable():forward(tab_output_seqs) + local output_count = 0 + for i = 1, #tab_output_seqs do + output_count = output_count + tab_output_seqs[i]:nElement() + end + return tab_output_seqs, nn.FlattenTable():forward(tab_output_probs), output_count, test_segments +end + +function NPMT:beam_search(input, xlength, configs) + local configs = configs or {} + local word_weight = configs.lenpen or configs.word_weight or 0 + local beam_size = configs.beam_size or configs.beam or 20 + + local lm_weight = configs.lm_weight or 0 + local lm = configs.lm + + local batch_size, T1 = input:size(1), input:size(2) + local rnns = self:GetOutnetRnns() + + local max_segment_len = self.max_segment_len + local sts_input = torch.Tensor():cuda() + + local tab_output_seqs = {} + local tab_output_probs = {} + + for b = 1, batch_size do + sts_input:resize(1, 1):fill(self.start_symbol) + + local fin_trans = {} + local fin_probs = {} + + local nnlm_rnns, nnlm_output + if self.use_nnlm then + nnlm_output = self.nnlm:updateOutput(sts_input):view(-1, self.dec_unit_size) + end + + for t = 1, xlength[b] do + local trans_t = {} + local probs_t = {} + local fin_trans_t = {} + local fin_probs_t = {} + + local ht = input[{{b},t,{}}]:clone() + if t > 1 then + sts_input:resize(#fin_trans, 1):fill(self.start_symbol) + for i = 1, #fin_trans do + table.insert(trans_t, torch.copy_array(fin_trans[i])) + end + probs_t = torch.copy_array(fin_probs) + ht = ht:repeatTensor(#fin_trans, 1) + else + table.insert(trans_t, {}) + table.insert(probs_t, 0.) + sts_input:resize(1, 1):fill(self.start_symbol) + end + + local ngrams + if lm_weight > 0 then + ngrams = torch.tab_to_ngram(trans_t) + end + + -- nnlm 1. dropout, 2. nnlm with increasing/decreasing weights + if self.use_nnlm then + if self.lm_concat then + ht = self.lm_concat_proj:updateOutput(torch.cat(ht, nnlm_output, 2)) + elseif self.schedule_nnlm > 0 then + ht = ht + self.schedule_nnlm * nnlm_output + ht = self.lm_concat_proj:updateOutput(torch.cat(ht, nnlm_output, 2)) + else + ht:add(nnlm_output) + end + end + self:clearOutnetStates() + local nsamples = beam_size + for j = 1, max_segment_len + 1 do + local output_prob = self.outnet:updateOutput({ht, sts_input}) + local new_trans_t = {} + local new_probs_t = {} + local new_ht_idx = {} + if j == max_segment_len + 1 then -- we have to force to have the end_segment_symbol for each input + probs_t = torch.Tensor(probs_t):cuda() + output_prob[{{}, self.end_segment_symbol}] -- + word_weight + for k = 1, nsamples do + table.insert(fin_trans_t, trans_t[k]) + table.insert(fin_probs_t, probs_t[k]) + end + else + probs_t = torch.Tensor(probs_t):cuda():view(-1, 1):repeatTensor(1, self.vocab_size) + output_prob -- + word_weight + probs_t = probs_t:view(-1) + local _, sorted_idx = torch.sort(probs_t, true) + for k = 1, math.min(nsamples, sorted_idx:nElement()) do + local tran_id = math.floor((sorted_idx[k]-1) / self.vocab_size) + 1 + local word_id = (sorted_idx[k]-1) % self.vocab_size + 1 + if word_id == self.end_segment_symbol then + --- end symbol + table.insert(fin_trans_t, trans_t[tran_id]) + table.insert(fin_probs_t, probs_t[sorted_idx[k]]) + else + --- continue in the pool + local new_tran = torch.copy_array(trans_t[tran_id]) + table.insert(new_tran, word_id) + table.insert(new_trans_t, new_tran) + + local new_prob = probs_t[sorted_idx[k]] + word_weight + if lm_weight > 0 then + new_prob = new_prob + lm_weight * lookup_lm_prob(lm, ngrams[tran_id], tostring(word_id))[1] + end + table.insert(new_probs_t, new_prob) + table.insert(new_ht_idx, tran_id) + end + end + end + trans_t = new_trans_t + probs_t = new_probs_t + nsamples = #trans_t + if nsamples == 0 then + break + end + if lm_weight > 0 then + ngrams = torch.tab_to_ngram(trans_t) + end + sts_input:resize(#trans_t, 1):copy(torch.last_nelements(trans_t, 1, self.start_symbol)) + ht = self:SelectRememberRNNStates(rnns, new_ht_idx) + end + if t > 1 then + --- merge same sequences + local merge_fin_trans_t = {} + merge_fin_trans_t[table.concat(fin_trans_t[1], '-')] = {fin_trans_t[1], fin_probs_t[1]} + for i = 2, #fin_trans_t do + local tran_str = table.concat(fin_trans_t[i], '-') + if merge_fin_trans_t[tran_str] ~= nil then + merge_fin_trans_t[tran_str][2] = torch.logadd(merge_fin_trans_t[tran_str][2], fin_probs_t[i]) + else + merge_fin_trans_t[tran_str] = {fin_trans_t[i], fin_probs_t[i]} + end + end + fin_trans_t = {} + fin_probs_t = {} + for key, value in pairs(merge_fin_trans_t) do + table.insert(fin_trans_t, value[1]) + table.insert(fin_probs_t, value[2]) + end + end + fin_trans = fin_trans_t + fin_probs = fin_probs_t + + if self.use_nnlm then + nnlm_output:resize(#fin_trans, self.dec_unit_size):zero() + local max_len = #(fin_trans[1]) + for i = 1, #fin_trans do + max_len = math.max(max_len, #(fin_trans[i])) + end + local lm_input = nnlm_output.new(#fin_trans, max_len+1):fill(self.start_symbol) + for i = 1, #fin_trans do + if #fin_trans[i] > 0 then + lm_input[{i, {2,#(fin_trans[i])+1}}]:copy(torch.Tensor(fin_trans[i])) + end + end + local lm_output = self.nnlm:updateOutput(lm_input) + for i = 1, #fin_trans do + nnlm_output[{i, {}}]:copy(lm_output[{i, #(fin_trans[i])+1, {}}]) + end + end + end + + if configs.use_avg_prob then + for i = 1, #fin_probs do + fin_probs[i] = fin_probs[i] / (#fin_trans[i]) + end + end + + local _, sorted_idx = torch.sort(torch.Tensor(fin_probs), true) + + local output_seqs = {} + local output_probs = {} + + for i = 1, math.min(configs.beam_size, #fin_probs) do + if #fin_trans[sorted_idx[i]] > 0 then + table.insert(output_seqs, torch.Tensor(fin_trans[sorted_idx[i]])) + else + local eos_index = 3 + table.insert(output_seqs, torch.Tensor(1):fill(eos_index)) + end + table.insert(output_probs, fin_probs[sorted_idx[i]]) + end + self:clearState() -- don't leave things for next one example + table.insert(tab_output_seqs, output_seqs) + table.insert(tab_output_probs, output_probs) + end + + return nn.FlattenTable():forward(tab_output_seqs), nn.FlattenTable():forward(tab_output_probs) + +end diff --git a/fairseq/models/npmt_model.lua b/fairseq/models/npmt_model.lua new file mode 100755 index 0000000..5ac112b --- /dev/null +++ b/fairseq/models/npmt_model.lua @@ -0,0 +1,598 @@ +-- Copyright (c) 2017-present, Facebook, Inc. +-- All rights reserved. +-- +-- This source code is licensed under the license found in the LICENSE file in +-- the root directory of this source tree. An additional grant of patent rights +-- can be found in the PATENTS file in the same directory. +-- +-- Copyright (c) Microsoft Corporation. All rights reserved. +-- Licensed under the BSD License. +-- +--[[ +-- +-- This model closely follows the conditional setup of rnn-lib v1, with -name +-- clstm and -aux conv_attn. See the individual functions (makeEncoder, +-- makeDecoder) for detailed comments regarding the model architecture. +-- +--]] + +require 'nn' +require 'nngraph' +require 'rnnlib' +local argcheck = require 'argcheck' +local mutils = require 'fairseq.models.utils' +local rutils = require 'rnnlib.mutils' +local utils = require 'fairseq.utils' + +local cuda = utils.loadCuda() + +local NPMTModel, parent = torch.class('NPMTModel', 'Model') + +NPMTModel.make = argcheck{ + {name='self', type='NPMTModel'}, + {name='config', type='table'}, + call = function(self, config) + config.use_cuda = true + local encoder = self:makeEncoder(config) + local decoder = self:makeDecoder(config) + -- Wire up encoder and decoder + local input = nn.Identity()() + local sourceIn, xlength, targetIn, ylength = input:split(4) + -- reformat the shape + -- input to npmt is {hidden_inputs, xlength, yref, ylength} + local output = decoder({ + encoder(sourceIn):annotate{name = 'encoder'}, + xlength, + targetIn, + ylength + }):annotate{name = 'decoder'} + + return nn.gModule({input}, {output}) + end +} + +-- Use the same encoder as BLSTMModel + + +NPMTModel.makeEncoderColumn = argcheck{ + {name='self', type='NPMTModel'}, + {name='config', type='table'}, + {name='inith', type='nngraph.Node'}, + {name='input', type='nngraph.Node'}, + {name='nlayers', type='number'}, + call = function(self, config, inith, input, nlayers) + local rnnconfig = { + inputsize = config.nembed, + hidsize = config.nhid, + nlayer = 1, + winitfun = function(network) + rutils.defwinitfun(network, config.init_range) + end, + usecudnn = usecudnn, + } + + local rnn_class = nn.LSTM + if config.rnn_mode == "GRU" then + rnn_class = nn.GRU + end + + local rnn = rnn_class(rnnconfig) + rnn.saveHidden = false + local output = nn.SelectTable(-1)(nn.SelectTable(2)( + rnn({inith, input}):annotate{name = 'encoderRNN'} + )) + + if config.use_resnet_enc then + if config.nembed ~= config.nhid then + local input_proj = nn.MapTable(nn.Linear(config.nembed, config.nhid, false))(input) + output = nn.MapTable(nn.CAddTable())(nn.ZipTable()({input_proj, output})) + else + output = nn.MapTable(nn.CAddTable())(nn.ZipTable()({input, output})) + end + end + + rnnconfig.inputsize = config.nhid + + for i = 2, nlayers do + if config.dropout_hid > 0 then + output = nn.MapTable(nn.Dropout(config.dropout_hid))(output) + end + local rnn = rnn_class(rnnconfig) + rnn.saveHidden = false + local prev_input + if config.use_resnet_enc then + prev_input = nn.Identity()(output) + end + output = nn.SelectTable(-1)(nn.SelectTable(2)( + rnn({ + inith, + nn.ReverseTable()(output), + }) + )) + if config.use_resnet_enc then + output = nn.MapTable(nn.CAddTable())(nn.ZipTable()({prev_input, output})) + end + end + return output + end +} + +NPMTModel.makeEncoder = argcheck{ + doc=[[ +This encoder runs a forward and backward LSTM network and concatenates their +top-most hidden states. +]], + {name='self', type='NPMTModel'}, + {name='config', type='table'}, + call = function(self, config) + local sourceIn = nn.Identity()() + local inith, tokens = sourceIn:split(2) + + local dict = config.srcdict + local lut = mutils.makeLookupTable(config, dict:size(), + dict.pad_index) + local embed + if config.dropout_src > 0 then + embed = nn.MapTable(nn.Sequential() + :add(lut) + :add(nn.Dropout(config.dropout_src)))(tokens) + else + embed = nn.MapTable(lut)(tokens) + end + assert(config.num_lower_conv_layers + config.num_mid_conv_layers + config.num_high_conv_layers <= 1) + + -- Low level - Add temporal conv stride to reduced computations + if config.num_lower_conv_layers > 0 then + local conv_embed = nn.Sequential() + conv_embed:add(nn.MapTable(nn.View(-1, 1, config.nembed))) + conv_embed:add(nn.JoinTable(2)) -- Split table to tensor as it expects tensor {batch_size x T x nembed} + conv_embed:add(nn.Padding(2, 1-config.conv_kW_size))-- pad left with zeros + conv_embed:add(nn.TemporalConvolution(config.nembed, config.nembed, config.conv_kW_size, config.conv_dW_size)) + conv_embed:add(nn.ReLU()) + embed = conv_embed(embed):annotate{name = 'TemporalConv'} + end + + if config.num_lower_win_layers > 0 then + local reorder_embed = nn.Sequential() + -- Reshape as a table T elements of (batch_size x 1 x nembed) + if config.num_lower_conv_layers == 0 then + reorder_embed:add(nn.MapTable(nn.View(-1, 1, config.nembed))) + reorder_embed:add(nn.JoinTable(2)) -- Split table to tensor as it expects tensor {batch_size x T x nembed} + end + + if config.num_lower_win_layers > 0 then + local winattn_layer + if config.win_attn_type == 'ori' then + winattn_layer = nn.winAttn(config.nembed, config.kwidth, config.use_win_middle) + else + winattn_layer = nil -- Error + end + for i = 1, config.num_lower_win_layers do + reorder_embed:add(winattn_layer) + end + embed = reorder_embed(embed) + end + end + + -- Mid level - Add temporal conv stride to reduced computations + if config.num_mid_conv_layers > 0 then + local conv_embed = nn.Sequential() + if config.num_lower_win_layers == 0 and config.num_lower_conv_layers == 0 then + conv_embed:add(nn.MapTable(nn.View(-1, 1, config.nembed))) + conv_embed:add(nn.JoinTable(2)) -- Split table to tensor as it expects tensor {batch_size x T x nembed} + end + conv_embed:add(nn.Padding(2, 1-config.conv_kW_size))-- pad left with zeros + conv_embed:add(nn.TemporalConvolution(config.nembed, config.nembed, config.conv_kW_size, config.conv_dW_size)) + conv_embed:add(nn.ReLU()) + embed = conv_embed(embed):annotate{name = 'TemporalConv'} + end + if config.num_lower_conv_layers > 0 or config.num_lower_win_layers > 0 or config.num_mid_conv_layers > 0 then + embed = nn.SplitTable(2)(embed) + end + + local col1 = self:makeEncoderColumn{ + config = config, + inith = inith, + input = embed, + nlayers = config.nenclayer, + } + local col2 = self:makeEncoderColumn{ + config = config, + inith = inith, + input = nn.ReverseTable()(embed), + nlayers = config.nenclayer, + } + + -- Each column will switch direction between layers. Before merging, + -- they should both run in the same direction (here: forward). + if config.nenclayer % 2 == 0 then + col1 = nn.ReverseTable()(col1) + else + col2 = nn.ReverseTable()(col2) + end + + local prepare = nn.Sequential() + -- Concatenate forward and backward states + prepare:add(nn.JoinTable(2, 2)) + -- Scale down to nhid for further processing + prepare:add(nn.Linear(config.nhid * 2, config.dec_unit_size, false)) + -- Add singleton dimension for subsequent joining + prepare:add(nn.View(-1, 1, config.dec_unit_size)) + + local joinedOutput = nn.JoinTable(1, 2)( + nn.MapTable(prepare)( + nn.ZipTable()({col1, col2}) + ) + ) + if config.dropout_hid > 0 then + joinedOutput = nn.Dropout(config.dropout_hid)(joinedOutput) + end + + -- TODO add attention layer + + -- TODO add temporal conv stride to reduced computations + if config.num_high_conv_layers > 0 then + local conv_embed = nn.Sequential() + conv_embed:add(nn.Padding(2, 1-config.conv_kW_size))-- pad left with zeros + conv_embed:add(nn.TemporalConvolution(config.dec_unit_size, config.dec_unit_size, config.conv_kW_size, config.conv_dW_size)) + conv_embed:add(nn.ReLU()) + joinedOutput = conv_embed(joinedOutput):annotate{name = 'TemporalConv'} + end + + -- avgpool_model.makeDecoder() expects two encoder outputs, one for + -- attention score computation and the other one for applying them. + -- We'll just use the same output for both. + return nn.gModule({sourceIn}, {joinedOutput}) + end +} + + +NPMTModel.makeDecoder = argcheck{ + doc=[[ + Constructs a WASM. + ]], + {name='self', type='NPMTModel'}, + {name='config', type='table'}, + call = function(self, config) + -- input to npmt is {hidden_inputs, xlength, yref, ylength} + local input = nn.Identity()() + local encoderOut, xlength, targetIn, ylength = input:split(4) + local output = nn.NPMT(config)({encoderOut, xlength, targetIn, ylength}):annotate{name = 'npmt'} + return nn.gModule({input}, {output}) + end +} + + +NPMTModel.prepareSource = argcheck{ + {name='self', type='NPMTModel'}, + call = function(self) + -- Device buffers for samples + local buffers = { + source = {}, + xlength = {} + } + + -- NOTE: It's assumed that all encoders start from the same hidden + -- state. + local encoderRNN = mutils.findAnnotatedNode( + self:network(), 'encoderRNN' + ) + assert(encoderRNN ~= nil) + local conv_kW_size, conv_dW_size = 0, 0 + if mutils.findAnnotatedNode(self:network(), 'TemporalConv') then + if #mutils.findAnnotatedNode(self:network(), 'TemporalConv') > 3 then + conv_kW_size = mutils.findAnnotatedNode(self:network(), 'TemporalConv'):get(4).kW + conv_dW_size = mutils.findAnnotatedNode(self:network(), 'TemporalConv'):get(4).dW + else + conv_kW_size = mutils.findAnnotatedNode(self:network(), 'TemporalConv'):get(2).kW + conv_dW_size = mutils.findAnnotatedNode(self:network(), 'TemporalConv'):get(2).dW + end + end + + return function(sample) + -- Encoder input + local source = {} + local xlength = torch.Tensor(sample.bsz):zero() + local source_t = sample.source:t() + + local pad_index = 2 + local eos_index = 3 + local max_xlength = 0 + for i = 1, sample.bsz do + buffers.xlength[i] = buffers.xlength[i] or torch.Tensor():type(self:type()) + xlength[i] = source_t:size(2) - torch.sum(source_t[i]:eq(pad_index)) + xlength[i] = xlength[i] - torch.sum(source_t[i]:eq(eos_index)) + max_xlength = math.max(max_xlength, xlength[i]) + + source_t[{i, xlength[i]+1}] = pad_index + end + source_t = source_t[{{}, {1, max_xlength}}]:clone() + + for j = 1, source_t:size(2) do + buffers.source[j] = buffers.source[j] or torch.Tensor():type(self:type()) + source[j] = mutils.sendtobuf(source_t[{{}, j}], buffers.source[j]) + end + -- change xlength when there is a TemporalConv layer + if conv_dW_size > 0 then + for i = 1, sample.bsz do + xlength[i] = math.floor((xlength[i] - 1)/ conv_dW_size) + 1 -- Using temporal convolution + end + end + + local initialHidden = encoderRNN:initializeHidden(sample.bsz) + return {{initialHidden, source}, xlength} + end + end +} + + +NPMTModel.prepareHidden = argcheck{ + {name='self', type='NPMTModel'}, + call = function(self) + local decoderRNN = mutils.findAnnotatedNode( + self:network(), + 'decoder' + ) + assert(decoderRNN ~= nil) + + return function(sample) + -- The sample contains a _cont entry if this sample is a + -- continuation of a previous one (for truncated bptt training). In + -- that case, start from the RNN's previous hidden state. + if not sample._cont then + return decoderRNN:initializeHidden(sample.bsz) + else + return decoderRNN:getLastHidden() + end + end + end +} + +NPMTModel.prepareInput = argcheck{ + {name='self', type='NPMTModel'}, + call = function(self) + local buffers = { + input = {}, + } + + return function(sample) + -- Copy data to device buffers. Recurrent modules expect a table of + -- tensors as their input. + local input = {} + for i = 1, sample.input:size(1) do + buffers.input[i] = buffers.input[i] + or torch.Tensor():type(self:type()) + input[i] = mutils.sendtobuf(sample.input[i], + buffers.input[i]) + end + return input + end + end +} + +NPMTModel.prepareTarget = argcheck{ + {name='self', type='NPMTModel'}, + call = function(self) + local buffers = { + target = torch.Tensor():type(self:type()), + ylength = torch.Tensor():type(self:type()) + } + + return function(sample) + local target = mutils.sendtobuf(sample.target:t(), buffers.target) + local ylength = torch.Tensor(target:size(1)):zero() + + local pad_index = 2 + local eos_index = 3 + local max_ylength = 0 + for i = 1, target:size(1) do + ylength[i] = target:size(2) - torch.sum(target[i]:eq(pad_index)) + ylength[i] = ylength[i] - torch.sum(target[i]:eq(eos_index)) + max_ylength = math.max(ylength[i], max_ylength) + target[{i, ylength[i]+1}] = pad_index + end + target = target[{{},{1,max_ylength}}]:clone() + local ylength = mutils.sendtobuf(ylength, buffers.ylength) + + return {target, ylength} + end + end +} + +NPMTModel.prepareSample = argcheck{ + {name='self', type='NPMTModel'}, + call = function(self) + local prepareSource = self:prepareSource() + local prepareTarget = self:prepareTarget() + return function(sample) + local source = prepareSource(sample) + local target = prepareTarget(sample) + + local source, xlength = source[1], source[2] + local target, ylength = target[1], target[2] + sample.target = target + sample.input = {source, xlength, target, ylength} + end + end +} + + +NPMTModel.generate = argcheck{ + doc=[[ +Sentence generation. See search.lua for a description of search functions. +]], + {name='self', type='Model'}, + {name='config', type='table'}, + {name='sample', type='table'}, + {name='search', type='table'}, + call = function(self, config, sample, search) + local dict = config.dict + local minlen = config.minlen + local maxlen = config.maxlen + local bsz = sample.source:size(2) + local bbsz = config.beam * bsz + local callbacks = self:generationCallbacks(config, bsz) + local vocabsize = sample.targetVocab and sample.targetVocab:size(1) or dict:size() + + local timers = { + setup = torch.Timer(), + encoder = torch.Timer(), + decoder = torch.Timer(), + search_prune = torch.Timer(), + search_results = torch.Timer(), + } + + for k, v in pairs(timers) do + v:stop() + v:reset() + end + + timers.setup:resume() + local state = callbacks.setup(sample) + if cuda.cutorch then + cuda.cutorch.synchronize() + end + timers.setup:stop() + + timers.encoder:resume() + callbacks.encode(state) + timers.encoder:stop() + + timers.decoder:resume() + local results, output_count, num_segments = callbacks.decode(state) + if cuda.cutorch then + cuda.cutorch.synchronize() + end + timers.decoder:stop() + + timers.search_results:resume() +-- local results = table.pack(search.results()) + callbacks.finalize(state, sample, results) + timers.search_results:stop() + + local times = {} + for k, v in pairs(timers) do + times[k] = v:time() + end + -- hypos, scores, attns, t + local attns = {} + for i = 1, #results[2] do + attns[i] = torch.zeros(1, vocabsize) + end + table.insert(results, attns) + table.insert(results, times) + table.insert(results, output_count) + table.insert(results, num_segments) + -- TODO expect hypos, scores, attns, t + return table.unpack(results) + end +} + + +NPMTModel.generationSetup = argcheck{ + {name='self', type='NPMTModel'}, + {name='config', type='table'}, + {name='bsz', type='number'}, + call = function(self, config, bsz) + local beam = config.beam + local bbsz = beam * bsz + local m = self:network() + local prepareSource = self:prepareSource() + return function(sample) + m:evaluate() + local source = prepareSource(sample) + local state = { + sourceIn = source[1], + xlength = source[2], + } + return state + end + end +} + +NPMTModel.generationEncode = argcheck{ + {name='self', type='NPMTModel'}, + {name='config', type='table'}, + {name='bsz', type='number'}, + call = function(self, config, bsz) + local m = self:network() + local encoder = mutils.findAnnotatedNode(m, 'encoder') + local beam = config.beam + local bbsz = beam * bsz + + return function(state) + local encoderOut = encoder:forward(state.sourceIn) + + -- There will be 'beam' hypotheses for each sentence in the batch, + -- so duplicate the encoder output accordingly. +-- local index = torch.range(1, bsz + 1, 1 / beam) +-- index = index:narrow(1, 1, bbsz):floor():long() +-- for i = 1, encoderOut:size(1) do +-- encoderOut[i] = encoderOut[i]:index(1, index) +-- end + state.encoderOut = encoderOut + end + end +} + +NPMTModel.generationDecode = argcheck{ + {name='self', type='NPMTModel'}, + {name='config', type='table'}, + {name='bsz', type='number'}, + call = function(self, config, bsz) + local m = self:network() + + local npmt = mutils.findAnnotatedNode(m, 'npmt') + assert(npmt ~= nil) + -- TODO add more parameters for beam search + config.beam_size = config.beam + config.word_weight = config.lenpen + return function(state, targetIn) + local output_seqs, output_probs + local output_counts, num_segments = 0, 0 + if config.beam == 1 then + output_seqs, output_probs, output_counts, num_segments = npmt:predict(state.encoderOut, state.xlength, config.verbose or false) + else + output_seqs, output_probs = npmt:beam_search(state.encoderOut, state.xlength, config) + end + return {output_seqs, output_probs}, output_counts, num_segments + end + end +} + +NPMTModel.generationUpdate = argcheck{ + {name='self', type='NPMTModel'}, + {name='config', type='table'}, + {name='bsz', type='number'}, + call = function(self, config, bsz) + local bbsz = config.beam * bsz + local m = self:network() + local decoderRNN = mutils.findAnnotatedNode(m, 'decoder') + assert(decoderRNN ~= nil) + + return function(state, indexH) + local lastH = decoderRNN:getLastHidden(bbsz) + for i = 1, #state.prevhIn do + for j = 1, #state.prevhIn[i] do + local dim = lastH[i][j]:dim() - 1 + state.prevhIn[i][j]:copy(lastH[i][j]:index(dim, indexH)) + end + end + end + end +} + +function NPMTModel:float(...) + self.module:replace(function(m) + if torch.isTypeOf(m, 'nn.WrappedCudnnRnn') then + return mutils.wrappedCudnnRnnToLSTMs(m) + elseif torch.typename(m) == 'nn.SequenceTable' then + -- Use typename() to avoid matching RecurrentTables + return mutils.replaceCudnnRNNs(m) + end + return m + end) + return parent.float(self, ...) +end + +return NPMTModel diff --git a/fairseq/models/npmt_utils.lua b/fairseq/models/npmt_utils.lua new file mode 100755 index 0000000..e75ebc1 --- /dev/null +++ b/fairseq/models/npmt_utils.lua @@ -0,0 +1,767 @@ +-- Copyright (c) Microsoft Corporation. All rights reserved. +-- Licensed under the MIT License. +-- +--[[ +-- +-- Auxiliary functions +-- +--]] +-- + +-- log(ret) = log(a + b) +function torch.logadd(log_a, log_b) + if (type(log_a) == 'number') then + return math.max(log_a, log_b) + torch.log1p(torch.exp(-torch.abs(log_a - log_b))) + else + return torch.cmax(log_a, log_b) + torch.log1p(torch.exp(-torch.abs(log_a - log_b))) + end +end + +-- log(ret) = log(1-a) +function torch.log1sub(log_a) + return torch.log1p(-torch.exp(log_a)) +end + +-- log(ret) = log(a - b) +function torch.logsub(log_a, log_b) + assert(torch.all(torch.ge(log_a, log_b))) + return log_a + torch.log1p(-torch.exp(log_b - log_a) + 1e-15) +end + +-- log(ret) = log(+inf) +function torch.loginf() + return 1000000 +end + +function torch.copy_array(tab) + return torch.Tensor(tab):totable() +end + +function torch.last_nelements(tab, n, pad) + local n = n or 1 + local pad = pad or -1 + local out = torch.Tensor(#tab, n) + for i = 1, #tab do + for j = 1, n do + out[{i,j}] = tab[i][#(tab[i]) - n + j] or pad + end + end + return out +end + +-- generate n random rules for a vocabulary of V (excluding 1 and V) and probabilities specified in probs +-- return: a table of size n, each being a table of two 1-dimensional tensors +function gen_random_rules(V, n, probs, use_ctc, longer_input) + local lens_src + if use_ctc then + lens_src = torch.Tensor(n):fill(probs:nElement()) + assert(not longer_input, "use_ctc can not be used with longer input") + else + if longer_input then + lens_src = torch.Tensor(n):random(3, 15) + else + lens_src = torch.multinomial(probs, n, true) + end + end + local lens_trgt = torch.multinomial(probs, n, true) + if longer_input then + lens_trgt = torch.Tensor(n):random(1, 6) + end + local rules = {} + for i = 1, n do + local t_src = torch.Tensor(lens_src[i]):random(1, V) + local t_trgt = torch.Tensor(lens_trgt[i]):random(1, V) + table.insert(rules, {t_src, t_trgt}) + end + return rules +end + +-- generate n sequences each of length T, using those specified in rules +-- return three tables: input, output and markers, each has exactly n elements +function gen_random_data(rules, n, T, sort_output) + if sort_output then + print("--- the outputs are sorted ---") + end + local input = {} + local output = {} + local markers = {} + + for i = 1, n do + local output_indices = {} + local t_src = {} + local t_trgt = {} + local t_marker = {} + local count = 0 + local random_T = torch.random(1, T) + local rule_idx = {} + for k = 1, random_T do + local j = torch.random(1, #rules) + table.insert(rule_idx, j) + end + rule_idx = torch.sort(torch.Tensor(rule_idx)) + for k = 1, random_T do + local j = rule_idx[k] + table.insert(t_src, rules[j][1]) + table.insert(t_trgt, rules[j][2]) + count = count + rules[j][1]:nElement() + table.insert(t_marker, count) + table.insert(output_indices, j) + end + table.insert(input, torch.cat(t_src)) + table.insert(markers, torch.Tensor(t_marker)) + if sort_output then + local t_trgt_sorted = {} + _, indices = torch.sort(torch.Tensor(output_indices)) + for j = 1, indices:nElement() do + table.insert(t_trgt_sorted, t_trgt[indices[j]]) + end + table.insert(output, torch.cat(t_trgt_sorted)) + else + table.insert(output, torch.cat(t_trgt)) + end + end + return input, output, markers +end + +function toy_prepare_minibatch(input_table, output_table, markers_table, p, batch_size) + + local T1 = input_table[p]:nElement() + local T2 = 0 + local input = torch.Tensor(batch_size, T1):fill(2) + local markers = {} + local ylength = torch.Tensor(batch_size) + for i = 1, batch_size do + input[{i,{}}]:copy(input_table[p+i-1]) + table.insert(markers, markers_table[p+i-1]) + ylength[i] = output_table[p+i-1]:nElement() + if (ylength[i] > T2) then + T2 = ylength[i] + end + end + local yref = torch.Tensor(batch_size, T2):fill(2) + for i = 1, batch_size do + yref[{i,{1,ylength[i]}}]:copy(output_table[p+i-1]) + end + + return input, yref, ylength, markers +end + +function toy_evaluate_zeta_loss(logprob, markers) + + local loss = 0 + local T1 = logprob:size(2) + for i = 1, #markers do + loss = loss + torch.sum(logprob[{i,{}}]:gather(1, markers[i]:long())) + end + + return loss + +end + +function get_toy_data(V, sort_output, use_ctc, longer_input) + local n_train = 16384 + local n_test = 128 + local m = 100 + local T1 = 6 + local V = V + + -- generate data + local rules = gen_random_rules(V, m, torch.Tensor({1/3,1/3,1/3}), false, longer_input) + local input_train, output_train, markers_train = gen_random_data(rules, n_train, T1, sort_output) + local input_test, output_test, markers_test = gen_random_data(rules, n_test, T1, sort_output) + return {rules, + input_train, + output_train, + markers_train, + input_test, + output_test, + markers_test} +end + +function sort_data_by_length(input, output, max_sen_len) + max_sen_len = max_sen_len or nil + local all_lengths = {} + for i = 1, #input do + table.insert(all_lengths, input[i]:nElement()) + end + local _, sorted_idx = torch.sort(torch.Tensor(all_lengths)) + local sorted_input = {} + local sorted_output = {} + for i = 1, #input do + local j = sorted_idx[i] + if not max_sen_len then + table.insert(sorted_input, input[j]) + table.insert(sorted_output, output[j]) + elseif input[j]:nElement() <= max_sen_len then + table.insert(sorted_input, input[j]) + table.insert(sorted_output, output[j]) + end + end + return sorted_input, sorted_output +end + + +local log0 = -1000000. +function prepare_minibatch_3d(input, output, s, t, params) + local y_vocab = params.target_vocab_size + + local input_max_len = 0 + local input_feature_dim = 123 + local output_max_len = 0 + local batch_size = t - s + 1 + for i = 1, batch_size do + input_max_len = math.max(input_max_len, input[s+i-1]:nElement()/input_feature_dim) + output_max_len = math.max(output_max_len, output[s+i-1]:nElement()) + end + local batch_input = torch.Tensor(batch_size, input_max_len, input_feature_dim):fill(params.start_symbol) + local batch_output = torch.Tensor(batch_size, output_max_len):fill(y_vocab) + local xlength = torch.Tensor(batch_size):zero() + local ylength = torch.Tensor(batch_size):zero() + for i = 1, batch_size do + local input_sequence_length = input[s+i-1]:nElement() / input_feature_dim + batch_input[{i, {1, input_sequence_length}, {}}]:copy(input[s+i-1]:reshape(input_sequence_length, input_feature_dim)) + batch_output[{i, {1,output[s+i-1]:nElement()}}]:copy(output[s+i-1]) + if params.input_temporalconv_stride > 0 then + xlength[i] = math.floor((input_sequence_length - params.input_temporalconv_width)/params.input_temporalconv_stride) + 1 -- Using temporal convolution + else + xlength[i] = input_sequence_length + end + ylength[i] = output[s+i-1]:nElement() + end +-- local debugger = require('fb.debugger') +-- debugger.enter() + + if params.temporal_sampling == 'TemporalSampling' then + for i = 1, batch_size do + xlength[i] = math.floor(xlength[i] / params.temporalsampling_stride) + if xlength[i] == 0 then + xlength[i] = 1 + end + end + elseif params.temporal_sampling == 'TemporalConvolution' then + for i = 1, batch_size do + xlength[i] = math.floor((xlength[i] - params.temporalconv_width) / params.temporalconv_stride) + 1 + end + end + + return batch_input, batch_output, xlength, ylength +end + +local log0 = -1000000. +function prepare_minibatch(input, output, s, t, x_vocab, y_vocab, end_symbol, conv_dW_size, params) + local input_max_len = 0 + local output_max_len = 0 + local batch_size = t - s + 1 + for i = 1, batch_size do + input_max_len = math.max(input_max_len, input[s+i-1]:nElement()) + output_max_len = math.max(output_max_len, output[s+i-1]:nElement()) + end + if end_symbol and end_symbol > 0 then + input_max_len = input_max_len + 1 + output_max_len = output_max_len + 1 + end + local batch_input = torch.Tensor(batch_size, input_max_len):fill(params.start_symbol) + local mask_input = torch.Tensor(batch_size, input_max_len):fill(log0) -- logspace + local batch_output = torch.Tensor(batch_size, output_max_len):fill(y_vocab) + local xlength = torch.Tensor(batch_size):zero() + local ylength = torch.Tensor(batch_size):zero() + for i = 1, batch_size do + batch_input[{i, {1,input[s+i-1]:nElement()}}]:copy(input[s+i-1]) + mask_input[{i, {1, input[s+i-1]:nElement()}}]:zero() -- logspace + batch_output[{i, {1,output[s+i-1]:nElement()}}]:copy(output[s+i-1]) + xlength[i] = input[s+i-1]:nElement() + ylength[i] = output[s+i-1]:nElement() + if end_symbol and end_symbol > 0 then + batch_input[{i, input[s+i-1]:nElement()+1}] = end_symbol + mask_input[{i, input[s+i-1]:nElement()+1}] = 0 + batch_output[{i, output[s+i-1]:nElement()+1}] = end_symbol + xlength[i] = xlength[i] + 1 + ylength[i] = ylength[i] + 1 + end + end + + if params.input_temporalconv_stride > 0 then + input_max_len = math.floor((input_max_len - 1)/params.input_temporalconv_stride) + 1 -- Using temporal convolution + mask_input:resize(batch_size, input_max_len):fill(log0) + for i = 1, batch_size do + xlength[i] = math.floor((xlength[i] - 1)/params.input_temporalconv_stride) + 1 -- Using temporal convolution + mask_input[{i, {1, xlength[i]}}] = 0 + end + end + + if conv_dW_size then + input_max_len = math.floor((input_max_len - 1) / conv_dW_size) + 1 + mask_input:resize(batch_size, input_max_len):fill(log0) + for i = 1, batch_size do + xlength[i] = math.floor((xlength[i] - 1) / conv_dW_size) + 1 + mask_input[{i, {1, xlength[i]}}] = 0 + end + end + + return batch_input, batch_output, xlength, ylength, mask_input +end + +function g_cloneManyTimes(net, T) + net:clearState() + local clones = {} + local params, gradParams = net:parameters() + local mem = torch.MemoryFile("w"):binary() + mem:writeObject(net) + for t = 1, T do + -- We need to use a new reader for each clone. + -- We don't want to use the pointers to already read objects. + local reader = torch.MemoryFile(mem:storage(), "r"):binary() + local clone = reader:readObject() + reader:close() + local cloneParams, cloneGradParams = clone:parameters() + for i = 1, #params do + cloneParams[i]:set(params[i]) + cloneGradParams[i]:set(gradParams[i]) + end + clones[t] = clone + collectgarbage() + end + mem:close() + return clones +end + +function toy_evaluate_predict_loss_mapping(ypred, yref, use_crossentropy, mapping_path, without_mapping) + local phones_mapping = {} + if without_mapping then + for _id = 1, 29 do + phones_mapping[_id] = _id + end + else + for line in io.lines('%s/phones.60-39.mapping' % mapping_path) do + splits = line:split('\t') + phones_mapping[tonumber(splits[1]) + 1] = splits[2] + end + end + + assert(#ypred == #yref) + + local count = 0 + local loss = 0 + local WER = 0 + local total_num_words = 0 + + local sequenceError = SequenceError() + for i = 1, #ypred do + local ypred_mapping = '' + local prevalue = -1 + for j = 1, ypred[i]:nElement() do + if use_crossentropy then + _value = torch.floor((ypred[i][j] - 1) / 3) + 1 + else + _value = ypred[i][j] + end + if phones_mapping[_value] then + ypred_mapping = ypred_mapping .. ' ' .. phones_mapping[_value] + end + end + + local yref_mapping = '' + local prevalue = -1 + for j = 1, yref[i]:nElement() do + if use_crossentropy then + _value = torch.floor((yref[i][j] - 1) / 3) + 1 + else + _value = yref[i][j] + end + if phones_mapping[_value] then -- and _value ~= prevalue then + yref_mapping = yref_mapping .. ' ' .. phones_mapping[_value] + end + end + + local word_errs , num_words = sequenceError:calculateWER(yref_mapping:gsub("^%s*(.-)%s*$", "%1"), ypred_mapping:gsub("^%s*(.-)%s*$", "%1")) + WER = WER + word_errs + total_num_words = total_num_words + num_words + + -- unmerged edit distance, count + loss = loss + EditDistance(torch.totable(ypred[i]:long()), torch.totable(yref[i]:long())) + count = count + yref[i]:nElement() + end + print('WER: ', WER, ' total num words: ', total_num_words) + print(' ==> test predict WER: ', WER / total_num_words * 100) + print(' ==> test predict edit distance (normalized): ', loss / count) + return WER / total_num_words * 100 +end + + +function toy_evaluate_predict_loss(ypred, yref) + assert(#ypred == #yref) + + local count = 0 + local loss = 0 + for i = 1, #ypred do + loss = loss + EditDistance(torch.totable(ypred[i]:long()), torch.totable(yref[i]:long())) + count = count + yref[i]:nElement() + end + return loss / count +end + +function generate_batches_by_length(input_train, batch_size, req_same_length) + --- input train must have sorted + local batches = {} + if req_same_length then + local count = 0 + local cur_len = input_train[1]:nElement() + local s = 1 + local t = 1 + for i = 1, #input_train do + if count < batch_size and input_train[i]:nElement() == cur_len then + t, count = i, count + 1 + else + table.insert(batches, {s, t}) + s, t = i, i + count = 1 + cur_len = input_train[i]:nElement() + end + end + if count > 0 then + table.insert(batches, {s, t}) + end + else + local num_batches = math.ceil(#input_train / batch_size) + for i = 1, num_batches do + local s = (i-1)*batch_size + 1 + local t = math.min(s + batch_size - 1, #input_train) + table.insert(batches, {s,t}) + end + end + return batches +end + +---- machine translation specific --- + +function load_data(filename, note, type) + local note = note or filename + print(string.format("--loading %s data from %s", note, filename)) + local data = {} + for line in io.lines(filename) do + splits = line:split(" ") + if type == 'Double' then + table.insert(data, torch.DoubleTensor(splits)) + else + table.insert(data, torch.LongTensor(splits)) + end + end + print(string.format(" %s data size: %d", note, #data)) + return data +end + +function load_vocab(filename, note) + local note = note or filename + print(string.format("--loading %s vocab from %s", note, filename)) + local vocab = {} + for line in io.lines(filename) do + splits = line:split(" ") + if splits[1] == '' then + vocab[tonumber(splits[2])] = ' ' + else + vocab[tonumber(splits[2])] = splits[1] + end + if (not splits[2]) or splits[3] then + print("error") + end + end + -- the actual size is #vocab + 1, since we use 0 as padding + print(string.format(" %s vocab size: %d", note, #vocab)) + return vocab +end + +function tensor_to_string(x) + local phrases = {} + for i = 1, x:nElement() do + table.insert(phrases, x[i]) + end + return table.concat(phrases, ' ') +end + +function decipher_tables(x, vocab) + local decipher_tables = {} + -- sed -r 's/(@@ )|(@@ ?$)//g' + for i = 1, #x do + local decipher_str = decipher(x[i]:view(-1, 1), vocab) + local decipher_str_merge = string.gsub(string.gsub(decipher_str, "@@ ", ""), "@@ ", "") + table.insert(decipher_tables, decipher_str_merge) + end + return decipher_tables +end + + +function decipher(x, ivocab) + local phrases = {} + for i = 1, x:nElement() do + local phrase = ivocab[torch.totable(x[i])[1]] or 'UNK' + table.insert(phrases, phrase) + end + return table.concat(phrases, ' ') +end + +function eval_bleu_score(refs, outputs, work_dir, iter, unk_as_error, config_dir, unk_symbol, is_string) + assert(#outputs == #refs) + local ref_filename + if is_string then + ref_filename = string.format("%s/refs-at-iter-%d.txt", work_dir, iter) + ref_f = assert(io.open(ref_filename, "w")) + for i = 1, #refs do + ref_f:write(refs[i], "\n") + end + ref_f:close() + else -- otherwise, table of tensor + if unk_as_error then + ref_filename = string.format("%s/refs-at-iter-%d-unk-as-error.txt", work_dir, iter) + ref_f = assert(io.open(ref_filename, "w")) + for i = 1, #refs do + local new_ref = refs[i]:long() - torch.cmul(refs[i]:long(), refs[i]:eq(unk_symbol):long()) + ref_f:write(table.concat(new_ref:totable(), " "), "\n") + end + ref_f:close() + else + ref_filename = string.format("%s/refs-at-iter-%d.txt", work_dir, iter) + ref_f = assert(io.open(ref_filename, "w")) + for i = 1, #refs do + ref_f:write(table.concat(refs[i]:totable(), " "), "\n") + end + ref_f:close() + end + end + + local output_filename + if unk_as_error then + output_filename = string.format("%s/results-at-iter-%d-unk-as-error.txt", work_dir, iter) + else + output_filename = string.format("%s/results-at-iter-%d.txt", work_dir, iter) + end + local out_f = assert(io.open(output_filename, "w")) + + if is_string then + for i = 1, #outputs do + out_f:write(outputs[i], "\n") + end + else -- otherwise, table of tensor + for i = 1, #outputs do + out_f:write(table.concat(outputs[i]:totable(), " "), "\n") + end + end + out_f:close() + if config_dir == nil then + config_dir = './' + end + local cmd = string.format("perl %s/data_processing/multi-bleu.perl %s < %s", config_dir, ref_filename, output_filename) + os.execute(cmd) + local cmd = string.format("perl %s/data_processing/multi-bleu.perl %s < %s", config_dir, ref_filename, output_filename) + local handle = io.popen(cmd) + local result = handle:read("*a") + handle:close() + _, str_end = string.find(result, 'EVALERR =') + return tonumber(string.sub(result, str_end + 1)) +end + +function eval_wer_score(refs, outputs, work_dir, iter, without_mapping) + assert(#outputs == #refs) + local ref_filename + + ref_filename = string.format("%s/refs-at-iter-%d.txt", work_dir, iter) + ref_f = assert(io.open(ref_filename, "w")) + for i = 1, #refs do + ref_f:write(i, ' ', table.concat(refs[i]:totable(), " "), "\n") + end + ref_f:close() + -- end + + local output_filename + output_filename = string.format("%s/results-at-iter-%d.txt", work_dir, iter) + local out_f = assert(io.open(output_filename, "w")) + for i = 1, #outputs do + out_f:write(i, ' ', table.concat(outputs[i]:totable(), " "), "\n") + end + out_f:close() + if without_mapping then + local cmd = string.format("sh /home/pshuang/work/tools/kaldi/egs/timit/s5/local/score_timit_womapping.sh %s %s %s %s", output_filename, ref_filename, work_dir, iter) + os.execute(cmd) + else + local cmd = string.format("sh /home/pshuang/work/tools/kaldi/egs/timit/s5/local/score_timit.sh %s %s %s %s", output_filename, ref_filename, work_dir, iter) + os.execute(cmd) + end +end + +function yref_totable(yref, ylength) + local bsize = yref:size(1) + local yref_tab = {} + for i = 1, bsize do + table.insert(yref_tab, torch.totable(yref[{i, {1, ylength[i]}}])) + end + return yref_tab +end + +function ctc_decodeOutput(predictions) + --[[ + Turns the predictions tensor into a list of the most likely tokens + NOTE: + to compute WER we strip the begining and ending spaces + --]] + local tokens = {} + local blankToken = 0 + local preToken = blankToken + -- The prediction is a sequence of likelihood vectors + local _, maxIndices = torch.max(predictions, 2) + maxIndices = maxIndices:float():squeeze() + + for i=1, maxIndices:size(1) do + local token = maxIndices[i] - 1 -- CTC indexes start from 1, while token starts from 0 + -- add token if it's not blank, and is not the same as pre_token + if token ~= blankToken and token ~= preToken then + table.insert(tokens, token) + end + preToken = token + end + return torch.Tensor(tokens) +end + +function ctc_beam_search(predictions, configs) + local beam_size = configs.beam_size or 40 + local use_avg_prob = configs.use_avg_prob + local softmax = nn.LogSoftMax():cuda() + local preds = softmax:forward(predictions):double() + + local T, V = preds:size(1), preds:size(2) + local B = {{}} + local B_inv = nil + local Pr = {0} + local pnb_ = {-torch.loginf()} + local pb_ = {0} + for t = 1, T do + local B_new, Pr_new, pnb_new, pb_new = {}, {}, {}, {} + local _, ind = torch.sort(torch.Tensor(Pr), true) + B_inv = {} + for i = 1, math.min(beam_size, ind:nElement()) do + local j = ind[i] + B_inv[table.concat(B[j], "-")] = j + end + for i = 1, math.min(beam_size, ind:nElement()) do + local j = ind[i] + local y = B[j] + local pnb = -torch.loginf() + if #y > 0 then + pnb = pnb_[j] + preds[{t, y[#y]+1}] + local y_1_str = '' + if #y > 1 then + y_1_str = table.concat(torch.totable(torch.Tensor(y)[{{1,#y-1}}]), "-") + end + local jj = B_inv[y_1_str] + if jj ~= nil then + local y_1 = B[jj] + if y_1_str:len() > 0 and y[#y] == y_1[#y_1] then + pnb = torch.logadd(pnb, pb_[jj] + preds[{t, y[#y]+1}]) + else + pnb = torch.logadd(pnb, Pr[jj] + preds[{t, y[#y]+1}]) + end + end + end + local pb = Pr[j] + preds[{t, 1}] -- 1 is the blank symbol + table.insert(B_new, torch.copy_array(y)) + table.insert(Pr_new, torch.logadd(pnb, pb)) + table.insert(pnb_new, pnb) + table.insert(pb_new, pb) + for v = 2, V do + pb = -torch.loginf() + if #y > 0 and v-1 == y[#y] then + pnb = pb_[j] + preds[{t, v}] + else + pnb = Pr[j] + preds[{t, v}] + end + table.insert(pb_new, pb) + table.insert(pnb_new, pnb) + table.insert(Pr_new, torch.logadd(pnb, pb)) + local y_ = torch.copy_array(y) + table.insert(y_, v-1) + table.insert(B_new, y_) + end + end + B = B_new + Pr = Pr_new + pnb_ = pnb_new + pb_ = pb_new + end + if use_avg_prob then + for i = 1, #Pr do + Pr[i] = Pr[i] / #(B[i]) + end + end + local _, indx = torch.sort(torch.Tensor(Pr), true) + return torch.Tensor(B[indx[1]]) +end + +function xent_decodeOutput(predictions) + --[[ + Turns the predictions tensor into a list of the most likely tokens + NOTE: + to compute WER we strip the begining and ending spaces + --]] + local tokens = {} + local _, maxIndices = torch.max(predictions, 2) + maxIndices = maxIndices:float():squeeze() + for i=1, maxIndices:size(1) do + local token = maxIndices[i] + table.insert(tokens, token) + end + return torch.Tensor(tokens) +end + +function file_exists(file) + local f = io.open(file, "rb") + if f then f:close() end + return f ~= nil +end + +function line_from(file) + if not file_exists(file) then return 0 end + local f = io.open(file, "r") + io.input(f) + local data = io.read() + io.close(f) + return data +end + +function lines_from(file, type) + if not file_exists(file) then return 0 end + local data = {} + for line in io.lines(file) do + line = string.gsub(line, "\n", "") + if type == 'Int' then + table.insert(data, tonumber(line)) + elseif type == 'phrases' then + local line_split = line:split(" ") + local subtable = {} + for i = 1, #line_split do + table.insert(subtable, tonumber(line_split[i])) + end + table.insert(data, table.concat(subtable, '-')) + else + table.insert(data, line) + end + end + return data +end + + +function Set(list) + local set = {} + for _, l in ipairs(list) do set[l] = true end + return set +end + +function tensor_to_table(input) + local table_input = input:totable() + return Set(table_input) +end + +function table.slice(tbl, first, last, step) + local sliced = {} + for i = first or 1, last or #tbl, step or 1 do + sliced[#sliced+1] = tbl[i] + end + return sliced +end diff --git a/fairseq/models/utils.lua b/fairseq/models/utils.lua index 0de0d28..5be8964 100644 --- a/fairseq/models/utils.lua +++ b/fairseq/models/utils.lua @@ -5,6 +5,9 @@ -- the root directory of this source tree. An additional grant of patent rights -- can be found in the PATENTS file in the same directory. -- +-- Copyright (c) Microsoft Corporation. All rights reserved. +-- Licensed under the BSD License. +-- --[[ -- -- Shared utility functions used for model construction. @@ -112,6 +115,11 @@ function mutils.loadLegacyModel(path, typename) return model end +function mutils.loadModel(path, typename) + return torch.load(path) +end + + function mutils.sendtobuf(data, buffer) assert(data and torch.isTensor(data)) assert(buffer and torch.isTensor(buffer)) diff --git a/fairseq/models/window_attn.lua b/fairseq/models/window_attn.lua new file mode 100755 index 0000000..4347f73 --- /dev/null +++ b/fairseq/models/window_attn.lua @@ -0,0 +1,128 @@ +-- Copyright (c) Microsoft Corporation. All rights reserved. +-- Licensed under the MIT License. +-- +--[[ +-- +-- Reordering layer +-- +--]] +-- +require("nn") +require("nngraph") + +local winAttn, parent = torch.class('nn.winAttn', 'nn.Container') + +function make_win_unit(input_size, kW) + local inputs = {} + table.insert(inputs, nn.Identity()()) + local x = unpack(inputs) -- B * kW * d + local reshaped_x = nn.Reshape(input_size*kW, true)(x) -- B * (kW*d) + local weight = nn.Sigmoid()(nn.Linear(input_size*kW, kW)(reshaped_x)):annotate{name = 'winAtt_weight'} -- B * kW + weight = nn.Replicate(input_size, 3)(weight) -- B * kW * d + local output = nn.CMulTable()({x, weight}) -- B * kW * d + output = nn.Tanh()(nn.Sum(2)(output)) -- B * d + return nn.gModule(inputs, {output}) +end + +function winAttn:__init(input_size, kW, use_middle) + parent.__init(self) + self.gradInput = torch.Tensor() + self.output = torch.Tensor() + self.padded_input = torch.Tensor() + + self.input_size = input_size + self.kW = kW + if use_middle then + local width = math.floor(self.kW / 2) + self.kW = width * 2 + 1 + self.padding = nn.Sequential():add(nn.Padding(2, -width)):add(nn.Padding(2, width)) + else + self.padding = nn.Padding(2, 1 - self.kW) + end + + self:add(self.padding) + self.win_unit = make_win_unit(input_size, self.kW) + self:add(self.win_unit) + self.win_unit_clones = {} + self.max_T = 0 +end + +function winAttn:updateOutput(input) + self.recompute_backward = true + + local T = input:size(2) + if self.max_T < T then + self.win_unit:clearState() + local more_win_units = g_cloneManyTimes(self.win_unit, T - self.max_T) + for i = 1, T - self.max_T do + table.insert(self.win_unit_clones, more_win_units[i]) + end + self.max_T = T + end + for t = 1, T do + self.win_unit_clones[t]:clearState() + end + self.padded_input = self.padding:updateOutput(input) + self.output = input.new(input:size()):zero() +-- local mutils = require 'fairseq.models.utils' + for t = 1, T do + local x = self.padded_input[{{}, {t, t+self.kW-1}, {}}] + local y = self.win_unit_clones[t]:updateOutput(x) + self.output[{{}, t, {}}]:add(y) +-- local weight = mutils.findAnnotatedNode(self.win_unit_clones[t], 'winAtt_weight') +-- print('t', t, 'weight', weight.output) +-- print(self.win_unit_clones[t].forwardnodes[6].data.input[1]) + end + return self.output +end + +function winAttn:backward(input, gradOutput, scale) + local scale = scale or 1 + self.recompute_backward = false + local grad_padded_input = self.padded_input.new(self.padded_input:size()):zero() + local T = input:size(2) + for t = 1, T do + local x = self.padded_input[{{}, {t, t+self.kW-1}, {}}] + local grad_x = self.win_unit_clones[t]:backward(x, gradOutput[{{}, t, {}}]) + grad_padded_input[{{}, {t, t+self.kW-1}, {}}]:add(grad_x) + end + self.gradInput = self.padding:backward(input, grad_padded_input, scale) + return self.gradInput +end + +function winAttn:updateGradInput(input, gradOutput) + if self.recompute_backward then + self:backward(input, gradOutput, 1.0) + end + return self.gradInput +end + +function winAttn:accGradParameters(input, gradOutput, scale) + if self.recompute_backward then + self:backward(input, gradOutput, scale) + end +end + +function winAttn:training() + parent.training(self) + for t = 1, self.max_T do + self.win_unit_clones[t]:training() + end +end + +function winAttn:evaluate() + parent.evaluate(self) + for t = 1, self.max_T do + self.win_unit_clones[t]:evaluate() + end +end + +function winAttn:clearState() + parent.clearState(self) + self.output:set() + self.padded_input:set() + self.gradInput:set() + for t = 1, self.max_T do + self.win_unit_clones[t]:clearState() + end +end diff --git a/fairseq/torchnet/ResumableDPOptimEngine.lua b/fairseq/torchnet/ResumableDPOptimEngine.lua index 9575e18..bddf99b 100644 --- a/fairseq/torchnet/ResumableDPOptimEngine.lua +++ b/fairseq/torchnet/ResumableDPOptimEngine.lua @@ -5,6 +5,9 @@ -- the root directory of this source tree. An additional grant of patent rights -- can be found in the PATENTS file in the same directory. -- +-- Copyright (c) Microsoft Corporation. All rights reserved. +-- Licensed under the BSD License. +-- --[[ -- -- A version of OptimEngine that implements data parallelism and @@ -16,7 +19,7 @@ local tnt = require 'torchnet' local argcheck = require 'argcheck' local utils = require 'fairseq.utils' local threads = require 'threads' - +local mutils = require 'fairseq.models.utils' local cuda = utils.loadCuda() local ResumableDPOptimEngine = @@ -215,10 +218,24 @@ ResumableDPOptimEngine.test = argcheck{ _G.prepareSample(sample) _G.model:resizeCriterionWeights( _G.criterion, _G.critweights, sample) + -- HACK for not having OOM + local group_size + if torch.typename(_G.model) == 'NPMTModel' then + local npmt = mutils.findAnnotatedNode(_G.model:network(), 'npmt') + group_size = npmt.group_size + npmt.group_size = 64 + end local net = _G.model:network() local crit = _G.criterion net:forward(sample.input) crit:forward(net.output, sample.target) + if torch.typename(_G.model) == 'NPMTModel' then + local npmt = mutils.findAnnotatedNode(_G.model:network(), 'npmt') + npmt.group_size = group_size + npmt:clearState() + _G.model:network():clearState() + end + collectgarbage() collectgarbage() return crit.output end, @@ -350,7 +367,7 @@ ResumableDPOptimEngine.doTrain = argcheck{ if sample then state.ntokens = state.ntokens + sample.ntokens self.pool:addjob(shardid, - function(optconfig, sample, clipv, prevn) + function(optconfig, sample, clipv, prevn, epoch) -- Clip gradients and update parameters. -- Note: this is being done for the -- previous sample. @@ -362,7 +379,6 @@ ResumableDPOptimEngine.doTrain = argcheck{ optconfig.method(_G.feval, _G.params, optconfig, _G.optstate) end - -- Process the current sample. _G.prepareSample(sample) _G.model:resizeCriterionWeights( @@ -378,13 +394,14 @@ ResumableDPOptimEngine.doTrain = argcheck{ crit:backward(net.output, sample.target) net:backward(sample.input, crit.gradInput) collectgarbage() + collectgarbage() return crit.output end, function(loss) state.loss = state.loss + loss end, state.epoch_t > 0 and state.optconfig or nil, - sample, clipv, prevn + sample, clipv, prevn, state.epoch ) end end diff --git a/fairseq/torchnet/hooks.lua b/fairseq/torchnet/hooks.lua index 44425bc..466787f 100644 --- a/fairseq/torchnet/hooks.lua +++ b/fairseq/torchnet/hooks.lua @@ -5,6 +5,9 @@ -- the root directory of this source tree. An additional grant of patent rights -- can be found in the PATENTS file in the same directory. -- +-- Copyright (c) Microsoft Corporation. All rights reserved. +-- Licensed under the BSD License. +-- --[[ -- -- Common torchnet engine hooks. The general format for functions declared @@ -185,11 +188,16 @@ hooks.runGeneration = argcheck{ local scorer = clib.bleu(dict:getPadIndex(), dict:getEosIndex()) local fp = outfile and io.open(outfile, 'a') local targetBuf = torch.IntTensor() + local output_counts, num_segments = 0, 0 for samples in iterator() do -- We don't shard generation computeSampleStats({samples = samples}) local sample = samples[1] - local hypos, scores, attns, _ = generate(model, sample) + local hypos, scores, attns, _, output_count, num_segment = generate(model, sample) + if output_count and num_segment then + output_counts = output_counts + output_count + num_segments = num_segments + num_segment + end local targetTT = sample.target:t() local targetT = targetBuf:resizeAs(targetTT):copy(targetTT) local beam = #hypos / sample.bsz @@ -233,7 +241,9 @@ hooks.runGeneration = argcheck{ scorer:add(ref, hypo:int()) end end - + if num_segments > 0 then + print(string.format("avg. phrase size %f", output_counts / num_segments)) + end if fp then fp:close() end @@ -263,6 +273,7 @@ hooks.onCheckpoint = argcheck{ isAnnealing = false, prevvalloss = nil, bestvalloss = nil, + bestvalbleu = nil, } end @@ -293,8 +304,8 @@ hooks.onCheckpoint = argcheck{ stats['wordspersec'] = lossMeter.n / cptime stats['current_lr'] = state.optconfig.learningRate * config.lrscale - local loss = lossMeter:value() / math.log(2) - local ppl = math.pow(2, loss) + local loss = lossMeter:value() + local ppl = math.min(math.pow(2, loss / math.log(2)), 1000) -- Hack print(string.format( '%s | trainloss %8.2f | train ppl %8.2f', logPrefix, loss, ppl) @@ -308,8 +319,8 @@ hooks.onCheckpoint = argcheck{ for name, set in pairs(testsets) do meter:reset() runTest(state, set, meter) - local loss = meter:value() / math.log(2) - local ppl = math.pow(2, loss) + local loss = meter:value() + local ppl = math.pow(2, loss / math.log(2)) str2print = string.format('%s | %sloss %8.2f | %s ppl %8.2f', str2print, name, loss, name, ppl) stats[name .. 'ppl'] = ppl @@ -351,6 +362,7 @@ hooks.onCheckpoint = argcheck{ local valloss = stats['validloss'] + local valbleu = stats['validbleu'] -- Save model and best model if not config.nosave then @@ -374,6 +386,19 @@ hooks.onCheckpoint = argcheck{ state._onCheckpoint.bestvalloss = valloss end + if valbleu + and (not state._onCheckpoint.bestvalbleu + or valbleu > state._onCheckpoint.bestvalbleu) then + local bestmodelpath = plpath.join(config.savedir, + 'model_bestbleu.th7') + if utils.retry(3, engine.saveModel, engine, bestmodelpath) + then + print(string.format( + '%s | saved new best bleu model to %s', logPrefix, + bestmodelpath)) + end + state._onCheckpoint.bestvalbleu = valbleu + end end io.stdout:flush() diff --git a/generate-lines.lua b/generate-lines.lua index 2a84869..7e2d109 100644 --- a/generate-lines.lua +++ b/generate-lines.lua @@ -5,6 +5,9 @@ -- the root directory of this source tree. An additional grant of patent rights -- can be found in the PATENTS file in the same directory. -- +-- Copyright (c) Microsoft Corporation. All rights reserved. +-- Licensed under the BSD License. +-- --[[ -- -- Hypothesis generation script with text file input, processed line-by-line. @@ -51,6 +54,8 @@ cmd:option('-freqthreshold', -1, 'the minimum frequency for an alignment candidate in order' .. 'to be considered (default no limit)') cmd:option('-fconvfast', false, 'make fconv model faster') +cmd:option('-lm_weight', 0.0, 'external lm weight.') +cmd:option('-lm_path', "", 'external lm path.') local config = cmd:parse(arg) @@ -62,6 +67,13 @@ print(string.format('| [target] Dictionary: %d types', config.dict:size())) config.srcdict = torch.load(config.sourcedict) print(string.format('| [source] Dictionary: %d types', config.srcdict:size())) +if config.lm_weight > 0 and config.lm_path:len() > 0 then +-- os.execute('./compile_lm.sh') + require "lua_lm" + config['lm'] = create_lm_instance(config.lm_path) +end + + if config.aligndictpath ~= '' then config.aligndict = tnt.IndexedDatasetReader{ indexfilename = config.aligndictpath .. '.idx', @@ -145,7 +157,9 @@ local dataset = tnt.DatasetIterator{ } local model -if config.model ~= '' then +if config.model == 'npmt' then + model = mutils.loadModel(config.path, config.model) +elseif config.model ~= '' then model = mutils.loadLegacyModel(config.path, config.model) else model = require( @@ -215,7 +229,15 @@ until dict:getIndex(runk) == dict:getUnkIndex() for sample in dataset() do sample.bsz = 1 - local hypos, scores, attns = model:generate(config, sample, searchf) + local hypos, scores, attns, t, num_counts, num_segments + if config.model == 'npmt' then + -- TODO fetch reordering layer weights + config.verbose = true + hypos, scores, attns, t, num_counts, num_segments = model:generate(config, sample, searchf) + print(string.format("avg. phrase size %f", num_counts / num_segments)) + else + hypos, scores, attns, t = model:generate(config, sample, searchf) + end -- Print results local sourceString = config.srcdict:getString(sample.source:t()[1]) diff --git a/generate.lua b/generate.lua index 1f125f6..b532c2b 100644 --- a/generate.lua +++ b/generate.lua @@ -5,6 +5,9 @@ -- the root directory of this source tree. An additional grant of patent rights -- can be found in the PATENTS file in the same directory. -- +-- Copyright (c) Microsoft Corporation. All rights reserved. +-- Licensed under the BSD License. +-- --[[ -- -- Batch hypothesis generation script. @@ -61,6 +64,10 @@ cmd:option('-freqthreshold', -1, 'the minimum frequency for an alignment candidate in order' .. 'to be considered (default no limit)') cmd:option('-fconvfast', false, 'make fconv model faster') +cmd:option('-lm_weight', 0.0, 'external lm weight.') +cmd:option('-lm_path', "", 'external lm path.') +cmd:option('-verbose', false, 'True print test_mode details') + local cuda = utils.loadCuda() @@ -70,6 +77,14 @@ if cuda.cutorch then cutorch.manualSeed(config.seed) end + +if config.lm_weight > 0 and config.lm_path:len() > 0 then + -- os.execute('./compile_lm.sh') + require "lua_lm" + config['lm'] = create_lm_instance(config.lm_path) +end + + local function accTime() local total = {} return function(times) @@ -133,7 +148,9 @@ local _, test = data.loadCorpus{config = config, testsets = {config.dataset}} local dataset = test[config.dataset] local model -if config.model ~= '' then +if config.model == 'npmt' then + model = mutils.loadModel(config.path, config.model) +elseif config.model ~= '' then model = mutils.loadLegacyModel(config.path, config.model) else model = require( @@ -203,6 +220,7 @@ local addBleu = accBleu(config.beam, dict) local addTime = accTime() local timer = torch.Timer() local nsents, ntoks, nbatch = 0, 0, 0 +local total_count, total_segments, num_counts, num_segments = 0, 0, 0, 0 local state = {} for samples in dataset() do if (nbatch % nparts == partidx - 1) then @@ -210,7 +228,14 @@ for samples in dataset() do state.samples = samples computeSampleStats(state) local sample = state.samples[1] - local hypos, scores, attns, t = model:generate(config, sample, searchf) + local hypos, scores, attns, t + if config.model == 'npmt' then + hypos, scores, attns, t, num_counts, num_segments = model:generate(config, sample, searchf) + total_count = total_count + num_counts + total_segments = total_segments + num_segments + else + hypos, scores, attns, t = model:generate(config, sample, searchf) + end nsents = nsents + sample.bsz ntoks = ntoks + sample.ntokens addTime(t) @@ -227,6 +252,9 @@ for samples in dataset() do end nbatch = nbatch + 1 end +if num_segments > 0 then + print(string.format("avg. phrase size %f", total_count / total_segments)) +end -- report overall stats local elapsed = timer:time().real diff --git a/npmt.png b/npmt.png new file mode 100755 index 0000000..96e1656 Binary files /dev/null and b/npmt.png differ diff --git a/rocks/fairseq-scm-1.rockspec b/rocks/fairseq-scm-1.rockspec index 221c3d0..2981c1b 100644 --- a/rocks/fairseq-scm-1.rockspec +++ b/rocks/fairseq-scm-1.rockspec @@ -1,7 +1,7 @@ package = 'fairseq' version = 'scm-1' source = { - url = 'git://github.com/facebookresearch/fairseq', + url = 'git://github.com:posenhuang/NPMT.git', tag = 'master', } description = { diff --git a/train.lua b/train.lua index 9b412f2..cca246f 100644 --- a/train.lua +++ b/train.lua @@ -5,6 +5,9 @@ -- the root directory of this source tree. An additional grant of patent rights -- can be found in the PATENTS file in the same directory. -- +-- Copyright (c) Microsoft Corporation. All rights reserved. +-- Licensed under the BSD License. +-- --[[ -- -- Main training script. @@ -32,7 +35,7 @@ local cmd = torch.CmdLine() cmd:option('-sourcelang', 'de', 'source language') cmd:option('-targetlang', 'en', 'target language') cmd:option('-datadir', 'data-bin') -cmd:option('-model', 'avgpool', 'model type {avgpool|blstm|conv|fconv}') +cmd:option('-model', 'avgpool', 'model type {avgpool|blstm|bgru|conv|fconv|npmt}') cmd:option('-nembed', 256, 'dimension of embeddings and attention') cmd:option('-noutembed', 256, 'dimension of the output embeddings') cmd:option('-nhid', 256, 'number of hidden units per layer') @@ -101,6 +104,37 @@ cmd:option('-fconv_kwidths', '', cmd:option('-fconv_klmwidths', '', 'comma-separated list of kernel widths for convolutional language model') +-- Options for NPMT +cmd:option('-max_segment_len', 6, 'maximum segment length in the output') +cmd:option('-num_lower_win_layers', 0, 'reorder layer') +cmd:option('-use_win_middle', true, 'reorder layer with window centered at t') +cmd:option('-dec_unit_size', 256, 'number of hidden units per layer in decoder (uni-directional)') +cmd:option('-word_weight', 0.5, 'Use word weight.') +cmd:option('-lm_weight', 0.0, 'external lm weight.') +cmd:option('-lm_path', "", 'external lm path.') +cmd:option('-use_resnet_enc', false, 'use resnet connections in enc') +cmd:option('-use_resnet_dec', false, 'use resnet connections in dec') +cmd:option('-npmt_dropout', 0, 'npmt dropout factor') +cmd:option('-rnn_mode', "LSTM", 'or GRU') +cmd:option('-use_cuda', true, 'use cuda') +cmd:option('-beam', 10, 'beam size') +cmd:option('-group_size', 512, 'group size') +cmd:option('-use_accel', false, 'use C++/CUDA acceleration') +cmd:option('-conv_kW_size', 3, 'kernel width for temporal conv layer') +cmd:option('-conv_dW_size', 2, 'kernel stride for temporal conv layer') +cmd:option('-num_lower_conv_layers', 0, 'num lower temporal conv layers') +cmd:option('-num_mid_conv_layers', 0, 'num mid temporal conv layers') +cmd:option('-num_high_conv_layers', 0, 'num higher temporal conv layers') +cmd:option('-win_attn_type', 'ori', 'ori: original') +cmd:option('-reset_lrate', false, 'True reset learning rate after reloading') +cmd:option('-use_nnlm', false, 'True use a separated RNN') + + + +--chowang: we don't need the following anymore? +--cmd:option('-unk_symbol', 1, 'unk symbol id') +--cmd:option('-start_symbol', 2, 'start symbol id') +--cmd:option('-end_symbol', 3, 'end symbol id') local config = cmd:parse(arg) @@ -133,6 +167,7 @@ assert(config.ngpus >= 1 and config.ngpus <= cuda.cutorch.getDeviceCount()) -- Effective batchsize equals to the base batchsize * ngpus config.batchsize = config.batchsize * config.ngpus config.maxbatch = config.maxbatch * config.ngpus +print(config) ------------------------------------------------------------------- -- Load data @@ -160,7 +195,7 @@ end local train, test = data.loadCorpus{ config = config, trainsets = {'train'}, - testsets = {'valid', 'test'}, + testsets = {'valid'}, } local corpus = { train = train.train, @@ -205,7 +240,12 @@ local make_criterion_fn = function(id) local padindex = config.dict:getIndex(config.dict.pad) local critweights = torch.ones(config.dict:size()):cuda() critweights[padindex] = 0 - local criterion = nn.CrossEntropyCriterion(critweights, false):cuda() + local criterion + if config.model ~= 'npmt' then + criterion = nn.CrossEntropyCriterion(critweights, false):cuda() + else + criterion = nn.DummyCriterion(critweights, false):cuda() + end return criterion, critweights end @@ -257,6 +297,7 @@ if config.optim == 'sgd' then end elseif config.optim == 'adam' then optalgConfig.method = optim.adam + config.minlr = 1e-5 elseif config.optim == 'nag' then optalgConfig.method = require('fairseq.optim.nag') optalgConfig.momentum = config.momentum @@ -265,6 +306,13 @@ else error('wrong optimization algorithm') end +if config.model == 'npmt' then + optalgConfig.prune_schedule = config.prune_schedule + optalgConfig.prune_schedule_start_epoch = config.prune_schedule_start_epoch + optalgConfig.schedule_max_segment_len = config.schedule_max_segment_len + optalgConfig.max_segment_len = config.max_segment_len +end + local runGeneration, genconfig, gensets = nil, nil, {} if not config.nobleu or config.validbleu then genconfig = pltablex.copy(config) @@ -286,8 +334,10 @@ if config.validbleu then generate = function(model, sample) genconfig.minlen = 1 genconfig.maxlen = genconfig._maxlen - local searchf = search.greedy(model:type(), genconfig.dict, - genconfig.maxlen) + local searchf = {} + if config.model ~= 'npmt' then + searchf = search.greedy(model:type(), genconfig.dict, genconfig.maxlen) + end return model:generate(genconfig, sample, searchf) end, } @@ -314,7 +364,7 @@ end engine.hooks.onStartEpoch = hooks.shuffleData(seed) engine.hooks.onJumpToEpoch = hooks.shuffleData(seed) -local annealing = (config.optim == 'sgd' or config.optim == 'nag') +local annealing = (config.optim == 'sgd' or config.optim == 'nag' or config.optim == 'adam') local onCheckpoint = hooks.call{ function(state) state.checkpoint = state.checkpoint + 1 @@ -354,8 +404,8 @@ engine.hooks.onUpdate = hooks.call{ }, function(state) if timeMeter.n == config.log_interval then - local loss = lossMeter:value() / math.log(2) - local ppl = math.pow(2, loss) + local loss = lossMeter:value() + local ppl = math.pow(2, loss / math.log(2)) local elapsed = timeMeter.n * timeMeter:value() local statsstr = string.format( '| epoch %03d | %07d updates | words/s %7d' .. @@ -400,6 +450,13 @@ if plpath.isfile(lastStatePath) and not config.nosave then -- Support modifying the maxepoch setting during resume engine.hooks.onResume = function(state) state.maxepoch = config.maxepoch + state.maxbatch = config.maxbatch + state.group_size = config.group_size + state.optconfig = optalgConfig + if config.reset_lrate then + print('Reset lr to ', config.lr) + state.optconfig.learningRate = config.lr + end end engine:resume{ @@ -416,41 +473,50 @@ else end local function runFinalEval() - -- Evaluate the best network on the supplied test set - local path = plpath.join(config.savedir, 'model_best.th7') - local best_model = torch.load(path) - - genconfig.batchsize = 1 - genconfig.minlen = 1 - genconfig.maxlen = genconfig._maxlen - - for _, beam in ipairs({1, 5, 10, 20}) do - genconfig.beam = beam - if not config.notext then - genconfig.outfile = plpath.join( - config.savedir, string.format('gen-b%02d.txt', beam) - ) + -- TODO + local checkpoint_paths = {'model_best.th7', 'model_bestbleu.th7'} + for icheckpoint_path = 1, #checkpoint_paths do + print('checkpoint', checkpoint_paths[icheckpoint_path]) + local path = plpath.join(config.savedir, checkpoint_paths[icheckpoint_path]) + local best_model = torch.load(path) + + genconfig.batchsize = 1 + genconfig.minlen = 1 + genconfig.maxlen = genconfig._maxlen + + for _, beam in ipairs({1}) do + genconfig.beam = beam + if not config.notext then + genconfig.outfile = plpath.join( + config.savedir, string.format('gen-b%02d.txt', beam) + ) + end + local searchf = {} + if config.model ~= 'npmt' then + searchf = search.beam{ + ttype = best_model:type(), + dict = genconfig.dict, + srcdict = genconfig.srcdict, + beam = genconfig.beam + } + end + local _, result = hooks.runGeneration{ + model = best_model, + dict = genconfig.dict, + generate = function(model, sample) + return model:generate(genconfig, sample, searchf) + end, + outfile = genconfig.outfile, + srcdict = config.srcdict, + }(gensets.test) + print(string.format('| Test with beam=%d: %s', beam, result)) + io.stdout:flush() end - local searchf = search.beam{ - ttype = best_model:type(), - dict = genconfig.dict, - srcdict = genconfig.srcdict, - beam = genconfig.beam - } - local _, result = hooks.runGeneration{ - model = best_model, - dict = genconfig.dict, - generate = function(model, sample) - return model:generate(genconfig, sample, searchf) - end, - outfile = genconfig.outfile, - srcdict = config.srcdict, - }(gensets.test) - print(string.format('| Test with beam=%d: %s', beam, result)) - io.stdout:flush() end end + + if not config.nobleu and not config.nosave then engine:executeAll( function(id)