Skip to content

Commit

Permalink
Merge pull request #1 from posenhuang/NPMT
Browse files Browse the repository at this point in the history
NPMT
  • Loading branch information
posenhuang committed Jan 29, 2018
2 parents 7d017f0 + a6f412b commit 244053c
Show file tree
Hide file tree
Showing 24 changed files with 3,872 additions and 207 deletions.
23 changes: 21 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,30 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
#
# Copyright (c) Microsoft Corporation. All rights reserved
# Licensed under the BSD License

CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR)
CMAKE_POLICY(VERSION 2.6)
CMAKE_MINIMUM_REQUIRED(VERSION 2.8 FATAL_ERROR)
CMAKE_POLICY(VERSION 2.8)

FIND_PACKAGE(Torch REQUIRED)
FIND_PACKAGE(OpenMP)
FIND_PACKAGE(CUDA REQUIRED)

SET(CMAKE_CXX_FLAGS "-std=c++11 -Ofast")
IF(OpenMP_FOUND)
SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
ENDIF()

SET(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3 -shared -Xcompiler -fPIC
-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35
-gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50
-gencode arch=compute_52,code=sm_52 -gencode arch=compute_52,code=compute_52")

INCLUDE_DIRECTORIES(${CMAKE_PREFIX_PATH}/include/THC)

# C++ library
IF(APPLE)
SET(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
Expand All @@ -25,6 +36,14 @@ FILE(GLOB CPPSRC fairseq/clib/*.cpp)
ADD_LIBRARY(fairseq_clib SHARED ${CPPSRC})
INSTALL(TARGETS fairseq_clib DESTINATION "${ROCKS_LIBDIR}")

ADD_LIBRARY(dp_lib SHARED fairseq/models/c_sample_dp.cc)
INSTALL(TARGETS dp_lib DESTINATION "${ROCKS_LIBDIR}")

CUDA_ADD_LIBRARY(compute_logpy_lib SHARED fairseq/models/compute_logpy.cu)
INSTALL(TARGETS compute_logpy_lib DESTINATION "${ROCKS_LIBDIR}")

TARGET_LINK_LIBRARIES(compute_logpy_lib ${CUDA_LIBRARIES})

# Lua library
INSTALL(DIRECTORY "fairseq" DESTINATION "${ROCKS_LUADIR}" FILES_MATCHING PATTERN "*.lua")

Expand Down
255 changes: 109 additions & 146 deletions README.md

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion data/prepare-iwslt14.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ for l in $src $tgt; do
perl $TOKENIZER -threads 8 -l $l > $tmp/$tok
echo ""
done
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175

# Only use up to 50 as in https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
# and https://github.com/rizar/actor-critic-public
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 50
for l in $src $tgt; do
perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l
done
Expand Down
34 changes: 34 additions & 0 deletions data/prepare-iwslt15.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/bin/sh
# Copied from https://github.com/tensorflow/nmt/blob/master/nmt/scripts/download_iwslt15.sh
#
# Download small-scale IWSLT15 Vietnames to English translation data for NMT
# model training.
#
# Usage:
# ./download_iwslt15.sh path-to-output-dir
#
# If output directory is not specified, "./iwslt15" will be used as the default
# output directory.

OUT_DIR="${1:-iwslt15}"
SITE_PREFIX="https://nlp.stanford.edu/projects/nmt/data"

mkdir -v -p $OUT_DIR

# Download iwslt15 small dataset from standford website.
echo "Download training dataset train.en and train.vi."
curl -o "$OUT_DIR/train.en" "$SITE_PREFIX/iwslt15.en-vi/train.en"
curl -o "$OUT_DIR/train.vi" "$SITE_PREFIX/iwslt15.en-vi/train.vi"

echo "Download dev dataset tst2012.en and tst2012.vi."
curl -o "$OUT_DIR/tst2012.en" "$SITE_PREFIX/iwslt15.en-vi/tst2012.en"
curl -o "$OUT_DIR/tst2012.vi" "$SITE_PREFIX/iwslt15.en-vi/tst2012.vi"

echo "Download test dataset tst2013.en and tst2013.vi."
curl -o "$OUT_DIR/tst2013.en" "$SITE_PREFIX/iwslt15.en-vi/tst2013.en"
curl -o "$OUT_DIR/tst2013.vi" "$SITE_PREFIX/iwslt15.en-vi/tst2013.vi"

echo "Download vocab file vocab.en and vocab.vi."
curl -o "$OUT_DIR/vocab.en" "$SITE_PREFIX/iwslt15.en-vi/vocab.en"
curl -o "$OUT_DIR/vocab.vi" "$SITE_PREFIX/iwslt15.en-vi/vocab.vi"

Binary file added de-en_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 25 additions & 0 deletions fairseq/models/DummyCriterion.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
-- Copyright (c) Microsoft Corporation. All rights reserved.
-- Licensed under the MIT License.
--
--[[
--
-- Dummy Criterion
--
--]]

local DummyCriterion, parent = torch.class('nn.DummyCriterion', 'nn.Criterion')

function DummyCriterion:__init()
parent.__init(self)
end

function DummyCriterion:updateOutput(input, target)
self.output = torch.mean(input)
return self.output
end

function DummyCriterion:updateGradInput(input, target)
local n = input:nElement()
self.gradInput = input.new(input:size()):fill(1.0/n)
return self.gradInput
end
11 changes: 7 additions & 4 deletions fairseq/models/avgpool_model.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
-- the root directory of this source tree. An additional grant of patent rights
-- can be found in the PATENTS file in the same directory.
--
-- Copyright (c) Microsoft Corporation. All rights reserved.
-- Licensed under the BSD License.
--
--[[
--
-- This model closely follows the conditional setup of rnn-lib v1, with -name
Expand Down Expand Up @@ -101,7 +104,7 @@ AvgpoolModel.makeAttention = argcheck{
local encoderOutPooled, encoderOutSingle = encoderOut:split(2)

-- Projection of previous hidden state onto source word space
local prevhProj = nn.Linear(config.nhid, config.nembed)(prevh)
local prevhProj = nn.Linear(config.dec_unit_size, config.nembed)(prevh)
local decoderRep = nn.CAddTable()({prevhProj, input})

-- Compute scores (usually denoted with alpha) using a simple dot
Expand Down Expand Up @@ -148,7 +151,7 @@ AvgpoolModel.makeDecoderRNN = argcheck{
local rnn = nn.CLSTM{
attention = attnmodule,
inputsize = config.nembed,
hidsize = config.nhid,
hidsize = config.dec_unit_size,
nlayer = config.nlayer,
winitfun = function(network)
rmutils.defwinitfun(network, config.init_range)
Expand All @@ -171,8 +174,8 @@ AvgpoolModel.makeDecoderRNN = argcheck{
end

local scaleHidden = nn.Identity()
if config.nhid ~= config.nembed then
scaleHidden = nn.Linear(config.nhid, config.nembed)
if config.dec_unit_size ~= config.nembed then
scaleHidden = nn.Linear(config.dec_unit_size, config.nembed)
end

local decoderRNNOut = scaleHidden(
Expand Down
174 changes: 174 additions & 0 deletions fairseq/models/bgru_model.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
-- Copyright (c) 2017-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the LICENSE file in
-- the root directory of this source tree. An additional grant of patent rights
-- can be found in the PATENTS file in the same directory.
--
-- Copyright (c) Microsoft Corporation. All rights reserved.
-- Licensed under the BSD License.
--
--[[
--
-- This model uses a bi-directional LSTM encoder. The direction is reversed
-- between layers and two separate columns run in parallel: one on the normal
-- input and one on the reversed input (as described in
-- http://arxiv.org/abs/1606.04199).
--
-- The attention mechanism and the decoder setup are identical to the avgpool
-- model.
--
--]]

require 'nn'
require 'rnnlib'
local usecudnn = pcall(require, 'cudnn')
local argcheck = require 'argcheck'
local mutils = require 'fairseq.models.utils'
local rutils = require 'rnnlib.mutils'

local BGRUModel = torch.class('BGRUModel', 'AvgpoolModel')

BGRUModel.makeEncoderColumn = argcheck{
{name='self', type='BGRUModel'},
{name='config', type='table'},
{name='inith', type='nngraph.Node'},
{name='input', type='nngraph.Node'},
{name='nlayers', type='number'},
call = function(self, config, inith, input, nlayers)
local rnnconfig = {
inputsize = config.nembed,
hidsize = config.nhid,
nlayer = 1,
winitfun = function(network)
rutils.defwinitfun(network, config.init_range)
end,
usecudnn = usecudnn,
}

local rnn = nn.GRU(rnnconfig)
rnn.saveHidden = false
local output = nn.SelectTable(-1)(nn.SelectTable(2)(
rnn({inith, input}):annotate{name = 'encoderRNN'}
))
rnnconfig.inputsize = config.nhid

for i = 2, nlayers do
if config.dropout_hid > 0 then
output = nn.MapTable(nn.Dropout(config.dropout_hid))(output)
end
local rnn = nn.GRU(rnnconfig)
rnn.saveHidden = false
output = nn.SelectTable(-1)(nn.SelectTable(2)(
rnn({
inith,
nn.ReverseTable()(output),
})
))
end
return output
end
}

BGRUModel.makeEncoder = argcheck{
doc=[[
This encoder runs a forward and backward LSTM network and concatenates their
top-most hidden states.
]],
{name='self', type='BGRUModel'},
{name='config', type='table'},
call = function(self, config)
local sourceIn = nn.Identity()()
local inith, tokens = sourceIn:split(2)

local dict = config.srcdict
local lut = mutils.makeLookupTable(config, dict:size(),
dict.pad_index)
local embed
if config.dropout_src > 0 then
embed = nn.MapTable(nn.Sequential()
:add(lut)
:add(nn.Dropout(config.dropout_src)))(tokens)
else
embed = nn.MapTable(lut)(tokens)
end

local col1 = self:makeEncoderColumn{
config = config,
inith = inith,
input = embed,
nlayers = config.nenclayer,
}
local col2 = self:makeEncoderColumn{
config = config,
inith = inith,
input = nn.ReverseTable()(embed),
nlayers = config.nenclayer,
}

-- Each column will switch direction between layers. Before merging,
-- they should both run in the same direction (here: forward).
if config.nenclayer % 2 == 0 then
col1 = nn.ReverseTable()(col1)
else
col2 = nn.ReverseTable()(col2)
end

local prepare = nn.Sequential()
-- Concatenate forward and backward states
prepare:add(nn.JoinTable(2, 2))
-- Scale down to nhid for further processing
prepare:add(nn.Linear(config.nhid * 2, config.nembed, false))
-- Add singleton dimension for subsequent joining
prepare:add(nn.View(-1, 1, config.nembed))

local joinedOutput = nn.JoinTable(1, 2)(
nn.MapTable(prepare)(
nn.ZipTable()({col1, col2})
)
)
if config.dropout_hid > 0 then
joinedOutput = nn.Dropout(config.dropout_hid)(joinedOutput)
end

-- avgpool_model.makeDecoder() expects two encoder outputs, one for
-- attention score computation and the other one for applying them.
-- We'll just use the same output for both.
return nn.gModule({sourceIn}, {
joinedOutput, nn.Identity()(joinedOutput)
})
end
}

BGRUModel.prepareSource = argcheck{
{name='self', type='BGRUModel'},
call = function(self)
-- Device buffers for samples
local buffers = {
source = {},
}

-- NOTE: It's assumed that all encoders start from the same hidden
-- state.
local encoderRNN = mutils.findAnnotatedNode(
self:network(), 'encoderRNN'
)
assert(encoderRNN ~= nil)

return function(sample)
-- Encoder input
local source = {}
for i = 1, sample.source:size(1) do
buffers.source[i] = buffers.source[i]
or torch.Tensor():type(self:type())
source[i] = mutils.sendtobuf(sample.source[i],
buffers.source[i])
end

local initialHidden = encoderRNN:initializeHidden(sample.bsz)
return {initialHidden, source}
end
end
}

return BGRUModel
Loading

0 comments on commit 244053c

Please sign in to comment.